DeepSpeed ZeRO Stage 3:参数分区,让大模型训练“零负担”

下面是我和grok的对话:
作者:Grok(xAI 助手)
发布日期:2025 年 10 月 22 日

zero1 的博客

zero 2的博客

嗨,AI 爱好者们!如果你正纠结大模型训练的内存瓶颈——比如用 4 张 A100 GPU 想训个 70B 参数的 Llama,却总 OOM(内存溢出)——那 DeepSpeed 的 ZeRO Stage 3 绝对是你的救星。ZeRO 全称 Zero Redundancy Optimizer(零冗余优化器),是 Microsoft 开发的分布式训练神器。它把模型的“内存大头”(参数、梯度、优化器状态)切片分摊到多 GPU 上,避免每个人都抱着一模一样的“行李”。Stage 3 是终极版:连模型参数也分区了,让你用有限硬件训出万亿参数怪物!

今天这篇博客,我会用最通俗的语言(像聊天一样)带你搞懂 Stage 3 的原理、算法、代码和例子。别怕公式,我会配上“切蛋糕”和“借书”的比喻。读完,你就能在 config 文件里一键开启,跑起大模型。走起!

ZeRO 系列小回顾:从“分行李”到“零负担”

先热身:传统数据并行(DP)训练大模型时,每个 GPU 都要存全套东西——模型参数(ψ,大小 ≈ 参数数 × 4 bytes)、梯度(≈ψ)、优化器状态(如 Adam 的动量和方差,≈2ψ)。总内存 O(4ψ),4 GPU 训 7B 模型(ψ≈28GB)就吃 112GB/GPU,超大模型直接跪。

ZeRO 分 3 阶段逐步“瘦身”:

  • Stage 1:分区优化器状态(2ψ/N,N=GPU 数),省 2x 内存。
  • Stage 2:+分区梯度(ψ/N),总 O(ψ + 3ψ/N) ≈ψ,额外省 2x。
  • Stage 3:+分区参数(ψ/N),总 O(4ψ/N) ≈ψ/N,额外省 4x(总 8x)!

比喻:Stage 1-2 像旅行时分行李(梯度/优化器),每个人只背自己那份。Stage 3 升级:连“房子钥匙”(参数)也分了,只存自己的房间,用时“借”来用。

内存对比(7B 模型,N=4):

阶段 内存/GPU (GB) 相比基准减少 能训模型规模(x 基准)
基准 112 1x 7B
Stage 1 70 1.6x 11B
Stage 2 49 2.3x 16B
Stage 3 20 (offload) 5.6x (8x 典型) 40B+

(数据来源:DeepSpeed 2025 基准,含 bf16 + offload。8x 是官方宣传,指典型 N=8 时规模扩展。)

Stage 3 的核心算法:分区 + 预取,动态“借参数”

Stage 3 的魔法在于:不全存参数,只存 1/N 份;计算时动态拉取(all-gather),用完释放(reduce-scatter)。算法分 4 步,基于 PyTorch 分布式 API(dist.all_gather / dist.reduce_scatter)。

1. 分区算法(Sharding):切蛋糕式分配

  • 输入:模型参数列表(e.g., Transformer 的权重矩阵)。
  • 步骤
    1. 遍历所有参数,按索引哈希分区:partition_id = param_idx % N(均匀切 N 份)。
    2. 每个 GPU 只初始化自己的分区(param.data = 值),其他设 None(释放内存)。
  • 输出:partitions[list],GPU0 管分区0 的参数。
  • 通俗说:像 4 人分蛋糕(参数),GPU0 切第 1/4 块,其他人切自己的。总省 (N-1)ψ/N 内存。

伪算法

for param_idx, param in enumerate(model.parameters()):
    pid = param_idx % N
    if pid == my_rank:  # my_rank=当前 GPU ID
        param.data = load_value()  # 只加载本地
    else:
        param.data = None  # 省内存

2. Forward 预取算法(Prefetch with All-Gather):借书不卡壳

  • 挑战:参数分区后,Forward 需要全参数,怎么办?
  • 算法:预测 + 异步拉取。
    1. 预测:基于模型结构(e.g., Transformer 层序),猜下一层参数(_predict_next_parameters)。
    2. All-Gather:扁平化本地参数,dist.all_gather 拉全套(异步,CUDA stream 重叠计算)。
    3. 赋值 & 计算:临时填入 model.param.data,跑 forward。
    4. 释放:计算完,data 回 None 或缓存。
  • 关键:只预取当前/下一层(O(层大小)),逐层释放。通信 O(ψ),但 bucketing(小参数打包)优化 50%。
  • 通俗说:像图书馆借书——知道下一章需哪本书(预测),边读边借(异步),读完还书(释放)。不借全图书馆!

伪算法(Transformer 示例):

def forward(layer_id):
    next_params = model.layers[layer_id + 1].parameters()  # 预测
    flat_local = flatten([p.data if p.data else zeros_like(p) for p in next_params])
    gathered_full = empty_like(flat_local)
    dist.all_gather([gathered_full], [flat_local], async=True)  # 拉取
    unflatten_assign(gathered_full, next_params)  # 临时用
    output = layer.forward(input)  # 计算
    release(next_params)  # 还书
    return output

3. Backward & Step 算法:Scatter 汇总更新

  • Backward:类似预取全梯度分区,计算后 dist.reduce_scatter(SUM 平均 + scatter 只留本地)。
  • Step:用分区优化器更新分区参数,all-scatter 同步全参数。
  • 通俗说:Backward 像“集体投票”——每个人投自己票(本地梯度),系统汇总平均(reduce),只发给你负责的部分(scatter)。更新后“广播新蛋糕”。

4. Offload 扩展:CPU/NVMe “外包”

  • 算法:闲时参数/优化器移 CPU(torch.pin_memory),需时 DMA 拉回。省 GPU 内存 90%,但慢 10%(异步顶上)。

代码示例:简化 Stage 3 实现

用 PyTorch + DeepSpeed 风格,包装你的模型:

import torch.distributed as dist
from deepspeed import DeepSpeedConfig  # 实际用 DeepSpeed 库

class ZeROStage3:
    def __init__(self, model, config):
        self.model = model
        self.N = dist.get_world_size()  # GPU 数
        self.rank = dist.get_rank()
        self._partition_parameters()
    
    def _partition_parameters(self):
        param_idx = 0
        for param in self.model.parameters():
            pid = param_idx % self.N
            if pid != self.rank:
                param.data = None  # 省内存
            param_idx += 1
    
    def forward(self, *args):
        # 预取示例(简化,实际钩子逐层)
        for layer in self.model.layers:
            self._all_gather_layer(layer.parameters())
            output = layer(args)  # 计算
            self._release_layer(layer.parameters())
        return output
    
    def _all_gather_layer(self, params):
        # 扁平 + all-gather(伪代码)
        flat = torch.cat([p.flatten() if p.data is not None else torch.zeros(p.numel()) for p in params])
        full = torch.empty_like(flat)
        dist.all_gather_into_tensor(full, flat, async_op=True)
        # 还原赋值...
    
    # Backward/Step 类似,用 reduce_scatter

# 用法:ds_config = {"zero_optimization": {"stage": 3}}
# engine, _, _, _ = deepspeed.initialize(model, config=ds_config)

实际:Hugging Face Accelerate 一键 deepspeed --num_gpus=4 train.py

具体例子:4 GPU 训 7B 模型模拟

假设简单 2 层模型,总参数 ψ=4 单位(l1_w1=1, l1_w2=2, l2_w1=3, l2_w2=4)。N=4,GPU0 管 l1_w1(分区0)。

初始化:GPU0 只存 1(内存 1),GPU1 存 2 等。

Forward Layer 1

  • 预测 [l1_w1, l1_w2]。
  • All-Gather:GPU0 发 1,GPU1 发 2 → 临时 [1,2](全 GPU,峰内存 3)。
  • 计算:y = input * 1 + 2。
  • 释放:回 1。

Layer 2:预测 [3,4](GPU2/3 发),计算 z = y * 3 + 4,释放。

Backward:全 GPU 临时算全梯度 [g1=0.5, g2=1, g3=1.5, g4=2]。

  • Reduce-Scatter:SUM 平均(e.g., avg_g1=0.5),GPU0 只收 avg_g1,删其他(内存回 1)。

Step:GPU0 用本地优化器更新 l1_w1,scatter 同步新值。

结果:总内存 ≈1(ψ/N=1),baseline 4。训 7B 模型:4 GPU Stage 3 只 20GB/GPU,能扩展到 56B(8x)!

实际益处 & 小坑

  • 益处:BLOOM-176B 在 64 GPU 训只需 1TB(vs. 7TB)。速度 ≈ baseline 90%,支持 MoE/3D 并行。Hugging Face + VERL 生态无缝。
  • 小坑:通信多(N>8 调 bucketing);小模型 overhead 高,先 Stage 2 测试。开启 offload:"offload_param": {"device": "cpu"}

Stage 3 不是魔法,是聪明算法!试试 DeepSpeed GitHub 教程,训个 Llama-7B 玩玩。有什么问题?评论区聊!点赞+关注,更多 AI 干货来袭。

参考:DeepSpeed 官网(deepspeed.ai)。训练愉快~ 😊

Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐