Autograd 是一个反向自动微分系统(或梯度计算引擎),基于记录所有的操作来构建一个有向无环图——Autograd 计算图,其中叶子节点是输入 Tensor,根节点 root 是输出 Tensor,通过跟踪图中从根节点 root 到叶子节点的路径上的操作,能够自动地计算出梯度。
在 PyTorch 中,模型训练的每一轮迭代,都会创建对应的 Autograd 计算图:在前向传播阶段动态地创建 Autograd 计算图,在反向传播阶段根据 Autograd 计算图来进行梯度的计算。
构建分布式 Autograd 计算图
对于分布式模型训练环境下,需要在各个节点(主机)之间进行大量的 RPC 调用,统一协调各个过程来完成模型的训练。PyTorch 实现的分布式 Autograd,在前向传播过程中构建 Autograd 计算图,并且基于 Autograd 计算图在反向传播过程中计算梯度。在前向传播过程中,PyTorch 持续跟踪各个 RPC 调用的情况,必须确保在反向传播过程中计算是正确的,所以 PyTorch 在实现过程中使用了 send、recv 这一对函数来进行跟踪,当执行 RPC 调用时将 send 和 recv 绑定到 Autograd 计算图上。
为了说明这个过程,以 PyTorch 文档中下面的简单计算为例:
import torch import torch.distributed.rpc as rpc def my_add(t1, t2): return torch.add(t1, t2) # On worker 0: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) # Perform some computation remotely. t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) # Perform some computation locally based on remote result. t4 = torch.rand((3, 3), requires_grad=True) t5 = torch.mul(t3, t4) # Compute some loss. loss = t5.sum()
这个例子非常简单,计算 t5 = (t1 + t2) * t4 的结果,其中 t1 + t2 的计算是通过 RPC 调用在 worker1 节点上计算,计算结果 t3 返回到 worker0 节点上,继续后面的乘法计算 t5 = t3 * t4。这个分布式计算的例子,在执行前向传播过程中生成 Autograd 计算图,在反向传播阶段使用它来计算梯度,如下图所示:
图中存在两个 send-recv 调用对,其中,在 Worker 0 上有 2 个 Autograd 计算图,分别以 mul 和 send 函数为根的 root;在 Worker 1 有 1 个计算图,根 root 是 send 函数。
FAST mode 算法
目前,PyTorch 已经实现了 FAST mode 算法,该算法考虑了对性能要求比较敏感的应用场景,通过设置较强的假设越是来简化分布式梯度计算。而 SMART mode 算法是更通用意义上的算法,当前正在进行中,还没有完成实现。
FAST mode 算法的关键假设:
每一个 send 函数在反向传播阶段只存在一个依赖,也就是说,通过一个 RPC 调用 send 函数只需要从目的节点端接收一个梯度结果。下面是 FAST mode 算法的基本流程:
为了说明 FAST mode 算法的流程,下面通过一个简单的例子来加深理解计算的过程:
import torch import torch.distributed.autograd as dist_autograd import torch.distributed.rpc as rpc def my_add(t1, t2): return torch.add(t1, t2) # On worker 0: # Setup the autograd context. Computations that take # part in the distributed backward pass must be within # the distributed autograd context manager. with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) # Perform some computation remotely. t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) # Perform some computation locally based on remote result. t4 = torch.rand((3, 3), requires_grad=True) t5 = torch.mul(t3, t4) # Compute some loss. loss = t5.sum() # Run the backward pass. dist_autograd.backward(context_id, [loss]) # Retrieve the gradients from the context. dist_autograd.get_gradients(context_id)
通过前向传播计算,会得到 Autograd 计算图,下图描述了 Autograd 计算图以及在进行分布式 Autograd 计算过程中得到的依赖关系:
通过上图可以看到,在前向传播阶段,构建好了 Autograd 计算图,其中:Worker 0 上有两个子图,它们的根 root 分别是 mul1 和 send1;Worker 2 上有一个计算图,根 root 为 send2。
分布式 Autograd 梯度的详细计算过程,描述如下:
参考资源