This post is about a small functionality that is found useful in TensorFlow / JAX / PyTorch.
Low-level components of these systems often use a plain list of values/tensorsas inputs & outputs.However, end-users that develop models often want to work with morecomplicated data structures:Dict[str, Any]
, List[Any]
, custom classes, and their nested combinations.Therefore, we need bidirectional conversion between nested structures and a plain list of tensors.I found that different libraries invent similar approaches to solve this problem, and it's interesting to list them here.
Though many simple deep learning models just needs a few inputs/outputs tensors,nested containers are useful abstractions in advanced models.This is because many concepts are naturally represented by >1 tensors, e.g.:
|
When a frequently-used concept has natural complexity like above, representing itin a flat structure (e.g. Dict[str, Tensor]
) consisting of only regular tensors may result in ugly code.A multi-level nested structure sometimes becomes helpful.Take sparse tensor as a simple example:
Use nested containers | Use a flat Dict[str, Tensor] | |
---|---|---|
Representation | {"a": SparseTensor, SparseTensor can be a namedtuple/dataclass, or a new class. | {"a_values": Tensor, |
Sanity check | SparseTensor class can guarantee both tensors exist and follow certain contracts (e.g. their shapes match) | Need to check a_{values,indices} co-exist in the dict |
Pass to another function | Pass x["a"] directly | Extract x["a_values"], x["a_indices"] and pass both |
Operations | SparseTensor class can have methods that work like regular tensors, e.g. y = x["a"] + 1 | Need to implement many new functions, e.g. y = add_sparse(x["a_values"], x["a_indices"], 1) |
Despite the benefits, lower-level stacks often ignore these abstractionsand choose to use a "flat" interface: their inputs & outputs are a flat list of values / Tensors.This is because:(i) the abstraction may no longer be useful in lower level;(ii) a simple structure simplifies their implementation;(iii) a flat list is a data structure available even in lower-level languages & systems.
Therefore, conversion from a nested structure to a plain list of values is important.This is often referred to as "flatten".It is pretty straightforward to flatten a container recursively -- like the following flatten
function:
|
The inverse of flatten
is also important: given new values [x2, y2, z2]
,we want the unflatten
function below to construct obj2
that has the samestructure as obj
.
|
unflatten
is a very handy utility. For example, to create a clone of obj
on a different device, we simply do this:
|
Without unflatten
, every such functionality needs to be implemented as a recursivefunction, like PyTorch's pin_memory
.
unflatten
How do we implement unflatten
?Apparently, we need to give it a representation of structure (noted as a placeholder ???
in the above code).There are two high-level approaches to solve this problem:
Schema-based: when flattening a container, explicitly record its structure/schema to be used for unflatten.Its API may look like this:
|
Examples: Detectron2's flatten_to_tuple
, TensorFlow's FetchMapper
, JAX's pytree
.
Schema-less: use the entire nested container as an implicit representation of structure. Its interface looks like this:
|
Examples: TensorFlow's tf.nest
. DeepMind's dm-tree
.
The two approaches have some pros and cons:
JAX's low level components accept/return flat tensors, so functions can be transformed and optimized more easily.Since end-users need nested containers, JAX transformations supports pytree containers,which by default includes flattening & unflattening for common Python containers.It further allows users to register custom classes byregister_pytree_node
.
Pytree uses a schema-based implementation that we already show-cased above.
When we need to independently process each leaf of the container, JAX provides another handyfunction tree_map
:
|
PyTorch also adds a similar implementation of pytree at herethat is used in its FX tracing.
TracingAdapter
torch.jit.trace(model, inputs)
executes the model with given inputs, and returns a graph representationof the model's execution.This is one of the most common methods (and the best IMO) how PyTorch models are exported today.However, it has a limitation that model's inputs & outputs have to be flat (precisely, Union[Tensor, Tuple[Tensor]]
).
In order to trace models with more complicated inputs & outputs, I created the TracingAdapter
tool in detectron2, that flattens/unflattens a model's inputs and outputs to make it traceable.A minimal implementation of it may look like this:
|
where flatten
uses a schema-based implementation that can be found in this file.Coincidentally, its interface looks like JAX's pytree:
|
Perception models in Meta accept a wide range of inputs/outputs formats:they may take any number of images plus auxiliary data as inputs, andpredict boxes, masks, keypoints or any other interesting attributes as outputs.But deployment prefers a flat interface for optimizability and interoperability.TracingAdapter
's automatic flattening and unflattening mechanism has freed engineers fromwriting format conversion glue code when deploying these models.
In addition to deployment, TracingAdapter
is also useful in a few other places to smooththe experience of torch.jit.trace
:
TracingAdapter
is the easiest way.add_graph
method that visualizes the graph structure in tensorboard.The method requires flattened inputs,therefore TracingAdapter
can be used like this.TracingAdapter
is useful as well, e.g. here.tf.nest
tf.nest.flatten
and tf.nest.pack_sequence_as
implement schema-less flattening and unflattening.
The unflatten function requires a container, and it will flatten this container on-the-fly whilesimultaneously "pack" flat values into the structure of this container. Here is an official example (note that dict values are ordered by keys):
|
tf.nest.{flatten,pack_sequence_as}
are widely used in TensorFlow because many low-level components have a flat interface, especially forinterop with C APIs.
|
tf.nest.map_structure
has the same functionality as JAX's tree_map
.
FetchMapper
TFv1's session.run(fetches)
supports fetching nested containers.This is demonstrated in an example from theofficial documentation:
|
This powerful interface exists in TF's Python client only.The client interacts with the C API's TF_SessionRun
which only accepts a plain array of inputs/outputs.Therefore, the client needs to:
The flatten/unflatten logic uses a schema-based implementation in the client's FetchMapper
.This implementation is a bit more complicated due toan extra guarantee thatthe flattened tensors are unique. (This is to ensure the client won't fetch the same tensor twice in one call;this cannot be done by using tf.nest
.)
In addition to builtin Python containers, FetchMapper
supports a few other TF containers(such as SparseTensor
) and can be extended to new containers by registering conversion functions.
tree
libraryDeepMind has a tree
library as a standalone alternative to tf.nest
:
deepmind/tree | tf.nest | jax.tree_util |
---|---|---|
tree.flatten | tf.nest.flatten | jax.tree_util.tree_flatten |
tree.unflatten_as | tf.nest.pack_sequence_as | jax.tree_util.tree_unflatten |
tree.map_structure | tf.nest.map_structure | jax.tree_util.tree_map |