IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    理解 PyTorch 分布式 Autograd 设计

    Yanjun发表于 2024-02-12 10:13:15
    love 0

    Autograd 是一个反向自动微分系统(或梯度计算引擎),基于记录所有的操作来构建一个有向无环图——Autograd 计算图,其中叶子节点是输入 Tensor,根节点 root 是输出 Tensor,通过跟踪图中从根节点 root 到叶子节点的路径上的操作,能够自动地计算出梯度。
    在 PyTorch 中,模型训练的每一轮迭代,都会创建对应的 Autograd 计算图:在前向传播阶段动态地创建 Autograd 计算图,在反向传播阶段根据 Autograd 计算图来进行梯度的计算。

    构建分布式 Autograd 计算图

    对于分布式模型训练环境下,需要在各个节点(主机)之间进行大量的 RPC 调用,统一协调各个过程来完成模型的训练。PyTorch 实现的分布式 Autograd,在前向传播过程中构建 Autograd 计算图,并且基于 Autograd 计算图在反向传播过程中计算梯度。在前向传播过程中,PyTorch 持续跟踪各个 RPC 调用的情况,必须确保在反向传播过程中计算是正确的,所以 PyTorch 在实现过程中使用了 send、recv 这一对函数来进行跟踪,当执行 RPC 调用时将 send 和 recv 绑定到 Autograd 计算图上。

    • send 函数被绑定到 RPC 调用的源节点(Source Node)端,send 函数的输出边指向 RPC 的输入 Tensor 变量;在反向传播阶段,send 函数会接收从目的节点(Destination Node)端与之对应的 recv 函数发送过来的结果,作为 send 函数的输入
    • recv 函数被绑定到 RPC 调用的目的节点端,通过在目的节点端查询对应前向计算得到的结果作为 recv 的输入;在反向传播阶段,recv 函数执行得到的梯度结果被发送到源节点端对应的 send 函数

    为了说明这个过程,以 PyTorch 文档中下面的简单计算为例:

    import torch
    import torch.distributed.rpc as rpc
    
    def my_add(t1, t2):
      return torch.add(t1, t2)
    
    # On worker 0:
    t1 = torch.rand((3, 3), requires_grad=True)
    t2 = torch.rand((3, 3), requires_grad=True)
    
    # Perform some computation remotely.
    t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
    
    # Perform some computation locally based on remote result.
    t4 = torch.rand((3, 3), requires_grad=True)
    t5 = torch.mul(t3, t4)
    
    # Compute some loss.
    loss = t5.sum()
    

    这个例子非常简单,计算 t5 = (t1 + t2) * t4 的结果,其中 t1 + t2 的计算是通过 RPC 调用在 worker1 节点上计算,计算结果 t3 返回到 worker0 节点上,继续后面的乘法计算 t5 = t3 * t4。这个分布式计算的例子,在执行前向传播过程中生成 Autograd 计算图,在反向传播阶段使用它来计算梯度,如下图所示:
    send_recv_functions
    图中存在两个 send-recv 调用对,其中,在 Worker 0 上有 2 个 Autograd 计算图,分别以 mul 和 send 函数为根的 root;在 Worker 1 有 1 个计算图,根 root 是 send 函数。

    FAST mode 算法

    目前,PyTorch 已经实现了 FAST mode 算法,该算法考虑了对性能要求比较敏感的应用场景,通过设置较强的假设越是来简化分布式梯度计算。而 SMART mode 算法是更通用意义上的算法,当前正在进行中,还没有完成实现。
    FAST mode 算法的关键假设:
    每一个 send 函数在反向传播阶段只存在一个依赖,也就是说,通过一个 RPC 调用 send 函数只需要从目的节点端接收一个梯度结果。下面是 FAST mode 算法的基本流程:

    1. 在一个 Worker 节点上开始执行反向传播计算,从起始的根 root 开始,这就要求所有的根 root 必须是本地的。
    2. 为当前的分布式 Autograd 上下文对象(通过 dist_autograd.context()获得)查询所有的 send 函数(一次训练迭代中,Autograd 计算图中所有的 send-recv 对都保存在分布式 Autograd 上下文对象中)。
    3. 根据已经确定好的根 root,在本地计算这些根 root 的依赖关系,也包括所有 send 函数的依赖关系,然后启动本地 Autograd 引擎开始计算梯度。
    4. 当 Autograd 引擎执行 recv 函数时,recv 函数基于 RPC 调用将梯度发送给与之对应的 send 函数所在的节点端,实际上 recv 函数只需要发送对应的两个 ID 即可:autograd_context_id(唯一对应于一次迭代的分布式 Autograd 上下文对象)和 autograd_message_id(唯一对应于一个 send-recv 对)。
    5. 远程节点端接收到对应的 autograd_context_id 和 autograd_message_id,查询找到对应的 send 函数,如果这是第一次在该节点上接收到 autograd_context_id,会在该节点本地计算与 autograd_context_id 对应的所有依赖关系。
    6. send 函数被放到执行队列中,等待调度执行,得到该 send-recv 对对应的 RPC 调用结果,并返回给调用端。
    7. 单独为每个分布式 Autograd 上下文对象计算累加梯度(Accumulated Gradient),计算结果保存在 Dict[Tensor, Tensor] 结构中,基于这个 Dict 可以通过一个给定的 Tensor 得到它对应的累加梯度。

    为了说明 FAST mode 算法的流程,下面通过一个简单的例子来加深理解计算的过程:

    import torch
    import torch.distributed.autograd as dist_autograd
    import torch.distributed.rpc as rpc
    
    def my_add(t1, t2):
      return torch.add(t1, t2)
    
    # On worker 0:
    
    # Setup the autograd context. Computations that take
    # part in the distributed backward pass must be within
    # the distributed autograd context manager.
    with dist_autograd.context() as context_id:
      t1 = torch.rand((3, 3), requires_grad=True)
      t2 = torch.rand((3, 3), requires_grad=True)
    
      # Perform some computation remotely.
      t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2))
    
      # Perform some computation locally based on remote result.
      t4 = torch.rand((3, 3), requires_grad=True)
      t5 = torch.mul(t3, t4)
    
      # Compute some loss.
      loss = t5.sum()
    
      # Run the backward pass.
      dist_autograd.backward(context_id, [loss])
    
      # Retrieve the gradients from the context.
      dist_autograd.get_gradients(context_id)
    

    通过前向传播计算,会得到 Autograd 计算图,下图描述了 Autograd 计算图以及在进行分布式 Autograd 计算过程中得到的依赖关系:
    fast_mode_distributed_dependencies_computed
    通过上图可以看到,在前向传播阶段,构建好了 Autograd 计算图,其中:Worker 0 上有两个子图,它们的根 root 分别是 mul1 和 send1;Worker 2 上有一个计算图,根 root 为 send2。
    分布式 Autograd 梯度的详细计算过程,描述如下:

    1. 在 Worker 0 上,从 loss 和 send1 开始,计算依赖关系:send1 有 1 个依赖、mul1 有 1 个依赖。
    2. 在 Worker 0 上启动 Autograd 引擎,首先执行 mul1 函数,将结果在对应的分布式 Autograd 上下文对象中进行累加,对应着 t4;然后执行 recv2 函数,将计算结果梯度发送给 Worker 1 上。
    3. Worker 1 第一次得知在进行反向传播计算,所以首先计算本地依赖关系:send2 有 1 个依赖、add1 有 1 个依赖、recv1 有 1 个依赖。
    4. 然后在 Worker 1 上启动本地 Autograd 引擎,并将 send2 加入到执行队列等待调度,接着依次执行 add1、recv1 函数,当 recv1 计算完成后,会将梯度计算结果发送给 Worker 0。
    5. Worker 0 接收到 recv1 的梯度结果,在本地执行 send1 函数。
    6. 最后,t1、 t2、t4 的梯度都会在当前的分布式 Autograd 上下文对象中进行累加计算,这样就完成了一轮迭代的分布式梯度计算。

    参考资源

    • https://pytorch.org/docs/stable/rpc.html#distributed-autograd-framework
    • https://pytorch.org/docs/stable/notes/autograd.html
    • https://pytorch.org/docs/stable/rpc/distributed_autograd.html
    • https://pytorch.org/blog/overview-of-pytorch-autograd-engine/


沪ICP备19023445号-2号
友情链接