PyTorch provides two methods to turn an nn.Module
into agraph represented in TorchScript format: tracing and scripting.This article will:
torch.jit.trace
should be preferred over torch.jit.script
for deployment of non-trivial models.The second point might be an uncommon opinion:If I Google "tracing vs scripting", the first articlerecommends scripting as default.But tracing has many advantages.In fact, by the time I left, "tracing as default, scripting only when necessary" is thestrategy all detection & segmentation models in Facebook/Meta products are deployed.
Why tracing is better? TL;DR: (i) it will not damage the code quality; (ii) its main limitations can beaddressed by mixing with scripting.
We start by disambiguate some common terminologies:
Export: refers to the process that turns a model written in eager-mode Pythoncode into a graph that describes the computation.
Tracing: An export method. It runs a model with certain inputs, and "traces / records" all the operationsthat are executed into a graph.
torch.jit.trace
is an export API that uses tracing, used like torch.jit.trace(model, input)
.See its tutorialand API.
Scripting: Another export method. It parses the Python source code of the model, and compiles the code into agraph.
torch.jit.script
is an export API that uses scripting, used like torch.jit.script(model)
.See its tutorialand API.
TorchScript: This is an overloaded term
To avoid confusion, I'll never use "TorchScript" alone in this article.I'll use "TS-format" to refer to the format, and "scripting" to refer to the export method.
Because this term is used with ambiguity, it may have caused the impression that "scripting" is the"official / preferred" way to create a TS-format model. But that's not necessarily true.
(Torch)Scriptable: A model is "scriptable" if torch.jit.script(model)
succeeds, i.e. it canbe exported by scripting.
Traceable: A model is "traceable" if torch.jit.trace(model, input)
succeeds for atypical input.
Generalize: A traced model (returned object of trace()
) "generalizes" to other inputs(different from the inputs given during tracing), if it can inference correctly when given other inputs.Scripted models always generalize.
Dynamic control flow or data-dependent control flow: control flow where the operatorsto be executed depend on the input data, e.g. for a Tensor
x:
if x[0] == 4: x += 1
is a dynamic control flow.
|
|
If anyone says "we'll make Python better by writing a compiler for it", you should immediatelybe alarmed and know that this is extremely difficult.Python is too big and too dynamic. A compiler can only support a subset of its syntax features and builtins, at best --the scripting compiler in PyTorch is no exception.
What subset of Python does this compiler support?A rough answer is: the compiler hasgood support for the most basic syntax, but medium to no support for anything more complicated (classes, builtins like range
and zip
, dynamic types, etc.).But there is no clear answer: even the developers of the compiler usually need to run the code to see if it can be compiled or not.
The incomplete Python compiler limits how users can write code.Though there isn't a clear list of constraints,I can tell from my experience what impact they have had on large projects:code quality is the cost of scriptability.
To make their code scriptable / compilable by the scripting compiler,most projects choose to stay on the "safe side" to only use basic syntax of Python:no/few custom structures, no builtins, no inheritance, no Union
, no **kwargs
, no lambda, no dynamic types, etc.
This is because these "advanced" compiler features are either not supported at all, or with "partial support"which is not robust enough: they may work in some cases but fail in others.And because there is no clear spec of what is supported,users are unable to reason about or workaround the failures.Therefore, eventually users move to and stay on the safe side.
The terrible consequence is that:developers stop making abstractions / exploring useful language featuresdue to concerns in scriptability.
A related hack that many projects do is to rewrite part of the code for scripting:create a separate, inference-only forward codepath that makes the compiler happy.This also makes the project harder to maintain.
Detectron2 supports scripting, but the story was a bit different: it did not go downhill in code quality which we value a lot in research.Instead, with some creativity and direct support from PyTorch team (and some volunteered help from Alibaba engineers), we managed to make most modelsscriptable without removing any abstractions.
However, it is not an easy task:we had to add dozens of syntax fixes to the compiler, find creative workarounds,and develop some hacky patches in detectron2 that are inthis file(which honestly could affect maintainability in the long term).I would not recommend other large projects to aim for "scriptability without losing abstractions" unlessthey are also closely supported by PyTorch team.
If you think "scripting seems to work for my project"so let's embrace it, I might advise against it for the following reasons,based on my past experiences with a few projects that support scripting:
What "works" might be more brittle than you think (unless you limit yourself to the basic syntax):Your code might happen to compile now, but one day you'll add a few innocent changes to your modeland find that the compiler refuses it.
Basic syntax is not enough:Even if more complex abstractions don't appear necessary to your project at the moment,if the project is expected to grow, it will require more language features in the future.
Take a multi-task detector for example:
Union
or more dynamic types.Large, growing projects definitely need evolving abstractions to stay healthy.
Code quality could severely deteriorate:Ugly code starts to accumulate, because clean code sometimes just doesn't compile.Also, due to syntax limitations of the compiler,abstractions cannot be easily made to clean up the ugliness.The health of the project gradually goes downhill.
Below is a complaint in PyTorch issues.The issue itself is just one small papercut of scripting,but similar complaints were heard many times.The status-quo is: scripting forces you to write ugly code, so only use it when necessary.
What it takes to make a model traceable is very clear, and has a much smaller impact on code health.
First, neither scripting nor tracing works if the model is not even a proper single-device, connected graph representable in TS-format.For example, if the model has DataParallel
submodules, or if the modelconverts tensors to numpy arrays and calls OpenCV functions, etc, you'll have to refactor it.
Apart from this obvious constraint, there are only two extra requirements for traceability.
Input/output format
Model's inputs/outputs have to be Union[Tensor, Tuple[Tensor]]
to be traceable.
This might appear worse than scripting, because scripting at least has good support forstrongly-typed dicts.However, here the format constraint does not apply to submodules:submodules can use any input/output format: classes, kwargs, anything that Python supports.
The format requirement only applies to the outer-most model, so it's very easy to address.If the model uses richer formats, just create a simple wrapper around it that converts to/fromTuple[Tensor]
.Detectron2 even automates this for all its models by a universal wrapperlike this:
|
Symbolic shapes:
Expressions like tensor.size(0)
, tensor.size()[1]
, tensor.shape[2]
are integers in eager mode, but Tensor
s in tracing mode.Such difference is necessary so that during tracing, shape computation can becaptured as symbolic operations in the graph.An example is given in the next section about generalization.
Due to different return types,a model may be untraceable if parts of it assume shapes are integers.This usually can be fixed quite easily by handling both types in the code.A helpful function is torch.jit.is_tracing
which checks if the code is executed in tracing mode.
That's all it takes for traceability - most importantly, any Python syntax is allowed in model implementation, because tracing does not careabout syntax at all.
Just being "traceable" is not sufficient.The biggest problem with tracing, is that it may not generalize to other inputs.This problem happens in the following cases:
Dynamic control flow:
|
In this example, due to dynamic control flow,the trace only keeps one branch of the condition, and will not generalize to certain (negative) inputs.
Capture variables as constants:
|
Intermediate computation results of a non-Tensor type (in this case, an int type) may be captured as constants, using thevalue observed during tracing. This causes the trace to not generalize.
In addition to len()
, this issue can also appear in:
.item()
which converts tensors to int/float.Capture device:
|
Similarly, operators that accept a device
argument will remember the device used during tracing (this canbe seen in m.code
).So the trace may not generalize to inputs on a different device.Such generalization is almost never needed, because deployment usually has a target device.
The above problems are annoying and often silent (warnings, but no errors),but they can be successfully addressed by good practice and tools:
Pay attention to TracerWarning
: In the first two examples above, torch.jit.trace
actually emits warnings.The first example prints:
|
Paying attention to these warnings (or even better, catch them)will expose most generalization problems of tracing.
Note that the "capture device" case does not print warnings because tracing was not designed to support such generalization at all.
Unittests for parity: Unittests should be done after export and before deployment, to verify thatthe exported model produces the same outputs as the original eager-mode model, i.e.
|
If generalization across shapes is needed (not always needed), input2
should have differentshapes from input1
.
Detectron2 has many generalization tests, e.g. thisand this.Once a gap is found, inspecting the code of the exported TS-format model can uncover the place whereit fails to generalize.
Avoid unnecessary "special case" conditions:Avoid conditions like
|
that handles special cases such as empty inputs.Instead, improve self.layers
or its underlying kernel so it supports empty inputs.This would result in cleaner code and also improve tracing.This is why I'm involved in many PyTorch issues that improve support for emptyinputs, such as#12013,#36530,#56998.Most PyTorch operations work perfectly with empty inputs,so such branching is hardly needed.
Use symbolic shapes: As mentioned earlier, tensor.size()
returns Tensor
during tracing, sothat shape computations are captured in the graph.Users should avoid accidentally turning tensor shapes into constants:
tensor.size(0)
instead of len(tensor)
because the latter is an int.For custom classes, implement a .size
method or use .__len__()
instead of len()
, e.g. like here.int()
or torch.as_tensor
because they will capture constants.This helper functionis useful to convert sizes into a tensor, in a way that works in both tracing and eager mode.Mix tracing and scripting: they can be mixed together, so you can use scriptingon the small portion of code that tracing does not work correctly.This can fix almost all problems of tracing. More on this below.
Tracing and scripting both have their own problems, and thebest solution is usually to mix them together.This gives us the best of both worlds.
To minimize the negative impact on code quality,we should use tracing for the majority of logic, and use scripting only when necessary.
Use @script_if_tracing
: Inside torch.jit.trace
, the @script_if_tracing
decorator can compile functions by scripting.Typically, this only requires a small refactor of the forward logic to separate the parts that need tobe compiled (the parts with control flow):
|
By scripting only the parts that need it,the code quality damage is strictly smaller than making the entire model scriptable,and it does not affect the module's forward interface at all.
The function decorated by @script_if_tracing
has to be a pure function that does not contain modules.Therefore, sometimes a bit more refactoring is needed:
Before Refactoring | After Refactoring | ||
---|---|---|---|
|
|
In fact, for most vision models, dynamic control flow is needed only in a few submodules whereit's easy to be scriptable.To show how rare it is needed, the entire detectron2 only has two functions decorated with @script_if_tracing
due to control flows:paste_masksand heatmaps_to_keypoints,both for post-processing only.A few other functions are also decorated to generalize across devices (a very rare requirement).
Use scripted / traced submodules:
|
In this example, suppose submodule
cannot be traced correctly, we can script it before tracing.However I do not recommend it.If possible, I will suggest using @script_if_tracing
inside submodule.forward
instead,so that scripting is limited to the internals of the submodule,without affecting the module's interface.
And similarly,
|
this uses a traced submodule during scripting.This looks nice, but is not so useful in practice: it will affect the interfaceof submodule
, requiring it to only accept/return Tuple[Tensor]
-- this is abig constraint that might hurt code quality even more than scripting.
A rare scenario where "tracing a submodule" is useful, is this:
|
@script_if_tracing
cannot compile such control flow because it only supports pure functions.If submodule{1,2}
are complex and cannot be scripted,using traced submodules in a scripted parent A
is the best option.
Merge multiple traces:
Scripted models support two more features that traced models don't:
forward()
, but a scripted module can havemultiple methods.Actually, both features above are doing the same thing: they allow an exported model to be used indifferent ways, i.e. execute different sequences of operators as requested by the caller.
Below is an example scenario where such feature is useful: if Detector
is scripted, the caller can mutate itsdo_keypoint
attribute to control its behavior, or call predict_keypoint
methoddirectly if needed.
|
This requirement is not seen very often. But if needed, how to achieve this in tracing?I have a solution that's not very clean:
Tracing can only capture one sequence of operators, so the natural way is to trace the model twice:
|
We can then alias their weights (to not duplicate the storage), and merge thetwo traces into one module to script.
|
If a model is both traceable and scriptable,tracing always generates same or simpler graph (therefore likely faster).
Why?Because scripting tries to faithfully representyour Python code, even some of it are unnecessary. For example:it is not always smart enough to realize that someloops or data structures in the Python code are actually static and can be removed:
|
This example is very simple, so it actually has workarounds for scripting (use tuple instead of list),or the loop might get optimized in a later optimization pass.But the point is: the graph compiler is not always smart enough. For complicated models, scripting mightgenerate a graph with unnecessary complexity that's hard to optimize.
Tracing has clear limitations:I spent most of this article talking about the limitations of tracing and how to fix them.I actually think this is the advantage of tracing: it has clear limitations (and solutions),so you can reason about whether it works.
On the contrary, scripting is more like a black box:no one knows if it works before trying.I didn't mention a single trick about how to fix scripting:there are many of them, but it's not worth your time to probe and fix a black box.
Tracing has small blast radius:Both tracing and scripting affect how code can be written, but tracing has a much smaller blastradius, and causes much less damage:
On the other hand, scripting has an impact on:
Having a large blast radius is why scripting can do great harm to code quality.
Control flow vs. other Python syntax:PyTorch is loved by its users because they can "just write Python", and most importantly writePython control flows. But other syntax of Python are important as well.If being able to write Python control flow (scripting) means losing other great syntax,I'd rather give up on the ability to write Python control flow.
In fact, if PyTorch is less obsessed with Python control flow, and offers mesymbolic control flows such as torch.cond
like this (similar to the API of tf.cond
):
|
Then f
could be traced correctly and I would be happy to use this, no longer having to worryabout scripting.TensorFlow AutoGraphis a great example that automates this idea.