【RL框架】Slime 学习
函数用于为 MTP (Multi-Token Prediction) 训练构建损失掩码,将原本只覆盖 response 部分的损失掩码扩展到覆盖整个序列(prompt + response)。MTP 训练需要预测序列中的多个 token,因此需要完整的序列标签(包括 prompt 部分),而不仅仅是 response 部分。这行代码的含义是:“对于批次中的每个样本,分别取出它的总长度、响应长度和响
build_loss_mask_for_mtp 详细执行过程
build_loss_mask_for_mtp 函数用于为 MTP (Multi-Token Prediction) 训练构建损失掩码,将原本只覆盖 response 部分的损失掩码扩展到覆盖整个序列(prompt + response)。 1
函数作用和背景
MTP 训练需要预测序列中的多个 token,因此需要完整的序列标签(包括 prompt 部分),而不仅仅是 response 部分。 2 这个函数就是将标准的 RL 训练掩码转换为 MTP 训练所需的掩码。
详细执行示例
假设有一个 batch 包含 2 个样本:
输入数据:
batch = {
"tokens": torch.tensor([[101, 102, 103, 104, 105, 106, 107, 108], # 样本1: prompt(5) + response(3)
[201, 202, 203, 204, 205, 206, 207, 208]]), # 样本2: prompt(4) + response(4)
"total_lengths": [8, 8], # 总长度(prompt + response)
"response_lengths": [3, 4], # response 长度
"loss_masks": [torch.tensor([1, 1, 1]), # 样本1的response掩码
torch.tensor([1, 1, 1, 1])] # 样本2的response掩码
}
逐步执行过程:
1. 遍历每个样本
函数遍历 batch 中的每个样本,处理 total_lengths、response_lengths 和 loss_masks。 3
2. 计算 prompt 长度
对每个样本计算 prompt 长度: 4
- 样本1:
prompt_len = 8 - 3 = 5 - 样本2:
prompt_len = 8 - 4 = 4
3. 创建完整序列掩码
为每个样本创建覆盖整个序列的掩码: 5
样本1:
# 原始 response 掩码: [1, 1, 1]
# 创建全零掩码长度为 8: [0, 0, 0, 0, 0, 0, 0, 0]
# 将 response 掩码放到后半部分: [0, 0, 0, 0, 0, 1, 1, 1]
full_mask = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1])
样本2:
# 原始 response 掩码: [1, 1, 1, 1]
# 创建全零掩码长度为 8: [0, 0, 0, 0, 0, 0, 0, 0]
# 将 response 掩码放到后半部分: [0, 0, 0, 0, 1, 1, 1, 1]
full_mask = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
4. Context Parallelism 处理
如果启用了 Context Parallelism,会调用 slice_with_cp 对掩码进行切分: 6
- 假设没有启用 CP,
slice_with_cp直接返回原掩码
5. 拼接所有样本的掩码
将所有样本处理后的掩码拼接: 7
flattened_mask = torch.cat([
torch.tensor([0, 0, 0, 0, 0, 1, 1, 1]), # 样本1
torch.tensor([0, 0, 0, 0, 1, 1, 1, 1]) # 样本2
], dim=0)
# 结果: [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]
6. 长度检查和 Padding
检查掩码长度是否超过 tokens 长度,必要时进行 padding: 8
# tokens 形状: [2, 8],seq_len = 16
# flattened_mask 长度: 16,正好匹配
# 如果不匹配,会进行 padding 或截断
最终输出:
# 形状: [1, 16] (batch_size=1, seq_len=16)
loss_mask_tensor = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]])
在训练中的使用
这个构建的掩码会传递给 Megatron 的 forward 函数: 9
if args.enable_mtp_training:
loss_mask = build_loss_mask_for_mtp(batch)
mtp_kwargs = {
"mtp_labels": batch["tokens"], # 完整序列作为标签
}
output_tensor = model(
input_ids=batch["tokens"],
loss_mask=loss_mask, # 扩展后的掩码
mtp_kwargs=mtp_kwargs,
)
关键设计要点
- 掩码扩展:将 response-only 掩码扩展为 full-sequence 掩码,使 MTP 层能够学习预测 prompt 部分
- 长度对齐:确保掩码长度与 tokens 长度完全匹配 10
- CP 兼容:通过
slice_with_cp支持 Context Parallelism 6 - 批量处理:高效处理 batch 中的多个样本
Notes
- MTP 训练需要完整的序列标签(包括 prompt),因此需要扩展损失掩码 2
- 函数确保掩码长度不超过 tokens 长度 10
- 支持 Context Parallelism,通过
slice_with_cp处理序列切分 6
Wiki pages you might want to explore:
Citations
File: slime/backends/megatron_utils/model.py (L374-399)
def build_loss_mask_for_mtp(batch: dict[str, object]) -> torch.Tensor | None:
tokens_tensor: torch.Tensor = batch["tokens"]
mask_chunks: list[torch.Tensor] = []
for total_len, response_len, resp_mask in zip(
batch["total_lengths"], batch["response_lengths"], batch["loss_masks"]
):
assert (
resp_mask.numel() == response_len
), f"Unexpected loss mask size {resp_mask.numel()} (expected {response_len} or {total_len})."
prompt_len = total_len - response_len
full_mask = resp_mask.new_zeros(total_len)
full_mask[prompt_len:] = resp_mask
mask_chunks.append(slice_with_cp(full_mask, 0.0))
flattened_mask = torch.cat(mask_chunks, dim=0)
seq_len = tokens_tensor.size(-1)
assert (
flattened_mask.numel() <= seq_len
), f"MTP loss mask ({flattened_mask.numel()}) exceeds token length ({seq_len})."
# token tensor may be padded by 128, so pad loss mask to the same length
loss_mask_tensor = flattened_mask.new_zeros(seq_len)
loss_mask_tensor[: flattened_mask.numel()] = flattened_mask
return loss_mask_tensor.unsqueeze(0)
File: slime/backends/megatron_utils/model.py (L416-434)
if args.enable_mtp_training:
loss_mask = build_loss_mask_for_mtp(batch)
assert (
loss_mask.shape == batch["tokens"].shape
), f"loss_mask shape {loss_mask.shape} mismatches token shape {batch['tokens'].shape}"
mtp_kwargs = {
# We have to set labels to tokens for MTP training, to point out samples to train.
"mtp_labels": batch["tokens"],
}
output_tensor = model(
input_ids=batch["tokens"],
position_ids=None,
attention_mask=None,
labels=None,
packed_seq_params=batch["packed_seq_params"],
loss_mask=loss_mask,
**(dict(mtp_kwargs=mtp_kwargs) if mtp_kwargs is not None else {}),
)
File: docker/patch/latest/megatron.patch (L169-170)
- if mtp_in_postprocess:
+ if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None:
这是一个非常Pythonic的写法,使用了 zip() 函数来同时迭代多个序列。
让我们来详细解释一下:
1. zip() 函数
zip() 函数接收可迭代对象作为参数,并将其元素打包成一个个元组,然后返回由这些元组组成的迭代器。
例如:
list1 = [1, 2, 3]
list2 = ['a', 'b', 'c']
list3 = [True, False, True]
zipped_lists = zip(list1, list2, list3)
print(list(zipped_lists)) # Output: [(1, 'a', True), (2, 'b', False), (3, 'c', True)]
2. 在 for 循环中的应用
当 zip() 函数返回的迭代器用于 for 循环时,每次迭代都会从每个可迭代对象中取出一个元素,并将它们组合成一个元组。这个元组随后被解包(unpacking)到循环变量中。
在你的代码中:
for total_len, response_len, resp_mask in zip(
batch["total_lengths"], batch["response_lengths"], batch["loss_masks"]
):
# ... 循环体 ...
batch["total_lengths"]: 这是一个列表或张量,包含了批次中每个样本的总长度。batch["response_lengths"]: 这是一个列表或张量,包含了批次中每个样本的响应部分的长度。batch["loss_masks"]: 这是一个列表,包含了批次中每个样本的响应部分的损失掩码。
zip() 函数会做以下事情:
- 取出
batch["total_lengths"]的第一个元素。 - 取出
batch["response_lengths"]的第一个元素。 - 取出
batch["loss_masks"]的第一个元素。 - 将这三个元素打包成一个元组,例如
(first_total_len, first_response_len, first_resp_mask)。 - 在
for循环的第一次迭代中,这个元组被解包:total_len得到first_total_len的值。response_len得到first_response_len的值。resp_mask得到first_resp_mask的值。
- 循环继续,取出第二个元素、第三个元素,直到其中一个序列耗尽。
目的和优点
这种用法的主要目的是:
- 同步迭代: 它允许你同时、按顺序地处理来自多个相关数据结构的数据。在你的场景中,
total_lengths、response_lengths和loss_masks都对应于批次中的同一个样本,因此需要同时访问它们。 - 代码简洁: 避免了使用索引
i来访问batch["total_lengths"][i]、batch["response_lengths"][i]和batch["loss_masks"][i],使代码更清晰易读。 - 可读性: 循环变量
total_len,response_len,resp_mask直接反映了它们所代表的含义,使得循环体内的逻辑更容易理解。
总结
for total_len, response_len, resp_mask in zip(batch["total_lengths"], batch["response_lengths"], batch["loss_masks"]):
这行代码的含义是:“对于批次中的每个样本,分别取出它的总长度、响应长度和响应损失掩码,并将它们赋值给 total_len、response_len 和 resp_mask,然后执行循环体中的操作。” 它高效且优雅地处理了批次中每个样本的相关数据。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)