完全分片数据并行(FSDP, Fully Sharded Data Parallel)
完全分片数据并行(**FSDP, Fully Sharded Data Parallel**)是 PyTorch 提供的一种分布式训练技术,主要用于在多 GPU(甚至多节点)环境中训练**超大模型**时节省显存和提升可扩展性。
它的核心思路可以用一句话概括:
> **把模型的参数(Parameters)、梯度(Gradients)和优化器状态(Optimizer States)都切成小片(Shard),分散存储到不同 GPU 上,按需取用,计算完成后再释放。**
---
## 1. 背景
在常规的 **Data Parallel (DP)** 中:
* 每个 GPU 都保存**完整模型的副本**。
* 每次前向和反向传播都在本地 GPU 上完成。
* GPU 间同步梯度。
问题:
* 如果模型很大(比如数百亿甚至上百亿参数),单个 GPU **装不下完整的模型参数 + 梯度 + 优化器状态**。
* 即使能装下,也会浪费显存(每个 GPU 都冗余存储了相同数据)。
---
## 2. FSDP 的核心机制
FSDP 解决了 **冗余存储** 的问题,通过**完全切分 (full sharding)**:
| 存储内容 | 普通 Data Parallel | FSDP |
| ----------------------- | ---------------- | --------------------- |
| 模型参数(Parameters) | 每个 GPU 保存全量 | 被切成 N 份,每个 GPU 保存 1/N |
| 梯度(Gradients) | 每个 GPU 保存全量 | 被切成 N 份,每个 GPU 保存 1/N |
| 优化器状态(Optimizer States) | 每个 GPU 保存全量 | 被切成 N 份,每个 GPU 保存 1/N |
工作过程:
1. **初始化阶段**
* 将模型的参数按 GPU 数量分片(Sharding),每块参数只存储在一个 GPU 上。
2. **前向计算(Forward)**
* 当某个层需要计算时,FSDP 会\*\*按需广播(AllGather)\*\*该层的参数到所有 GPU。
* 完成计算后,立即释放这些全量参数(节省显存)。
3. **反向计算(Backward)**
* 计算梯度时,也会在各 GPU 上收集必要的梯度分片。
* 完成梯度计算后,把梯度分片化(ReduceScatter),只保留属于本 GPU 的那部分。
4. **优化器更新(Optimizer Step)**
* 只在本地更新自己负责的参数分片(节省内存和通信量)。
---
## 3. 直观类比
你可以把普通 Data Parallel 想象成:
> 每个人都要背**整本书**(全量模型)去考试,浪费了很多背书的时间和力气。
而 FSDP 则是:
> 把书拆成几份(参数分片),每个人只负责背其中一部分。考试时,先互相借书抄答案(AllGather),抄完就归还,减少背书负担(显存占用)。
---
## 4. FSDP 的优点
* **显存节省**:能训练远大于单卡显存容量的模型。
* **高扩展性**:支持多 GPU / 多节点。
* **减少冗余存储**:参数、梯度、优化器状态都分片存储。
---
## 5. 可能的缺点
* **通信开销大**:频繁 AllGather / ReduceScatter,可能导致网络瓶颈。
* **实现复杂**:需要正确划分模型结构(wrap layer)。
* **调参复杂**:需要合理设置 shard 策略、重计算(checkpointing)、混合精度等。
---
## 6. PyTorch FSDP 典型用法
```python
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import wrap
# 初始化分布式环境
torch.distributed.init_process_group("nccl")
# 定义模型
model = MyModel()
# FSDP 包装模型
model = FSDP(wrap(model))
# 正常训练流程
for data, target in dataloader:
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.zero_grad()
```
---
✅ 总结一句话:
**FSDP 是一种“全量模型切片+按需取用+用完释放”的分布式训练方式,能让单卡显存负担只相当于总模型大小的 1/N,从而在有限显存下训练超大模型。**
更多推荐



所有评论(0)