build_loss_mask_for_mtp 详细执行过程

build_loss_mask_for_mtp 函数用于为 MTP (Multi-Token Prediction) 训练构建损失掩码,将原本只覆盖 response 部分的损失掩码扩展到覆盖整个序列(prompt + response)。 1

函数作用和背景

MTP 训练需要预测序列中的多个 token,因此需要完整的序列标签(包括 prompt 部分),而不仅仅是 response 部分。 2 这个函数就是将标准的 RL 训练掩码转换为 MTP 训练所需的掩码。

详细执行示例

假设有一个 batch 包含 2 个样本:

输入数据:

batch = {
    "tokens": torch.tensor([[101, 102, 103, 104, 105, 106, 107, 108],    # 样本1: prompt(5) + response(3)
                           [201, 202, 203, 204, 205, 206, 207, 208]]),   # 样本2: prompt(4) + response(4)
    "total_lengths": [8, 8],           # 总长度(prompt + response)
    "response_lengths": [3, 4],        # response 长度
    "loss_masks": [torch.tensor([1, 1, 1]),    # 样本1的response掩码
                   torch.tensor([1, 1, 1, 1])] # 样本2的response掩码
}

逐步执行过程:

1. 遍历每个样本

函数遍历 batch 中的每个样本,处理 total_lengthsresponse_lengthsloss_masks3

2. 计算 prompt 长度

对每个样本计算 prompt 长度: 4

  • 样本1: prompt_len = 8 - 3 = 5
  • 样本2: prompt_len = 8 - 4 = 4
3. 创建完整序列掩码

为每个样本创建覆盖整个序列的掩码: 5

样本1:

# 原始 response 掩码: [1, 1, 1]
# 创建全零掩码长度为 8: [0, 0, 0, 0, 0, 0, 0, 0]
# 将 response 掩码放到后半部分: [0, 0, 0, 0, 0, 1, 1, 1]
full_mask = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1])

样本2:

# 原始 response 掩码: [1, 1, 1, 1]
# 创建全零掩码长度为 8: [0, 0, 0, 0, 0, 0, 0, 0]
# 将 response 掩码放到后半部分: [0, 0, 0, 0, 1, 1, 1, 1]
full_mask = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
4. Context Parallelism 处理

如果启用了 Context Parallelism,会调用 slice_with_cp 对掩码进行切分: 6

  • 假设没有启用 CP,slice_with_cp 直接返回原掩码
5. 拼接所有样本的掩码

将所有样本处理后的掩码拼接: 7

flattened_mask = torch.cat([
    torch.tensor([0, 0, 0, 0, 0, 1, 1, 1]),  # 样本1
    torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])   # 样本2
], dim=0)
# 结果: [0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]
6. 长度检查和 Padding

检查掩码长度是否超过 tokens 长度,必要时进行 padding: 8

# tokens 形状: [2, 8],seq_len = 16
# flattened_mask 长度: 16,正好匹配
# 如果不匹配,会进行 padding 或截断

最终输出:

# 形状: [1, 16] (batch_size=1, seq_len=16)
loss_mask_tensor = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1]])

在训练中的使用

这个构建的掩码会传递给 Megatron 的 forward 函数: 9

if args.enable_mtp_training:
    loss_mask = build_loss_mask_for_mtp(batch)
    mtp_kwargs = {
        "mtp_labels": batch["tokens"],  # 完整序列作为标签
    }

output_tensor = model(
    input_ids=batch["tokens"],
    loss_mask=loss_mask,  # 扩展后的掩码
    mtp_kwargs=mtp_kwargs,
)

关键设计要点

  1. 掩码扩展:将 response-only 掩码扩展为 full-sequence 掩码,使 MTP 层能够学习预测 prompt 部分
  2. 长度对齐:确保掩码长度与 tokens 长度完全匹配 10
  3. CP 兼容:通过 slice_with_cp 支持 Context Parallelism 6
  4. 批量处理:高效处理 batch 中的多个样本

Notes

  • MTP 训练需要完整的序列标签(包括 prompt),因此需要扩展损失掩码 2
  • 函数确保掩码长度不超过 tokens 长度 10
  • 支持 Context Parallelism,通过 slice_with_cp 处理序列切分 6

Wiki pages you might want to explore:

Citations

File: slime/backends/megatron_utils/model.py (L374-399)

        def build_loss_mask_for_mtp(batch: dict[str, object]) -> torch.Tensor | None:
            tokens_tensor: torch.Tensor = batch["tokens"]

            mask_chunks: list[torch.Tensor] = []
            for total_len, response_len, resp_mask in zip(
                batch["total_lengths"], batch["response_lengths"], batch["loss_masks"]
            ):
                assert (
                    resp_mask.numel() == response_len
                ), f"Unexpected loss mask size {resp_mask.numel()} (expected {response_len} or {total_len})."
                prompt_len = total_len - response_len
                full_mask = resp_mask.new_zeros(total_len)
                full_mask[prompt_len:] = resp_mask

                mask_chunks.append(slice_with_cp(full_mask, 0.0))

            flattened_mask = torch.cat(mask_chunks, dim=0)
            seq_len = tokens_tensor.size(-1)
            assert (
                flattened_mask.numel() <= seq_len
            ), f"MTP loss mask ({flattened_mask.numel()}) exceeds token length ({seq_len})."

            # token tensor may be padded by 128, so pad loss mask to the same length
            loss_mask_tensor = flattened_mask.new_zeros(seq_len)
            loss_mask_tensor[: flattened_mask.numel()] = flattened_mask
            return loss_mask_tensor.unsqueeze(0)

File: slime/backends/megatron_utils/model.py (L416-434)

            if args.enable_mtp_training:
                loss_mask = build_loss_mask_for_mtp(batch)
                assert (
                    loss_mask.shape == batch["tokens"].shape
                ), f"loss_mask shape {loss_mask.shape} mismatches token shape {batch['tokens'].shape}"
                mtp_kwargs = {
                    # We have to set labels to tokens for MTP training, to point out samples to train.
                    "mtp_labels": batch["tokens"],
                }

            output_tensor = model(
                input_ids=batch["tokens"],
                position_ids=None,
                attention_mask=None,
                labels=None,
                packed_seq_params=batch["packed_seq_params"],
                loss_mask=loss_mask,
                **(dict(mtp_kwargs=mtp_kwargs) if mtp_kwargs is not None else {}),
            )

File: docker/patch/latest/megatron.patch (L169-170)

-        if mtp_in_postprocess:
+        if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None:

这是一个非常Pythonic的写法,使用了 zip() 函数来同时迭代多个序列。

让我们来详细解释一下:

1. zip() 函数

zip() 函数接收可迭代对象作为参数,并将其元素打包成一个个元组,然后返回由这些元组组成的迭代器。

例如:

list1 = [1, 2, 3]
list2 = ['a', 'b', 'c']
list3 = [True, False, True]

zipped_lists = zip(list1, list2, list3)

print(list(zipped_lists)) # Output: [(1, 'a', True), (2, 'b', False), (3, 'c', True)]

2. 在 for 循环中的应用

zip() 函数返回的迭代器用于 for 循环时,每次迭代都会从每个可迭代对象中取出一个元素,并将它们组合成一个元组。这个元组随后被解包(unpacking)到循环变量中。

在你的代码中:

for total_len, response_len, resp_mask in zip(
    batch["total_lengths"], batch["response_lengths"], batch["loss_masks"]
):
    # ... 循环体 ...
  • batch["total_lengths"]: 这是一个列表或张量,包含了批次中每个样本的总长度
  • batch["response_lengths"]: 这是一个列表或张量,包含了批次中每个样本的响应部分的长度
  • batch["loss_masks"]: 这是一个列表,包含了批次中每个样本的响应部分的损失掩码

zip() 函数会做以下事情:

  1. 取出 batch["total_lengths"] 的第一个元素。
  2. 取出 batch["response_lengths"] 的第一个元素。
  3. 取出 batch["loss_masks"] 的第一个元素。
  4. 将这三个元素打包成一个元组,例如 (first_total_len, first_response_len, first_resp_mask)
  5. for 循环的第一次迭代中,这个元组被解包:
    • total_len 得到 first_total_len 的值。
    • response_len 得到 first_response_len 的值。
    • resp_mask 得到 first_resp_mask 的值。
  6. 循环继续,取出第二个元素、第三个元素,直到其中一个序列耗尽。

目的和优点

这种用法的主要目的是:

  • 同步迭代: 它允许你同时、按顺序地处理来自多个相关数据结构的数据。在你的场景中,total_lengthsresponse_lengthsloss_masks 都对应于批次中的同一个样本,因此需要同时访问它们。
  • 代码简洁: 避免了使用索引 i 来访问 batch["total_lengths"][i]batch["response_lengths"][i]batch["loss_masks"][i],使代码更清晰易读。
  • 可读性: 循环变量 total_len, response_len, resp_mask 直接反映了它们所代表的含义,使得循环体内的逻辑更容易理解。

总结

for total_len, response_len, resp_mask in zip(batch["total_lengths"], batch["response_lengths"], batch["loss_masks"]):

这行代码的含义是:“对于批次中的每个样本,分别取出它的总长度、响应长度和响应损失掩码,并将它们赋值给 total_lenresponse_lenresp_mask,然后执行循环体中的操作。” 它高效且优雅地处理了批次中每个样本的相关数据。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐