A typical PyTorch training program on 8 GPUs with 4 dataloaderworkers per GPU would create at least processes.A naive use of torch dataset and dataloader can easilyreplicate your dataset's RAM usage by 40 times. This issue has probably affected everyone who has done anything nontrivial with PyTorch.In this post, we will explain why it happens, and how to avoid the 40x RAM usage.
All code examples and experiment results are available on github at ppwwyyxx/RAM-multiprocess-dataloader.The content is not specific to PyTorch: it applies to any user of Python's multiprocessing library on Linux.
Datasets for machine learning are usually not stored in RAM. But it's common to store their "metadata" in RAM, and this may still cause nontrivial RAM usage. The metadata could be:
As a concrete case, loading the metadata of COCO training set into Python takes ~2.4G of RAM:
|
We obviously don't want to replicate this 2.4G of RAM across all processes.
We acknowledge that there are ways to offload these metadata to disk. For example, people sometimes do:
By doing these, the RAM usage of a dataset becomes negligible. However, these methods will sacrifice flexibility and capabilities, such as random-access, perfect shuffle, merging datasets arbitrarily, custom subsampling support, etc.Notably, PyTorch's commonly used map-style datasets supportrandom access & sampling.All of these capabilities require certain metadata in RAM.
This article ignores any of these offloading methods. Instead, we'll discuss how to reduce the RAM usage without moving these data out of RAM. The idea is simple: we'll try to let all processes share a single copy of the dataset.
First let's build tools to measure RAM usage - which is not as easy as it sounds.
Common tools like top -p PID
or psutil.Process(PID).memory_info()
obtains memory statistics from /proc/{PID}/statm
or /proc/{PID}/status
, but they are insufficient for our analysis. Instead, we'll use the information provided in
/proc/{PID}/smaps
: per-memory-mapping RAM usage information, documented inthis man page/proc/{PID}/smaps_rollup
: aggregation of data from smaps
We'll derive the following important measurements from it:
smaps
.smaps
.top/htop
.To obtain these measurements, we use psutil.Process(PID).memory_maps()
which parses smaps
under the hood:
|
Then we create a MemoryMonitor
utility to measure and print the results for a list of PIDs. The code is straightforward and can be found here.
We start with a naive implementation of a dataset that produces itemsfrom a list:
|
Then we start subprocesses to read from this dataset with the list of COCO data. To make a cleaner demo, we don't use PyTorch's dataloader, but just launch 4 subprocesses by ourselves:
|
We then added our MemoryMonitor
to it. The full code and its output logs are available on github. Each segment in the log contains memory measurements for the main process + 4 workers:
|
The code looks completely innocent. However, if we plot the memoryusage of any dataloader worker over time, we seem to find a memory leak!
This is the notorious "dataloader leaks memory" issue that is discussed at multiple places, e.g. this PyTorch issue and Edward's podcast.
In fact, the growth of RAM usage does stop in the end, so this issue is not a memory leak. But in reality, users often do not see the end before the system OOMs, and they may wrongly conclude this as a "memory leak".
The root cause of this issue is "copy-on-read" of forked CPython objects.
Linux has a copy-on-write mechanism: when a process forks, the child process will share its entire memory space with the parent, and only copy the relevant pages when necessary, i.e. when the child process needs to write to the page. This mechanism allows read-only pages to be shared to reduce total memory usage.
The copy-on-write behavior can be clearly observed in the above figure:at time=0, the worker has 2.6G of shared RAM, 0 USS, and of PSS because the RAM is shared among 5 processes (4 workers + 1 main).
However, this mechanism did not help us when we read our dataset. The problem is that our dataset is a large nested data structure that contains many small Python objects. Even though the dataset is "read-only" in theory, accessing any Python object will increment its refcount - causing a lot of memory writes. With these writes, memory can no longer be shared among parent and child processes. Therefore, in the figure we see that the "Shared" RAM decreases and "USS" increases.
The end game is that each child process has to replicate all the pages that contain object refcounts in the dataset. For a dataset with many objects, this is almost the size of the dataset itself. In the output log, we see that this program uses 10G total PSS in the end, where each child process replicates 1.8G of USS.
The copy-on-read issue is due to CPython's reference counting.There are ways to change CPython's behavior, e.g. by gc.freeze
, but it has far-reaching consequences and I failed to make it work for the example here. However, there is a simple and transparent way to solve the issue: store the dataset with very few number of Python objects, so there are very few refcounts!Below is a minimal implementation:
|
Detectron2 enables this type of serialization by default (since this commit by Yanghan). To compare different serialization mechanisms,we borrow its code into a serialization util, and use it here:
|
Just by this simple one-line change, the RAM usage greatly reduces. The end of the output log file is shown below.
|
We can see that:
#processes
because pickle.dumps
not only serializes but also compresses the data. We benefit from both sharing and compression by applying this optimization, at the cost of a tiny pickle.loads
overhead in each access.Actually, after compression, the dataset only takes ~500M (printed at the beginning of log). So a question arises: why does the main process use 1.6G RAM before starting subprocesses?
This can be explained by another CPython internal: it does not always release memory back to the OS. In fact, if we run this simple serialization/compression code:
|
We see that we seem to "lose" ~700MB of RAM even after we've deleted everything:
|
Such behavior is typically not a concern, since CPython will find opportunities to reuse these free buffers.
In our code above, we launched subprocesses using a start_method="fork"
argument."fork, spawn, forkserver" are the 3 "start methods" of Python's multiprocessing library. This article is a good reference that explains their differences.
Since start_method="fork"
is unsafe (in practice, it causes various crashes & deadlocks) and might no longer be the default in the future, we want to rerun our code above with start_method="spawn"
or "forkserver"
. Sadly, the serialized array is no longer shared among workers. Each worker has a large USS:
|
The reason why our trick no longer works is that "spawn" and "forkserver" don't benefit from the copy-on-write mechanism. They will start a "fresh" subprocess with fresh memory space, instead of sharing with the parent. Everything the child process needs to access is pickled in the parent process and sent to the child. This ensures safe behavior, but is bad for start-up speed and memory usage.
In our case, the entire dataset will be pickled and sent to child processes. This is why each child process consumes a large USS.
torch.Tensor
It turns out there is a simple fix to this problem: just store the serialized dataset in a torch.Tensor
instead of a numpy array. The reason why it works, is that multiprocessing uses a customizable pickle implementation called ForkingPickler
, and PyTorch customizes how torch.Tensor
should be pickled by it: the tensor data will not be serialized to bytes. Instead, during pickling the tensor will be moved to shared memory files (typically under /dev/shm
) to be accessed by other processes directly.
To test tensor-based serialization, we run ./main-torchserialize.py spawn
using the code here, and observes the following memory usage in workers (raw log is here):
torch.Tensor
as needed. This is different from start_method="fork"
where the entire memory space is shared at the beginning.import torch
.After applying tensor-based serialization,the total PSS usage in the end is 2.2G-- still worse than our earlier number using start_method="fork"
.Next section will optimize it further.
The last culprit in the above experiment is the 160MBper-worker USS in the above figure: this is just the memory footprint of import torch
,mainly for PyTorch's global variables, etc. Since every child process launched by "spawn / forkserver" is a "fresh" one, they all need to import torch
independently, hence each has 160MB of USS.
Luckily, "forkserver" provides a way to share the import torch
RAM usage through copy-on-write. By calling the undocumented Python API multiprocessing.set_forkserver_preload(["torch"])
before launching processes, each child process will be "less fresh": the torch library is preloaded (and shared), and don't need to be imported by each process independently.
Below are the experiment results. Code and full logs are on github:
|
start_method="fork"
.(Note that this optimization may be unsafe if import torch
creates any threads.My observation is that threads are indeed created due to import numpy
inside torch, but they can be disabled with environment variables.)
So far we've only looked at a single dataloader (with 4 workers). In reality, the only scalable way to use PyTorch on multiple GPUs is to use one process per GPU, each will have its own dataloader and dataloader workers. This gives a total of #GPUs x (#DL workers + 1)
processes organized like below:
We modified the previous experiment slightly into this code to run on 2 GPUs. The memory usage looks like this:
|
Our previous optimization on dataloader workers is still effective - dataloader workers have a tiny USS. However, RAM usage is now replicated by #GPUs times because we let each GPU worker read the dataset independently.
An inconvenient solution to this problem is to load and serialize the dataset before launching GPU workers. By doing this, all GPU workers share the dataset just like what dataloader workers do. However, this limits flexibility and often requires significant refactoring, due to reasons such as:
Another simple solution to this problem is again to use torch.Tensor
and ForkingPickler
to share the dataset among GPU workers, except that we need to do it explicitly like this:
|
This logic is implemented as another serialization utilhere. When using it as a drop-in replacement (full code here), the dataset is no longer replicated by GPU workers:
|
GPU worker 1 still has a small amount of extra USS, and that's just the footprint of import torch
that we saw earlier, and can be avoided using set_forkserver_preload
.
We've successfully reduced the total RAM usage by (approximately) a factor of
The essence of the solution is to let all processes share memory through a single torch.Tensor
object, which needs to be moved to Linux shared memory by PyTorch's custom pickling routine. The TLDR on how to achieve sharing is:
- Don't let dataloader workers access many Python objects in their parent. Serialize all objects into a single
torch.Tensor
(but not numpy array) for workers to access.- Don't let all GPU workers load data independently. Load in one GPU worker, and share with others through a
torch.Tensor
.
For list-like data, all of these can be implemented transparently using the serialization routines developed in this article.
Multi-processing is often the only way to achieve trueparallelism in Python, but it comes with many tricky problems.This article hopefully provides an in-depth view of the problem of RAM usage.