IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    TorchScript: Tracing vs. Scripting

    Yuxin Wu (ppwwyyxxc@gmail.com)发表于 2022-05-23 06:59:00
    love 0

    PyTorch provides two methods to turn an nn.Module into agraph represented in TorchScript format: tracing and scripting.This article will:

    1. Compare their pros and cons, with a focus on useful tips for tracing.
    2. Try to convince you that torch.jit.trace should be preferred over torch.jit.scriptfor deployment of non-trivial models.

    The second point might be an uncommon opinion:If I Google "tracing vs scripting", the first articlerecommends scripting as default.But tracing has many advantages.In fact, by the time I left, "tracing as default, scripting only when necessary" is thestrategy all detection & segmentation models in Facebook/Meta products are deployed.

    Why tracing is better? TL;DR: (i) it will not damage the code quality; (ii) its main limitations can beaddressed by mixing with scripting.

    Terminology

    We start by disambiguate some common terminologies:

    • Export: refers to the process that turns a model written in eager-mode Pythoncode into a graph that describes the computation.

    • Tracing: An export method. It runs a model with certain inputs, and "traces / records" all the operationsthat are executed into a graph.

      torch.jit.trace is an export API that uses tracing, used like torch.jit.trace(model, input).See its tutorialand API.

    • Scripting: Another export method. It parses the Python source code of the model, and compiles the code into agraph.

      torch.jit.script is an export API that uses scripting, used like torch.jit.script(model).See its tutorialand API.

    • TorchScript: This is an overloaded term

      • It often refers to the representation / format of the exported graph.
      • But sometimes it refers to the scripting export method.

      To avoid confusion, I'll never use "TorchScript" alone in this article.I'll use "TS-format" to refer to the format, and "scripting" to refer to the export method.

      Because this term is used with ambiguity, it may have caused the impression that "scripting" is the"official / preferred" way to create a TS-format model. But that's not necessarily true.

    • (Torch)Scriptable: A model is "scriptable" if torch.jit.script(model) succeeds, i.e. it canbe exported by scripting.

    • Traceable: A model is "traceable" if torch.jit.trace(model, input) succeeds for atypical input.

    • Generalize: A traced model (returned object of trace()) "generalizes" to other inputs(different from the inputs given during tracing), if it can inference correctly when given other inputs.Scripted models always generalize.

    • Dynamic control flow or data-dependent control flow: control flow where the operatorsto be executed depend on the input data, e.g. for a Tensor x:

      • if x[0] == 4: x += 1 is a dynamic control flow.
      • model: nn.Sequential = ...
        for m in model:
        x = m(x)
        is NOT a dynamic control flow.
        class A(nn.Module):
        backbone: nn.Module
        head: Optiona[nn.Module]
        def forward(self, x):
        x = self.backbone(x)
        if self.head is not None:
        x = self.head(x)
        return x
        is NOT a dynamic control flow.

    The Cost of Scriptability

    If anyone says "we'll make Python better by writing a compiler for it", you should immediatelybe alarmed and know that this is extremely difficult.Python is too big and too dynamic. A compiler can only support a subset of its syntax features and builtins, at best --the scripting compiler in PyTorch is no exception.

    What subset of Python does this compiler support?A rough answer is: the compiler hasgood support for the most basic syntax, but medium to no support for anything more complicated (classes, builtins like range and zip, dynamic types, etc.).But there is no clear answer: even the developers of the compiler usually need to run the code to see if it can be compiled or not.

    The incomplete Python compiler limits how users can write code.Though there isn't a clear list of constraints,I can tell from my experience what impact they have had on large projects:code quality is the cost of scriptability.

    Impact on Most Projects

    To make their code scriptable / compilable by the scripting compiler,most projects choose to stay on the "safe side" to only use basic syntax of Python:no/few custom structures, no builtins, no inheritance, no Union, no **kwargs, no lambda, no dynamic types, etc.

    This is because these "advanced" compiler features are either not supported at all, or with "partial support"which is not robust enough: they may work in some cases but fail in others.And because there is no clear spec of what is supported,users are unable to reason about or workaround the failures.Therefore, eventually users move to and stay on the safe side.

    The terrible consequence is that:developers stop making abstractions / exploring useful language featuresdue to concerns in scriptability.

    A related hack that many projects do is to rewrite part of the code for scripting:create a separate, inference-only forward codepath that makes the compiler happy.This also makes the project harder to maintain.

    Impact on Detectron2

    Detectron2 supports scripting, but the story was a bit different: it did not go downhill in code quality which we value a lot in research.Instead, with some creativity and direct support from PyTorch team (and some volunteered help from Alibaba engineers), we managed to make most modelsscriptable without removing any abstractions.

    However, it is not an easy task:we had to add dozens of syntax fixes to the compiler, find creative workarounds,and develop some hacky patches in detectron2 that are inthis file(which honestly could affect maintainability in the long term).I would not recommend other large projects to aim for "scriptability without losing abstractions" unlessthey are also closely supported by PyTorch team.

    Recommendation

    If you think "scripting seems to work for my project"so let's embrace it, I might advise against it for the following reasons,based on my past experiences with a few projects that support scripting:

    • What "works" might be more brittle than you think (unless you limit yourself to the basic syntax):Your code might happen to compile now, but one day you'll add a few innocent changes to your modeland find that the compiler refuses it.

    • Basic syntax is not enough:Even if more complex abstractions don't appear necessary to your project at the moment,if the project is expected to grow, it will require more language features in the future.

      Take a multi-task detector for example:

      1. There could be 10s of inputs, so it's preferable to use some structures/classes.
      2. The same data can have different representations (e.g. different ways to represent a segmentation mask),which demands Union or more dynamic types.
      3. There are many architectural choices of a detector, which makes inheritance useful.

      Large, growing projects definitely need evolving abstractions to stay healthy.

    • Code quality could severely deteriorate:Ugly code starts to accumulate, because clean code sometimes just doesn't compile.Also, due to syntax limitations of the compiler,abstractions cannot be easily made to clean up the ugliness.The health of the project gradually goes downhill.

    Below is a complaint in PyTorch issues.The issue itself is just one small papercut of scripting,but similar complaints were heard many times.The status-quo is: scripting forces you to write ugly code, so only use it when necessary.

    Make a Model Trace and Generalize

    The Cost of Traceability

    What it takes to make a model traceable is very clear, and has a much smaller impact on code health.

    1. First, neither scripting nor tracing works if the model is not even a proper single-device, connected graph representable in TS-format.For example, if the model has DataParallel submodules, or if the modelconverts tensors to numpy arrays and calls OpenCV functions, etc, you'll have to refactor it.

      Apart from this obvious constraint, there are only two extra requirements for traceability.

    2. Input/output format

      Model's inputs/outputs have to be Union[Tensor, Tuple[Tensor]] to be traceable.

      This might appear worse than scripting, because scripting at least has good support forstrongly-typed dicts.However, here the format constraint does not apply to submodules:submodules can use any input/output format: classes, kwargs, anything that Python supports.

      The format requirement only applies to the outer-most model, so it's very easy to address.If the model uses richer formats, just create a simple wrapper around it that converts to/fromTuple[Tensor].Detectron2 even automates this for all its models by a universal wrapperlike this:

      outputs = model(inputs)   # inputs/outputs are rich structure, e.g. dicts or classes
      # torch.jit.trace(model, inputs) # FAIL! unsupported format
      adapter = TracingAdapter(model, inputs)
      traced = torch.jit.trace(adapter, adapter.flattened_inputs) # Can now trace the model

      # Traced model can only produce flattened outputs (tuple of tensors):
      flattened_outputs = traced(*adapter.flattened_inputs)
      # Adapter knows how to convert it back to the rich structure (new_outputs == outputs):
      new_outputs = adapter.outputs_schema(flattened_outputs)
      Automatically Flatten & Unflatten Nested Containers has more details on how this adapter is implemented.
    3. Symbolic shapes:

      Expressions like tensor.size(0), tensor.size()[1], tensor.shape[2]are integers in eager mode, but Tensors in tracing mode.Such difference is necessary so that during tracing, shape computation can becaptured as symbolic operations in the graph.An example is given in the next section about generalization.

      Due to different return types,a model may be untraceable if parts of it assume shapes are integers.This usually can be fixed quite easily by handling both types in the code.A helpful function is torch.jit.is_tracingwhich checks if the code is executed in tracing mode.

    That's all it takes for traceability - most importantly, any Python syntax is allowed in model implementation, because tracing does not careabout syntax at all.

    Generalization Problem

    Just being "traceable" is not sufficient.The biggest problem with tracing, is that it may not generalize to other inputs.This problem happens in the following cases:

    1. Dynamic control flow:

      >>> def f(x):
      ... return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
      >>> m = torch.jit.trace(f, torch.tensor(3))
      >>> print(m.code)
      def f(x: Tensor) -> Tensor:
      return torch.sqrt(x)

      In this example, due to dynamic control flow,the trace only keeps one branch of the condition, and will not generalize to certain (negative) inputs.

    2. Capture variables as constants:

      >>> a, b = torch.rand(1), torch.rand(2)
      >>> def f1(x): return torch.arange(x.shape[0])
      >>> def f2(x): return torch.arange(len(x))
      >>> # See if the two traces generalize from a to b:
      >>> torch.jit.trace(f1, a)(b)
      tensor([0, 1])
      >>> torch.jit.trace(f2, a)(b)
      tensor([0]) # WRONG!
      >>> # Why f2 does not generalize? Let's compare their code:
      >>> print(torch.jit.trace(f1, a).code, torch.jit.trace(f2, a).code)
      def f1(x: Tensor) -> Tensor:
      _0 = ops.prim.NumToTensor(torch.size(x, 0))
      _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
      return _1
      def f2(x: Tensor) -> Tensor:
      _0 = torch.arange(1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
      return _0

      Intermediate computation results of a non-Tensor type (in this case, an int type) may be captured as constants, using thevalue observed during tracing. This causes the trace to not generalize.

      In addition to len(), this issue can also appear in:

      • .item() which converts tensors to int/float.
      • Any other code that converts torch types to numpy/python primitives.
      • A few problematic operators, e.g. advanced indexing.
    3. Capture device:

      >>> def f(x):
      ... return torch.arange(x.shape[0], device=x.device)
      >>> m = torch.jit.trace(f, torch.tensor([3]))
      >>> print(m.code)
      def f(x: Tensor) -> Tensor:
      _0 = ops.prim.NumToTensor(torch.size(x, 0))
      _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
      return _1
      >>> m(torch.tensor([3]).cuda()).device
      device(type='cpu') # WRONG!

      Similarly, operators that accept a device argument will remember the device used during tracing (this canbe seen in m.code).So the trace may not generalize to inputs on a different device.Such generalization is almost never needed, because deployment usually has a target device.

    Let Tracing Generalize

    The above problems are annoying and often silent (warnings, but no errors),but they can be successfully addressed by good practice and tools:

    • Pay attention to TracerWarning: In the first two examples above, torch.jit.trace actually emits warnings.The first example prints:

      a.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect.
      We can't record the data flow of Python values, so this value will be treated as a constant in the future.
      This means that the trace might not generalize to other inputs!
      if x.sum() > 0:

      Paying attention to these warnings (or even better, catch them)will expose most generalization problems of tracing.

      Note that the "capture device" case does not print warnings because tracing was not designed to support such generalization at all.

    • Unittests for parity: Unittests should be done after export and before deployment, to verify thatthe exported model produces the same outputs as the original eager-mode model, i.e.

      assert allclose(torch.jit.trace(model, input1)(input2), model(input2))

      If generalization across shapes is needed (not always needed), input2 should have differentshapes from input1.

      Detectron2 has many generalization tests, e.g. thisand this.Once a gap is found, inspecting the code of the exported TS-format model can uncover the place whereit fails to generalize.

    • Avoid unnecessary "special case" conditions:Avoid conditions like

      if x.numel() > 0:
      output = self.layers(x)
      else:
      output = torch.zeros((0, C, H, W)) # Create empty outputs

      that handles special cases such as empty inputs.Instead, improve self.layers or its underlying kernel so it supports empty inputs.This would result in cleaner code and also improve tracing.This is why I'm involved in many PyTorch issues that improve support for emptyinputs, such as#12013,#36530,#56998.Most PyTorch operations work perfectly with empty inputs,so such branching is hardly needed.

    • Use symbolic shapes: As mentioned earlier, tensor.size() returns Tensor during tracing, sothat shape computations are captured in the graph.Users should avoid accidentally turning tensor shapes into constants:

      • Use tensor.size(0) instead of len(tensor) because the latter is an int.For custom classes, implement a .size method or use .__len__() instead of len(), e.g. like here.
      • Do not convert sizes by int() or torch.as_tensor because they will capture constants.This helper functionis useful to convert sizes into a tensor, in a way that works in both tracing and eager mode.
    • Mix tracing and scripting: they can be mixed together, so you can use scriptingon the small portion of code that tracing does not work correctly.This can fix almost all problems of tracing. More on this below.

    Mix Tracing and Scripting

    Tracing and scripting both have their own problems, and thebest solution is usually to mix them together.This gives us the best of both worlds.

    To minimize the negative impact on code quality,we should use tracing for the majority of logic, and use scripting only when necessary.

    1. Use @script_if_tracing: Inside torch.jit.trace, the @script_if_tracingdecorator can compile functions by scripting.Typically, this only requires a small refactor of the forward logic to separate the parts that need tobe compiled (the parts with control flow):

      def forward(self, ...):
      # ... some forward logic
      @torch.jit.script_if_tracing
      def _inner_impl(x, y, z, flag: bool):
      # use control flow, etc.
      return ...
      output = _inner_impl(x, y, z, flag)
      # ... other forward logic

      By scripting only the parts that need it,the code quality damage is strictly smaller than making the entire model scriptable,and it does not affect the module's forward interface at all.

      The function decorated by @script_if_tracing has to be a pure function that does not contain modules.Therefore, sometimes a bit more refactoring is needed:

      Before Refactoring After Refactoring
      # This branch cannot be compiled by
      # @script_if_tracing, because it
      # refers to `self.layers`
      if x.numel() > 0:
      x = preprocess(x)
      output = self.layers(x)
      else:
      # Create empty outputs
      output = torch.zeros(...)
      # This branch can be compiled by @script_if_tracing
      if x.numel() > 0:
      x = preprocess(x)
      else:
      # Create empty inputs
      x = torch.zeros(...)
      # Needs to make sure self.layers accept empty
      # inputs. If necessary, add such condition branch
      # into self.layers as well.
      output = self.layers(x)

      In fact, for most vision models, dynamic control flow is needed only in a few submodules whereit's easy to be scriptable.To show how rare it is needed, the entire detectron2 only has two functions decorated with @script_if_tracing due to control flows:paste_masksand heatmaps_to_keypoints,both for post-processing only.A few other functions are also decorated to generalize across devices (a very rare requirement).

    2. Use scripted / traced submodules:

      model.submodule = torch.jit.script(model.submodule)
      torch.jit.trace(model, inputs)

      In this example, suppose submodule cannot be traced correctly, we can script it before tracing.However I do not recommend it.If possible, I will suggest using @script_if_tracinginside submodule.forward instead,so that scripting is limited to the internals of the submodule,without affecting the module's interface.

      And similarly,

      model.submodule = torch.jit.trace(model.submodule, submodule_inputs)
      torch.jit.script(model)

      this uses a traced submodule during scripting.This looks nice, but is not so useful in practice: it will affect the interfaceof submodule, requiring it to only accept/return Tuple[Tensor] -- this is abig constraint that might hurt code quality even more than scripting.

      A rare scenario where "tracing a submodule" is useful, is this:

      class A(nn.Module):
      def forward(self, x):
      # Dispatch to different submodules based on a dynamic, data-dependent condition:
      return self.submodule1(x) if x.sum() > 0 else self.submodule2(x)

      @script_if_tracing cannot compile such control flow because it only supports pure functions.If submodule{1,2} are complex and cannot be scripted,using traced submodules in a scripted parent A is the best option.

    3. Merge multiple traces:

      Scripted models support two more features that traced models don't:

      • Control flow conditioned on attributes: a scripted module can have mutable attributes (e.g. a boolean flag)that affect control flows. Traced modules do not have control flows.
      • Multiple methods: a traced module only supports forward(), but a scripted module can havemultiple methods.

      Actually, both features above are doing the same thing: they allow an exported model to be used indifferent ways, i.e. execute different sequences of operators as requested by the caller.

      Below is an example scenario where such feature is useful: if Detector is scripted, the caller can mutate itsdo_keypoint attribute to control its behavior, or call predict_keypoint methoddirectly if needed.

      class Detector(nn.Module):
      do_keypoint: bool

      def forward(self, img):
      box = self.predict_boxes(img)
      if self.do_keypoint:
      kpts = self.predict_keypoint(img, box)

      @torch.jit.export
      def predict_boxes(self, img): pass

      @torch.jit.export
      def predict_keypoint(self, img, box): pass

      This requirement is not seen very often. But if needed, how to achieve this in tracing?I have a solution that's not very clean:

      Tracing can only capture one sequence of operators, so the natural way is to trace the model twice:

      det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
      det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)

      We can then alias their weights (to not duplicate the storage), and merge thetwo traces into one module to script.

      det2.submodule.weight = det1.submodule.weight
      class Wrapper(nn.ModuleList):
      def forward(self, img, do_keypoint: bool):
      if do_keypoint:
      return self[0](img)
      else:
      return self[1](img)
      exported = torch.jit.script(Wrapper([det1, det2]))

    Performance

    If a model is both traceable and scriptable,tracing always generates same or simpler graph (therefore likely faster).

    Why?Because scripting tries to faithfully representyour Python code, even some of it are unnecessary. For example:it is not always smart enough to realize that someloops or data structures in the Python code are actually static and can be removed:

    class A(nn.Module):
    def forward(self, x1, x2, x3):
    z = [0, 1, 2]
    xs = [x1, x2, x3]
    for k in z: x1 += xs[k]
    return x1
    model = A()
    print(torch.jit.script(model).code)
    # def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
    # z = [0, 1, 2]
    # xs = [x1, x2, x3]
    # x10 = x1
    # for _0 in range(torch.len(z)):
    # k = z[_0]
    # x10 = torch.add_(x10, xs[k])
    # return x10
    print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
    # def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
    # x10 = torch.add_(x1, x1)
    # x11 = torch.add_(x10, x2)
    # return torch.add_(x11, x3)

    This example is very simple, so it actually has workarounds for scripting (use tuple instead of list),or the loop might get optimized in a later optimization pass.But the point is: the graph compiler is not always smart enough. For complicated models, scripting mightgenerate a graph with unnecessary complexity that's hard to optimize.

    Concluding Thoughts

    Tracing has clear limitations:I spent most of this article talking about the limitations of tracing and how to fix them.I actually think this is the advantage of tracing: it has clear limitations (and solutions),so you can reason about whether it works.

    On the contrary, scripting is more like a black box:no one knows if it works before trying.I didn't mention a single trick about how to fix scripting:there are many of them, but it's not worth your time to probe and fix a black box.

    Tracing has small blast radius:Both tracing and scripting affect how code can be written, but tracing has a much smaller blastradius, and causes much less damage:

    • It limits the input/output format, but on the outer-most module only. (And this issue can be automaticallysolved as discussed above.)
    • It needs some code changes to generalize (e.g. to mix scripting in tracing), but these changes only go into theinternal implementation of the affected modules, not their interfaces.

    On the other hand, scripting has an impact on:

    • The interface of every module & submodule involved.
      • IMO, this is the biggest damage:Advanced syntax features are needed in interfaces, and I'm not willing to compromise on interface design.
      • This may end up affecting training as well because interface is often shared between training and inference.
    • Pretty much every line of code in the inference forward path.

    Having a large blast radius is why scripting can do great harm to code quality.

    Control flow vs. other Python syntax:PyTorch is loved by its users because they can "just write Python", and most importantly writePython control flows. But other syntax of Python are important as well.If being able to write Python control flow (scripting) means losing other great syntax,I'd rather give up on the ability to write Python control flow.

    In fact, if PyTorch is less obsessed with Python control flow, and offers mesymbolic control flows such as torch.cond like this (similar to the API of tf.cond):

    def f(x):
    return torch.cond(x.sum() > 0, lambda: torch.sqrt(x), lambda: torch.square(x))

    Then f could be traced correctly and I would be happy to use this, no longer having to worryabout scripting.TensorFlow AutoGraphis a great example that automates this idea.



沪ICP备19023445号-2号
友情链接