Gradient Checkpointing 是什么有什么作用
我们来详细介绍一下 Gradient Checkpointing(有时也称为 Activation Checkpointing)。

请添加图片描述
图为娱乐,与文无关

  1. 背景:深度学习训练中的显存瓶颈

在训练深度神经网络时,标准的反向传播算法 (Backpropagation) 需要计算梯度 ∂L/∂W (损失 L 对权重 W 的偏导数) 来更新模型参数。为了计算这些梯度,算法需要在前向传播 (Forward Pass) 过程中存储每一层的激活值 (Activations)——也就是每一层计算的中间输出结果。

为什么需要存储激活值?因为在反向传播 (Backward Pass) 过程中,计算某一层权重的梯度通常需要用到该层的输入(即前一层的激活值)以及损失函数对该层输出的梯度。

对于非常深或者非常宽的模型(如大型 Transformer),存储所有层的激活值会消耗大量的 GPU 显存。这常常成为训练大型模型的瓶颈,限制了可以使用的模型大小或训练时的批次大小 (Batch Size)。

  1. Gradient Checkpointing 的核心思想与工作原理

Gradient Checkpointing 是一种用计算换显存的技术。其核心思想是:在前向传播过程中,不存储所有中间层的激活值,只选择性地存储其中一部分(称为“检查点”)。然后在反向传播过程中,当需要用到那些没有被存储的激活值时,再从最近的一个检查点开始,重新进行一小部分前向计算,以临时生成所需的激活值。

具体步骤如下:

• 前向传播 (Forward Pass):
• 模型被逻辑上划分为多个“段”(Segment),段与段之间的边界就是“检查点”(Checkpoint)。
• 在执行前向传播时,只有检查点位置的激活值会被明确存储在内存中。
• 对于非检查点层的激活值,在计算完成后,它们所占用的内存会被立即释放(或者不被存储)。
• 反向传播 (Backward Pass):
• 反向传播正常进行,计算梯度。
• 当反向传播进行到需要某个未被存储的激活值(位于两个检查点之间,或输入与第一个检查点之间)时,系统会暂停梯度的计算。
• 它会找到该激活值所在段的前一个检查点(或者模型的输入)。
• 从这个检查点开始,重新执行该段的前向计算,生成所有必要的中间激活值,但这次只是临时生成,用完即弃。
• 使用这些重新计算得到的激活值,完成该段的梯度计算。
• 梯度计算完成后,这些临时重新计算的激活值再次被丢弃。
• 反向传播继续进行到下一个段或下一个检查点。
3. Gradient Checkpointing 的作用(优点)

• 显著降低显存占用: 这是最主要的作用。通过只存储少量检查点的激活值,而不是所有层的激活值,可以大幅减少训练过程中峰值显存的需求。
• 支持训练更大的模型: 降低了显存门槛,使得在有限的 GPU 显存上能够训练原本无法容纳的更大、更深的模型。
• 支持更大的批次大小: 在模型大小固定的情况下,节省下来的显存可以用来增加训练的批次大小,有时可以提高训练效率或模型性能。
4. Gradient Checkpointing 的缺点(代价)

• 增加计算时间: 由于部分前向计算需要在反向传播过程中重新执行一次,总的计算量增加了。通常,这会导致训练时间延长(例如,可能增加 20-30% 或更多,具体取决于检查点的设置和模型结构)。这是一个典型的时间-空间权衡 (Time-Memory Trade-off)。
5. 为什么 Gradient Checkpointing 会导致(或关联)显存峰值问题?

这里需要澄清一个常见的误解。Gradient Checkpointing 的主要目的和效果是降低整体的峰值显存占用,使其低于不使用该技术时的水平。然而,在讨论 QLoRA 和 Paged Optimizers 的背景下提到 Gradient Checkpointing 相关的显存峰值,通常是指在使用了 Gradient Checkpointing 的情况下,其内部工作机制仍然会产生一些局部的、短暂的显存使用高峰,这些高峰可能对显存管理(如 Paged Optimizers)提出挑战。

具体来说,这个“峰值问题”体现在:

• 重计算阶段的临时显存需求: 当 Gradient Checkpointing 在反向传播期间触发重计算(Recomputation)时,它需要为正在重计算的那个“段”分配内存来存储临时生成的激活值。虽然这些激活值用完即弃,并且只覆盖模型的一小部分,但在重计算发生的那个短暂时刻,显存使用量会暂时性地上升。
• 与优化器状态的交互: 训练过程中除了激活值,显存还被梯度(Gradients)和优化器状态(Optimizer States,如 Adam 中的 momentum 和 variance)占用。优化器状态通常非常占用显存(例如,AdamW 优化器状态大约是模型参数量的 2 倍,如果用 FP32 存储的话)。当 Gradient Checkpointing 进行重计算导致临时激活值显存增加时,如果此时梯度和优化器状态也占用了大量显存,总的显存需求就可能达到一个局部的峰值。
• Paged Optimizers 的背景: QLoRA 论文中提到 Paged Optimizers 是为了解决在 Gradient Checkpointing 下可能出现的显存碎片化或短暂 OOM 问题。当重计算需要分配一块连续内存,而此时 GPU 显存虽然总量足够,但可能因为碎片化而无法分配成功,或者因为这个短暂的峰值需求(重计算的激活值 + 梯度 + 优化器状态)超过了可用物理显存时,Paged Optimizers 就能介入,将优化器状态暂时移到 CPU 内存,从而释放 GPU 显存以满足重计算的需求,避免训练崩溃。
总结:

Gradient Checkpointing 本身是为了降低训练过程中的总体峰值显存,使得大模型训练成为可能。它通过牺牲计算时间来换取显存空间。然而,在其工作机制中,反向传播期间的“重计算”步骤会临时性地增加显存使用。这个局部的、短暂的显存峰值,虽然低于不使用 GC 时的整体峰值,但在显存极度受限(如 QLoRA 场景)或存在碎片化的情况下,仍可能构成挑战,这就是为什么 Paged Optimizers 等技术会与 Gradient Checkpointing 结合使用,以更好地管理这些动态的显存需求。

• Gradient Checkpointing:牺牲时间换显存:只保存部分激活,在需要梯度的时候,选择最近的激活前向计算一下。
• flash-attention:牺牲时间换带宽:通过利用高速缓存的高速计算,不保存注意力权重S和softmax后的P,需要的时候再计算。并减少

欢迎关注公众号查看更多系列内容:AI沙砾

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐