在大模型训练过程中,显存(GPU内存)的占用主要来自三部分:模型参数显存激活显存优化器显存。每一部分的存在都有其必要性,共同支撑模型的训练过程。以下是它们的详细作用和原因:


1. 模型参数显存(Parameter Memory)

作用:存储模型的权重(参数)和梯度。
为什么需要?

  • 前向传播:计算时需要加载模型参数(如线性层的权重矩阵)。
  • 反向传播:需要保存参数的梯度以更新模型。
  • 存储形式:参数通常以float32格式存储(训练时),显存占用为 参数量 × 4字节
  • 举例:1750亿参数的GPT-3,仅参数显存需求为 175B × 4B ≈ 700GB(需多卡分布式存储)。

2. 激活显存(Activation Memory)

作用:存储前向传播的中间结果(激活值),用于反向传播的梯度计算。
为什么需要?

  • 链式法则依赖:反向传播时,梯度计算需要前向传播的中间结果(如ReLU的输出、注意力分数等)。
  • 显存峰值:激活显存通常是训练时的显存瓶颈,尤其是大batch size或长序列输入时。
  • 存储形式:激活值通常与输入数据量相关,显存占用为 batch_size × 序列长度 × 隐藏维度 × 层数 × 4字节
  • 举例:训练GPT-3时,激活显存可能远超参数显存(如batch_size=1024时可达TB级)。

3. 优化器显存(Optimizer Memory)

作用:存储优化器状态(如动量、方差等),用于参数更新。
为什么需要?

  • 优化器状态:如Adam优化器需保存动量(momentum)和方差(variance),每个参数额外占用 8字节float32动量 + float32方差)。
  • 显存占用:优化器显存通常为 参数量 × 12字节(参数4B + 梯度4B + 优化器状态8B)。
  • 举例:175B参数的模型,Adam优化器显存需求为 175B × 12B ≈ 2.1TB

三部分显存的协作流程

  1. 前向传播:加载参数 → 计算激活值并存储。
  2. 反向传播:根据激活值计算梯度 → 存储梯度。
  3. 参数更新:优化器读取梯度、更新参数(需访问优化器状态)。

为什么三者缺一不可?

  • 无参数显存:无法执行前向/反向计算。
  • 无激活显存:无法计算梯度(链式法则断裂)。
  • 无优化器显存:无法更新参数(训练停滞)。

显存优化的常见方法

  1. 参数分片(ZeRO):将参数、梯度、优化器状态分布式存储(如DeepSpeed的ZeRO-3)。
  2. 激活检查点(Activation Checkpointing):牺牲计算时间换显存,只存储部分激活值,其余重新计算。
  3. 混合精度训练:参数用float16,减少显存占用(需保留float32主副本防止精度损失)。
  4. 梯度累积:减小batch size以降低激活显存,通过多次累积梯度等效大batch。

总结

显存类型 内容 必要性 典型显存占用(以Adam为例)
模型参数显存 权重(参数) 前向/反向传播的基础 参数量 × 4B
激活显存 中间结果(激活值) 反向传播的梯度计算依赖 与batch size和序列长度相关
优化器显存 动量、方差等状态 参数更新的必需信息 参数量 × 8B

通过这三部分的协同,模型才能完成训练,而显存优化技术(如ZeRO、量化)的核心就是减少其中某一或多个部分的显存占用

问题

1. 什么是前向传播(Forward Propagation)

定义
输入数据从神经网络的输入层逐步计算,经过各层权重和激活函数的变换,最终得到输出预测值的过程。
数学表达
对于第 ( l ) 层的输出 ( \mathbf{a}^l ):
[
\mathbf{a}^l = f(\mathbf{W}^l \mathbf{a}^{l-1} + \mathbf{b}^l)
]
其中 ( f ) 是激活函数(如ReLU),( \mathbf{W}^l ) 和 ( \mathbf{b}^l ) 是权重和偏置。
作用

  • 计算模型的预测结果(如分类概率、回归值)。
  • 为反向传播提供中间激活值(用于梯度计算)。

2. 什么是反向传播(Backward Propagation)

定义
根据前向传播的输出与真实标签的误差(损失函数),从输出层反向逐层计算梯度,并更新模型参数的过程。
核心步骤

  1. 计算损失函数 ( \mathcal{L} ) 对输出的梯度 ( \frac{\partial \mathcal{L}}{\partial \mathbf{a}^L} )。
  2. 通过链式法则,从后向前逐层传递梯度,更新每一层的权重和偏置。

作用

  • 确定参数更新的方向和大小(通过梯度)。
  • 实现端到端的自动微分(无需手动推导导数)。

3. 什么是梯度(Gradient)

定义
损失函数对模型参数的偏导数向量,表示参数微小变化时损失函数的变化方向和速率。
例如,权重 ( \mathbf{W} ) 的梯度:
[
\nabla_{\mathbf{W}} \mathcal{L} = \frac{\partial \mathcal{L}}{\partial \mathbf{W}}
]
为什么计算梯度?

  • 优化模型:梯度指示参数如何调整能使损失函数下降(用于梯度下降法)。
  • 局部最优:通过梯度方向找到损失函数的极小值点(或鞍点)。

4. 什么是链式法则依赖(Chain Rule Dependency)

定义
复合函数求导的规则,反向传播中梯度通过链式法则从输出层传递到输入层。
数学表达
若 ( \mathcal{L} = f(g(x)) ),则:
[
\frac{d\mathcal{L}}{dx} = \frac{df}{dg} \cdot \frac{dg}{dx}
]
在反向传播中的应用
每一层的梯度是后一层梯度与本层局部梯度的乘积。例如:
[
\frac{\partial \mathcal{L}}{\partial \mathbf{a}^{l-1}} = \frac{\partial \mathcal{L}}{\partial \mathbf{a}^l}} \cdot \frac{\partial \mathbf{a}^l}{\partial \mathbf{a}^{l-1}}
]
为什么重要?

  • 避免重复计算,高效传递梯度。
  • 实现深度网络的自动微分(如PyTorch的Autograd机制)。

5. 什么是动量(Momentum)和方差(Variance)

动量(Momentum)

定义
优化器(如SGD with Momentum)中引入的“惯性”项,加速收敛并减少震荡。
数学表达
[
\mathbf{v}t = \beta \mathbf{v}{t-1} + (1-\beta) \nabla_{\mathbf{W}} \mathcal{L}
]
[
\mathbf{W} = \mathbf{W} - \alpha \mathbf{v}_t
]
其中 ( \beta ) 是动量系数(如0.9),( \alpha ) 是学习率。
作用

  • 保留历史梯度方向,平滑更新路径。
  • 帮助跳出局部极小值或鞍点。
方差(Variance,如Adam中的二阶矩)

定义
优化器(如Adam)中梯度平方的指数移动平均,用于自适应调整学习率。
数学表达
[
\mathbf{s}t = \beta_2 \mathbf{s}{t-1} + (1-\beta_2) (\nabla_{\mathbf{W}} \mathcal{L})^2
]
作用

  • 对频繁更新的参数减小学习率,对稀疏参数增大学习率。
  • 稳定训练过程(尤其适合非平稳目标函数)。

总结表格

概念 定义 作用
前向传播 输入数据逐层计算得到输出的过程 生成预测值,保存中间激活值
反向传播 从输出反向计算梯度并更新参数的过程 通过梯度优化模型参数
梯度 损失函数对参数的偏导数 指示参数更新方向,驱动模型优化
链式法则依赖 复合函数求导的规则,用于梯度反向传递 实现高效自动微分,支持深度网络训练
动量 历史梯度的加权平均,提供更新惯性 加速收敛,减少震荡
方差 梯度平方的指数平均(如Adam中的二阶矩) 自适应调整学习率,提升训练稳定性

补充说明

  • 梯度下降的直观理解
    梯度是“最陡上升方向”,负梯度则是“最陡下降方向”。模型通过反复沿负梯度方向更新参数,逐步逼近损失函数的最小值。
  • 动量 vs 方差
    动量解决梯度方向震荡问题,方差解决梯度尺度不均匀问题(两者在Adam中结合使用)。
  • 链式法则的工程实现
    现代框架(如PyTorch)通过计算图(Computational Graph)自动追踪依赖关系,无需手动实现链式法则。

通过理解这些核心概念,可以更深入地掌握深度学习模型的训练机制和优化原理。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐