IT博客汇
  • 首页
  • 精华
  • 技术
  • 设计
  • 资讯
  • 扯淡
  • 权利声明
  • 登录 注册

    谁动了我的显存?——深度学习训练过程显存占用分析及优化

    游凯超发表于 2023-07-06 18:47:02
    love 0

    在大语言模型时代,不仅语言模型变得越来越大,而且几乎所有的模型都想变得越来越大,试图在模型变大之后观察到一些涌现出来的能力。

    模型变大之后,最突出的问题就是显存不够用了。本文对深度学习训练过程中的显存占用问题进行一些具体分析,加深我对训练过程的理解,能够进行一些简单的显存优化操作。如果读者们有更多的相关资料、优化技巧,欢迎在评论区补充。

    显存占用概述

    深度学习训练过程中的显存占用,大致可以分为三部分:

    • 框架占用,例如pytorch框架的cuda context会占用大约几百MB显存
    • 模型参数相关的占用,比如7B的模型以FP16格式要占用14GB显存。此处还包括优化器、梯度相关的参数占用,全量微调的情况下,梯度与参数一样大,优化器状态是梯度的1~2倍(SGD为1倍,Adam为2倍)。如果使用DDP进行多卡训练,则还需要乘以显卡数量;如果使用FSDP进行多卡训练,显存占用与显卡数无关,但是会增加通信开销。
    • 特征相关的占用,这部分显存占用是最复杂的,因为它与模型的具体计算流程有关。很多地方只会笼统地说这类占用与batchsize成正比,但是具体的比例系数很难分析。

    本文希望详细解析特征相关的显存占用到底是多少。

    统计方法

    我们用一个样例程序,来使用不同的方法、在不同的情况下计算(x+1)(y+1)这样一个简单的函数。具体程序为:

    import torch
    
    # Create two tensors with 1GB memory footprint each, initialized randomly, in fp16 format
    # For a tensor of float16 (2 bytes), 1GB of memory can hold 1GB / 2B = 500M elements
    tensor_size = 512 * 1024 * 1024 
    x = torch.randn(tensor_size, dtype=torch.float16, device='cuda')
    y = torch.randn(tensor_size, dtype=torch.float16, device='cuda')
    
    # Record current memory footprint, and reset max memory counter
    current_memory = torch.cuda.memory_allocated()
    torch.cuda.reset_peak_memory_stats()
    
    def compute(x, y):
        return (x + 1) * (y + 1)
    
    z = compute(x, y)
    
    # Record the additional memory (both peak memory and persistent memory) after calculating the resulting tensor
    additional_memory = torch.cuda.memory_allocated() - (current_memory + 1e9)
    peak_memory = torch.cuda.max_memory_allocated()
    additional_peak_memory = peak_memory - (current_memory + 1e9)
    
    print(f"Additional memory used: {additional_memory / (1024 ** 3)} GB")
    print(f"Additional peak memory used: {additional_peak_memory / (1024 ** 3)} GB")

    在这个函数计算过程中,输入x、y,输出z不可避免地要占用显存。我们希望在不同情况下、改变不同的计算方式,观察/理解为了计算这个函数所需要的额外显存。

    这里需要区分两个概念:峰值显存占用 与 持续显存占用 。在计算一个函数的过程中,我们可能创建了很多中间结果,它们需要临时占用显存;但是当函数计算完成之后,只有一部分结果需要持续存在(直到反向传播结束),另一部分可以被释放。上述示例小脚本,会分别输出持续显存占用和峰值显存占用。

    一:不需要计算梯度的情况

    上述示例脚本,直接运行的结果是:

    Additional memory used: 0.06867742538452148 GB
    Additional peak memory used: 2.0686774253845215 GB

    也就是说,函数运行期间需要大约2GB的显存占用,运行结束之后几乎不占显存。

    具体来说,函数计算过程中需要创建 x+1和y+1两个临时变量,乘积结果放在z中。因此大约需要2GB的显存来存储临时变量,它们在计算结束后会被释放。

    至于为什么持续显存占用不严格为0、峰值显存占用不严格为2GB,这就与pytorch的具体显存管理策略、对象的显存布局有关,我们暂时不关心这部分内容。

    二:需要计算梯度的情况

    我们把计算函数改写为:

    def compute(x, y):
        x.requires_grad_(True)
        y.requires_grad_(True)
        return (x + 1) * (y + 1)

    得到的结果为:

    Additional memory used: 2.0686774253845215 GB
    Additional peak memory used: 2.0686774253845215 GB

    也就是说,需要计算梯度时,计算过程中的临时变量并不会被释放,反而会持续存在于显存中,等待后续用于反向传播计算。

    这个问题可以变得更复杂一些,如果我们让一个输入要求梯度、一个参数不要求梯度,会发生什么呢?

    def compute(x, y):
        x.requires_grad_(True)
        return (x + 1) * (y + 1)

    得到的结果是:

    Additional memory used: 1.0686774253845215 GB
    Additional peak memory used: 2.0686774253845215 GB

    可以看到,计算完成后释放了一个临时变量,还有一个临时变量持续存在。这是因为我们只要求x能计算梯度,y不用计算梯度。

    有意思的是,大部分人看到这里,都觉得既然y不需要计算梯度,那么肯定是y+1这个临时变量被释放了。然而,事实上是x+1这个临时变量被释放掉了。

    为了说清楚这个问题,我们用具体的值来区分x和y,这里x的值是1,y的值是2,于是临时变量x+1的值是2,y+1的值是3.通过计算结果z记录的中间变量的值,我们可以区分z具体记录了哪个中间结果。

    def compute(x, y):
        x.zero_()
        y.zero_()
        x += 1
        y += 2
        x.requires_grad_(True)
        z = (x + 1) * (y + 1)
        print(z.grad_fn._saved_other.mean().item())
        return z

    这段代码的运行结果是:

    3.0
    Additional memory used: 1.0686774253845215 GB
    Additional peak memory used: 2.0686774253845215 GB

    可以看到,虽然是x要求梯度,但是在计算过程中保留的变量却是y+1。

    为了从原理上理解这个现象,我们来看看反向传播的本质:梯度求导。

    考虑神经网络中的某个函数c = f(a,b),输入为a和b两个参数,输出为c。c将继续参与后续运算,得到损失函数J = g(c)。反向传播的任务,就是在已知\frac{\partial J}{\partial c}的情况下,计算\frac{\partial J}{\partial a}和\frac{\partial J}{\partial b}。

    根据链式法则,\frac{\partial J}{\partial a} = \frac{\partial J}{\partial c} \frac{\partial c}{\partial a}且\frac{\partial J}{\partial b} = \frac{\partial J}{\partial c} \frac{\partial c}{\partial b}。于是,为了反向传播,我们需要记录\frac{\partial c}{\partial a}和\frac{\partial c}{\partial b}。

    不失一般性而言,\frac{\partial c}{\partial a}是a和b的函数。于是,为了反向传播,我们需要完整记录a和 b 。这是最简单粗暴的方法。

    实际上,对于很多简单函数来说,偏导数的表达式并不复杂。以本文的小脚本为例,c = f(a,b)=a * b,于是\frac{\partial c}{\partial a}=b只和b有关。也就是说,为了计算a的反向传播,只需要记录b的值。

    于是,我们就能理解,为什么x需要梯度(对应地x+1也需要梯度)时,反向传播记录的是y+1。

    这部分的内容本质上就是自动微分的内容。当我们为每一个原子操作(例如加减乘除)写好了反向传播算法,自动微分就能够沿着计算图进行自动求导。这种方法写起来很简单,也很直观。然而,它的缺点也很明显:显存占用大。

    三:不使用自动微分计算梯度

    有什么办法能够绕开自动微分的限制,使得显存开销更低吗?

    有的,答案就是pytorch提供的torch.autograd.Function。

    我们把计算部分的代码替换成Function的实现,直接用一个算子实现(x+1)*(y+1)的功能:

    from torch.autograd import Function
    
    class AddMulFunction(Function):
        @staticmethod
        def forward(ctx, x, y):
            ctx.save_for_backward(x, y)
            return (x + 1) * (y + 1)
    
        @staticmethod
        def backward(ctx, grad_output):
            x, y = ctx.saved_tensors
            grad_x = grad_output * (y + 1)
            grad_y = grad_output * (x + 1)
            return grad_x, grad_y
    
    func = AddMulFunction.apply
    
    def compute(x, y):
        x.requires_grad_(True)
        y.requires_grad_(True)
        return func(x, y)

    输出结果为:

    Additional memory used: 0.06867742538452148 GB
    Additional peak memory used: 2.0686774253845215 GB

    这个算子也能够进行反向传播,而且计算结束之后并不会占用显存。这是因为我们在它的backward函数里手动计算了这个算子的梯度,使得它不用记录临时变量x+1和y+1也能进行反向传播。

    从这个算子的实现中,我们能清晰地看到ctx.save_for_backward函数,它为反向传播过程记录必要的参数。

    关于torch.autograd.Function,有一个细节值得注意:torch.autograd.Function设计的初衷就是为了让高级用户绕开自动微分的限制,因此torch.autograd.Function的forward和backward函数执行过程中,并不会记录梯度操作。大致可以理解为:torch.autograd.Function的forward和backward函数执行过程被包裹在 with torch.no_grad()环境中。

    例如,我们把计算代码改成:

    from torch.autograd import Function
    
    class AddMulFunction(Function):
        @staticmethod
        def forward(ctx, x, y):
            ctx.save_for_backward(x, y)
            z = (x + 1) * (y + 1)
            print(z.requires_grad)
            print(z.grad_fn)
            return z
    
        @staticmethod
        def backward(ctx, grad_output):
            x, y = ctx.saved_tensors
            grad_x = grad_output * (y + 1)
            grad_y = grad_output * (x + 1)
            return grad_x, grad_y
    
    func = AddMulFunction.apply
    
    def compute(x, y):
        x.requires_grad_(True)
        y.requires_grad_(True)
        return func(x, y)
    
    z = compute(x, y)
    
    print(z.requires_grad)
    print(z.grad_fn)

    输出结果为:

    False
    None
    True
    <torch.autograd.function.AddMulFunctionBackward object at 0x7fea4304a5e0>
    Additional memory used: 0.06867742538452148 GB
    Additional peak memory used: 2.0686774253845215 GB

    即使x和y是需要梯度的,在Function的forward函数中,z是不需要梯度的。然而,当走出forward函数之后,pytorch会为它加上需要梯度的标志,并且通过grad_fn属性记录其反向传播需要执行的函数。

    通过这一细节,我们可以理解,为什么定义了AddMulFunction之后,不能直接使用AddMulFunction.forward函数,而必须用func = AddMulFunction.apply。

    以上涉及的内容,其实就是“算子融合”,通过手动计算反向传播过程,节约不必要的显存开销。上述算子还可以进一步优化,把峰值显存占用也降下来。感兴趣的朋友可以试试。

    我们日常使用的很多算子,都是融合过的。

    以sigmoid算子为例,如果我们自己来实现:

    def compute(x):
        x.requires_grad_(True)
        z = 1 / (1 + torch.exp(-x))
        return z
    
    z = compute(x)

    输出结果为:

    Additional memory used: 2.0686774253845215 GB
    Additional peak memory used: 3.0686774253845215 GB

    峰值显存占用为3GB,持续显存占用为2GB。

    如果改为pytorch自带的已经融合过的算子:

    def compute(x):
        x.requires_grad_(True)
        z = torch.nn.Sigmoid()(x)
        return z
    
    z = compute(x)

    输出结果为:

    Additional memory used: 0.06867742538452148 GB
    Additional peak memory used: 0.06867742538452148 GB

    峰值显存占用与持续显存占用几乎都是0!

    这是怎么做到的呢?

    • sigmoid函数是element-wise的函数,只需要申请一次显存,把所有的操作都变成in-place,再把这块显存作为输出内容,就不用申请临时空间了。
    • sigmoid函数 z=\frac{1}{1+e^{-x}}的导数是z * (1-z),为了计算反向传播,只需要记录输出z。而在我们的示例程序中,z原本就会保留,因此sigmoid函数的反向传播记录的z就不用额外占用空间。

    算子显存占用分析中的记账问题

    上述分析中,关于sigmoid算子显存占用为0的结论并不严谨。它占用的显存刚好是我们的输出,因此没有算在它的显存开销中。

    为了更准确地反映这一问题,我们让它多计算几次:

    def compute(x):
        x.requires_grad_(True)
        for i in range(5):    
            x = torch.nn.Sigmoid()(x)
        return x
    
    z = compute(x)

    计算5次,额外占用显存为4GB:

    Additional memory used: 4.0686774253845215 GB
    Additional peak memory used: 4.0686774253845215 GB

    大体上来说,一个算子持续占用的显存,就是它在前向传播过程中保存下来的变量所占的显存。但一个程序占用的显存总量,并不能用全部算子占用的显存数进行求和,因为这些变量之间可能有重复(正如我们的示例中的输入变量、输出变量那样)。

    总结

    本文介绍了深度学习训练过程中的显存占用分析方法、自动求导与手动算子融合、优化等技术原理。算子融合是深度学习编译器等技术的核心,而算子优化目前还需要人工设计。对算子优化感兴趣的朋友,可以看看FlashAttention论文(参见《Flashattention: Fast and memory-efficient exact attention with io-awareness》),它是一个十分优雅的算子优化的例子。

    注:

    如何查看pytorch自带算子为反向传播保存的变量?可以通过输出的grad_fn属性的dir(var.grad_fn)看到,里面的_saved_xxx就是为了反向传播保存的变量。

    对于乘法,这个属性是_saved_other,因为乘法的梯度是另一个变量;对于sigmoid算子,这个属性是_saved_result,因为sigmoid的梯度和计算结果有关。

    大部分的pytorch算子都可以通过这种方式获得保存的具体变量内容,例如卷积算子保留了以下内容:_saved_bias_sym_sizes_opt/_saved_dilation/_saved_groups/_saved_input/_saved_output_padding/_saved_padding/_saved_stride/_saved_transposed/_saved_weight. 其中大部分都是卷积的配置(例如padding大小、stride大小等内容),真正对显存占用影响最大的就是_saved_input和_saved_weight。

    附上这部分代码,感兴趣的读者可以用它来分析pytorch自带算子的具体计算机制。

    var = z
    names =[k for k in dir(var.grad_fn) if k.startswith('_saved')]
    for k in names:
        v = getattr(var.grad_fn, k)
        if isinstance(v, torch.Tensor):
            print(k, v.shape)
        else:
            print(k, v)



    来源:知乎 www.zhihu.com
    作者:游凯超

    【知乎日报】千万用户的选择,做朋友圈里的新鲜事分享大牛。 点击下载


沪ICP备19023445号-2号
友情链接