@contextmanager
def unwrap_model_for_generation(
    model,
    accelerator,
    gather_deepspeed3_params=True,
    gather_parameters: List = None,
):
    unwrapped_model = accelerator.unwrap_model(model)
    if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
        if not gather_deepspeed3_params:
            yield accelerator.unwrap_model(model)
        else:
            import deepspeed
            parameters = [
                parameter for name, parameter in model.named_parameters()
                if not gather_parameters or name in gather_parameters
            ]
            with deepspeed.zero.GatheredParameters(parameters):
                from trl.models.utils import remove_hooks
                remove_hooks(model)
                yield accelerator.unwrap_model(model)
                from trl.models.utils import add_hooks
                add_hooks(model)
    else:
        yield unwrapped_model


class GRPOCallback(TrainerCallback):

    def __init__(self, trainer):
        self.trainer = trainer

    # offload original_modules to cpu, to save memory
    def on_train_begin(self, args, state, control, **kwargs):
        self.trainer.queue = self.trainer.train_queue
        train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
        self.trainer._prefetch(train_dataloader)


@dataclass
class DataCache:
    inputs: List[Dict] = field(default_factory=list)
    outputs: List[Dict] = field(default_factory=list)
    distributed_idx: List[List] = field(default_factory=list)

est

Poe

这段代码是用于大语言模型训练的,使用了 AccelerateDeepSpeedTrainerCallback 以及自定义的数据缓存结构。它涉及模型解包(unwrap)和 DeepSpeed Zero Stage 3 的特殊处理。下面是对代码中每一部分逐行的详细解释:


一、unwrap_model_for_generation 上下文管理器

python

运行复制

@contextmanager
def unwrap_model_for_generation(
    model,
    accelerator,
    gather_deepspeed3_params=True,
    gather_parameters: List = None,
):
  • @contextmanager:表明该函数是一个上下文管理器(可用于 with 语句)。
  • model:传入的模型对象。
  • acceleratorAccelerate 框架的对象,用于处理多设备训练、混合精度等。
  • gather_deepspeed3_params:是否需要收集 DeepSpeed Stage 3 分散的参数。
  • gather_parameters:指定收集哪些参数的名字列表(可选)。

python

运行复制

    unwrapped_model = accelerator.unwrap_model(model)
  • 使用 accelerator.unwrap_model 解包模型,去除封装(如 DDP、FP16、Deepspeed 包装等),得到底层原始模型。

python

运行复制

    if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
  • 判断是否使用了 DeepSpeed,并且是 Zero Stage 3(该阶段会对参数进行分片,不能直接访问完整参数)。

python

运行复制

        if not gather_deepspeed3_params:
            yield accelerator.unwrap_model(model)
  • 如果不需要收集分片参数,直接 yield 解包后的模型。

python

运行复制

        else:
            import deepspeed
  • 否则导入 deepspeed,准备使用其工具收集参数。

python

运行复制

            parameters = [
                parameter for name, parameter in model.named_parameters()
                if not gather_parameters or name in gather_parameters
            ]
  • 根据 gather_parameters 筛选出要收集的参数。如果没有指定,则收集所有参数。

python

运行复制

            with deepspeed.zero.GatheredParameters(parameters):
  • 使用 Deepspeed 的上下文管理器,将 Stage 3 中被分散的参数 收集到当前进程,以便可以正常使用。

python

运行复制

                from trl.models.utils import remove_hooks
                remove_hooks(model)
  • 从模型中暂时移除 hook(比如 forward hook),防止收集参数时引发副作用。

python

运行复制

                yield accelerator.unwrap_model(model)
  • yield 解包后的模型,此时参数是完整的。

python

运行复制

                from trl.models.utils import add_hooks
                add_hooks(model)
  • 在 yield 后恢复之前移除的 hook,保持模型状态的一致性。

python

运行复制

    else:
        yield unwrapped_model
  • 如果不是 DeepSpeed Stage 3,直接 yield 解包后的模型。

二、GRPOCallback 回调类

python

运行复制

class GRPOCallback(TrainerCallback):
  • 自定义的训练回调类,继承自 TrainerCallback(用于训练过程中的事件响应)。

python

运行复制

    def __init__(self, trainer):
        self.trainer = trainer
  • 初始化时将 trainer 对象保存为成员变量,以便后续使用。

python

运行复制

    def on_train_begin(self, args, state, control, **kwargs):
  • 训练开始时回调的函数。

python

运行复制

        self.trainer.queue = self.trainer.train_queue
  • train_queue 赋值给 queue,可能用于数据加载或任务队列处理。

python

运行复制

        train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
  • 尝试从 statekwargs 中获取训练数据加载器。

python

运行复制

        self.trainer._prefetch(train_dataloader)
  • 调用 _prefetch 方法预加载数据,提高训练效率。

三、DataCache 数据缓存结构

python

运行复制

@dataclass
class DataCache:
  • 使用 dataclass 声明一个简单的数据结构类 DataCache

python

运行复制

    inputs: List[Dict] = field(default_factory=list)
  • 存储输入数据,类型为字典的列表。

python

运行复制

    outputs: List[Dict] = field(default_factory=list)
  • 存储模型输出,类型为字典的列表。

python

运行复制

    distributed_idx: List[List] = field(default_factory=list)
  • 存储分布式训练场景下的索引映射(可能用于还原或聚合)。

总结

这段代码展示了:

  • 如何处理 DeepSpeed Stage 3 参数收集。
  • 如何自定义训练回调以优化内存和数据加载。
  • 如何定义训练过程中的输入输出缓存结构。

它是构建一个高效、可扩展的大模型训练框架的重要组件。

class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
    executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)

    def __init__(self,
                 model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
                 reward_funcs: Optional[List[Union[str, Callable]]] = None,
                 *_args,
                 **kwargs):
        from swift.trainers.rlhf_arguments import GRPOConfig
        args: GRPOConfig = kwargs['args']
        self.args = args
        self.train_queue = Queue()
        self.eval_queue = Queue()
        self.processing_class = kwargs.get('template').tokenizer
        self.offload_modules = {}
        self.offload_states = {}
        _, _, _, local_world_size = get_dist_setting()

        if not isinstance(reward_funcs, list):
            reward_funcs = [reward_funcs]

        if reward_funcs:
            for i, reward_func in enumerate(reward_funcs):
                if reward_func in orms:
                    reward_func_class = orms[reward_func]
                    reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
                    reward_func_kwargs = {
                        key: getattr(args, key)
                        for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
                    }
                    if 'tokenizer' in reward_func_args:
                        reward_func_kwargs['tokenizer'] = self.processing_class
                    reward_funcs[i] = reward_func_class(**reward_func_kwargs)
                elif not callable(reward_func):
                    raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin')

        self.reward_funcs = reward_funcs

        self.multi_turn_func = None
        if self.args.multi_turn_func:
            if isinstance(self.args.multi_turn_func, str):
                assert self.args.multi_turn_func in multi_turns
                multi_turn_func = multi_turns[self.args.multi_turn_func]
                self.multi_turn_func = multi_turn_func
            else:
                self.multi_turn_func = self.args.multi_turn_func

        self.reward_templates = [None] * len(self.reward_funcs)
        if reward_model is not None:
            self.reward_templates.append(kwargs.pop('reward_template', None))
            self.reward_funcs.append(reward_model)
        if not self.reward_funcs:
            raise ValueError('You must specify reward_funcs or reward_model')

        # Reward weights
        if args.reward_weights is not None:
            if len(args.reward_weights) != len(reward_funcs):
                raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward '
                                 f'functions ({len(reward_funcs)})')
            self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
        else:
            self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)

        self.num_generations = args.num_generations
        self.temperature = args.temperature
        model.warnings_issued['estimate_tokens'] = True
        kwargs['data_collator'] = lambda features: features
        self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}

        use_vllm = args.use_vllm
        use_lmdeploy = args.use_lmdeploy

        if self.args.tensor_parallel_size > 1 and self.multi_turn_func:
            import torch.distributed as dist
            rank, _, _, _ = get_dist_setting()
            for tp_group in self.tp_group_ranks():
                group = dist.new_group(tp_group)
                if rank in tp_group:
                    self.group = group

        super().__init__(model, ref_model, *_args, **kwargs)

        num_processes = self.accelerator.num_processes
        global_batch_size = args.per_device_train_batch_size * num_processes
        possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
        if self.num_generations not in possible_values:
            raise ValueError(
                f'The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly '
                f'divisible by the number of generations per prompt ({self.num_generations}). Given the current train '
                f'batch size, the valid values for the number of generations are: {possible_values}.')
        if self.args.eval_strategy != 'no':
            global_batch_size = args.per_device_eval_batch_size * num_processes
            possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
            if self.num_generations not in possible_values:
                raise ValueError(
                    f'The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly '
                    f'divisible by the number of generations per prompt ({self.num_generations}). Given the current '
                    f'eval batch size, the valid values for the number of generations are: {possible_values}.')

        # Ensure each process receives a unique seed to prevent duplicate completions when generating with
        # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
        # it's safer to set it in all cases.
        set_seed(args.seed, device_specific=True)
        self.parameter_groups, self.parameter_groups_no_lora = self.split_batches()
        self.infer_device = None

        if use_vllm or use_lmdeploy:
            if self.infer_rank >= 0:
                fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device
                if fast_infer_device[0] == 'auto':
                    if get_device_count() == 1:
                        fast_infer_device = [get_device()]  # particular case when training with only 1 GPU: share it
                    else:
                        fast_infer_device = []
                        for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()):
                            fast_infer_device.append(get_device(idx))

                for _device in fast_infer_device:
                    # Check that the requested device is available
                    if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count():
                        raise ValueError(f'The requested device for vllm ({_device}) is not available. '
                                         f'You are likely using vLLM '
                                         'without restricting the number of GPUs for training. '
                                         'Set the `--num_processes` argument to a '
                                         'value lower than the number of GPUs available on your machine—typically, '
                                         'reducing it by one is sufficient. '
                                         f'In your case: `--num_processes {get_device_count() - 1}`.')

                if use_vllm:
                    if not is_vllm_available():
                        raise ImportError('vLLM is not available and `use_vllm` is set to True. '
                                          'Please install vLLM with `pip install vllm -U` to use it.')
                    self.prepare_vllm(model, fast_infer_device)
                    self.infer_device = fast_infer_device[self.local_infer_rank]
                elif use_lmdeploy:
                    if not is_lmdeploy_available():
                        raise ImportError('LMDeploy is not available and `use_lmdeploy` is set to True.'
                                          'Please install LMDeploy with `pip install lmdeploy -U` to use it.')
                    from swift.llm import LmdeployEngine
                    from swift.tuners import Swift
                    with Swift.grpo_context(model, self.template.processor):
                        fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1])
                        self.engine = LmdeployEngine(
                            model.model_dir,
                            model.model_info.torch_dtype,
                            model_type=model.model_meta.model_type,
                            devices=[fast_infer_device],
                            session_len=args.lmdeploy_session_len,
                            cache_max_entry_count=args.lmdeploy_cache_max_entry_count,
                            reload_weights=True)
                        self.infer_device = fast_infer_device
                    self.engine.default_template = copy(self.template)  # Avoid thread-unsafe modifications of the mode.
            self._last_loaded_step = 0  # tag to avoid useless loading during grad accumulation

            # When using vLLM, the main process is responsible for loading the model weights. This can cause process
            # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
            # synchronize all processes after vLLM has been fully initialized.
            self.accelerator.wait_for_everyone()
        else:
            from swift.llm import PtEngine
            self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0)  # 0: no limit
        # Avoid thread-unsafe modifications of the mode.
        self.request_config = RequestConfig(
            max_tokens=args.max_completion_length,
            temperature=args.temperature,
            top_p=args.top_p,
            top_k=args.top_k,
            repetition_penalty=args.repetition_penalty,
            stop=args.stop_words,
        )

o-Latest

Poe

这段代码定义了 GRPOTrainer 类,是 Swift 框架下用于强化学习人类反馈(RLHF)的大模型训练器,集成了多个模块(如 vLLM、LMDeploy、DeepSpeed 等)。我们逐行详细解析整个类的逻辑与作用。


🔶 类定义

python

运行复制

class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
  • 继承自多个混入类RLHFTrainerMixinSwiftMixinHFGRPOTrainer,它们提供了 RLHF 训练、Swift 框架支持和 HuggingFace Trainer 的功能。
  • GRPOTrainer 是用于 GRPO(Preference Optimization)训练流程的核心类。

python

运行复制

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
  • 创建一个线程池执行器对象,最大线程数为 1,可能用于异步任务(如推理、日志记录等)。

🔷 初始化函数 __init__

python

运行复制

def __init__(self, model, ref_model, reward_model, reward_funcs, *_args, **kwargs):
  • 构造函数,接收多个模型(主模型、参考模型、奖励模型)和奖励函数列表,以及其他参数。

python

运行复制

from swift.trainers.rlhf_arguments import GRPOConfig
args: GRPOConfig = kwargs['args']
self.args = args
  • 导入并提取训练配置,args 是一个包含所有训练超参的对象。

python

运行复制

self.train_queue = Queue()
self.eval_queue = Queue()
  • 创建训练队列和评估队列(线程安全),用于异步数据处理。

python

运行复制

self.processing_class = kwargs.get('template').tokenizer
  • 获取模板中的 tokenizer,用于后续文本处理。

python

运行复制

self.offload_modules = {}
self.offload_states = {}
  • 用于记录被 CPU/offload 的模块状态和数据,节省内存。

python

运行复制

_, _, _, local_world_size = get_dist_setting()
  • 获取分布式训练的设置,local_world_size 是当前节点的 GPU 数目。

🎯 奖励函数初始化

python

运行复制

if not isinstance(reward_funcs, list):
    reward_funcs = [reward_funcs]
  • 将单个奖励函数包装成列表。

python

运行复制

for i, reward_func in enumerate(reward_funcs):
  • 遍历每个奖励函数。

python

运行复制

if reward_func in orms:
  • orms 是一个注册表,查找是否为内置奖励函数。

python

运行复制

reward_func_class = orms[reward_func]
reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
  • 获取奖励类构造函数的参数。

python

运行复制

reward_func_kwargs = {
    key: getattr(args, key)
    for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
}
  • 根据配置构造奖励函数参数字典。

python

运行复制

if 'tokenizer' in reward_func_args:
    reward_func_kwargs['tokenizer'] = self.processing_class
reward_funcs[i] = reward_func_class(**reward_func_kwargs)
  • 实例化奖励函数对象。

python

运行复制

elif not callable(reward_func):
    raise ValueError(...)
  • 如果不是可调用对象,也不是注册名,抛出异常。

python

运行复制

self.reward_funcs = reward_funcs
  • 保存奖励函数列表。

🤖 多轮对话处理函数设置

python

运行复制

self.multi_turn_func = None
if self.args.multi_turn_func:
  • 多轮对话处理函数。

python

运行复制

if isinstance(self.args.multi_turn_func, str):
    ...
else:
    self.multi_turn_func = self.args.multi_turn_func
  • 支持字符串引用或直接传入函数。

🧠 奖励模型与模板

python

运行复制

self.reward_templates = [None] * len(self.reward_funcs)
if reward_model is not None:
    self.reward_templates.append(kwargs.pop('reward_template', None))
    self.reward_funcs.append(reward_model)
  • reward_model 也作为奖励函数之一加入。

python

运行复制

if not self.reward_funcs:
    raise ValueError(...)
  • 如果没有任何奖励函数或模型,抛出错误。

⚖️ 奖励权重处理

python

运行复制

if args.reward_weights is not None:
    ...
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
  • 设置奖励函数的权重,确保数量一致。

python

运行复制

else:
    self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
  • 默认所有权重为 1。

🔧 训练参数配置

python

运行复制

self.num_generations = args.num_generations
self.temperature = args.temperature
model.warnings_issued['estimate_tokens'] = True
  • 设置生成数量、温度,并记录 token 估计警告。

python

运行复制

kwargs['data_collator'] = lambda features: features
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
  • 设置数据收集器(跳过)和初始化指标缓存。

⚙️ 分布式相关:Tensor Parallel

python

运行复制

if self.args.tensor_parallel_size > 1 and self.multi_turn_func:
    import torch.distributed as dist
    rank, _, _, _ = get_dist_setting()
    for tp_group in self.tp_group_ranks():
        group = dist.new_group(tp_group)
        if rank in tp_group:
            self.group = group
  • 创建 tensor parallel 分组,用于跨卡共享模型参数。

🧬 调用父类初始化

python

运行复制

super().__init__(model, ref_model, *_args, **kwargs)
  • 调用 HFGRPOTrainer 的构造函数。

✅ 校验 num_generations 合法性

python

运行复制

num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
...
if self.num_generations not in possible_values:
    raise ValueError(...)
  • 确保 num_generations 能被全局 batch size 整除。

🔐 设置随机种子

python

运行复制

set_seed(args.seed, device_specific=True)
  • 保证每个进程种子不同,避免重复生成。

📦 模型推理设备配置

python

运行复制

if use_vllm or use_lmdeploy:
    if self.infer_rank >= 0:
        ...
        for _device in fast_infer_device:
            ...
  • 自动推断推理设备,分配给 infer worker。

🔍 vLLM 初始化

python

运行复制

if use_vllm:
    if not is_vllm_available():
        raise ImportError(...)
    self.prepare_vllm(model, fast_infer_device)
    self.infer_device = fast_infer_device[self.local_infer_rank]
  • 初始化 vLLM 推理引擎。

🧠 LMDeploy 初始化

python

运行复制

elif use_lmdeploy:
    ...
    self.engine = LmdeployEngine(...)
    self.infer_device = fast_infer_device
  • 初始化 LMDeploy 推理引擎。

python

运行复制

self.engine.default_template = copy(self.template)
  • 避免多线程对模板的修改。

🧱 PtEngine 作为默认推理引擎

python

运行复制

else:
    from swift.llm import PtEngine
    self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0)
  • 使用 Swift 自带引擎,如果没有使用 vLLM/LMDeploy。

🔧 请求配置初始化

python

运行复制

self.request_config = RequestConfig(
    max_tokens=args.max_completion_length,
    temperature=args.temperature,
    top_p=args.top_p,
    top_k=args.top_k,
    repetition_penalty=args.repetition_penalty,
    stop=args.stop_words,
)
  • 设置生成请求的参数,如生成长度、采样策略、惩罚系数等。

✅ 总结

该类是 Swift 框架中 GRPO 算法的核心训练器,完成了如下功能:

  1. 初始化模型、参考模型、奖励模型与奖励函数。
  2. 校验配置项合法性(如 batch size 可整除)。
  3. 支持多种推理后端(vLLM、LMDeploy、PtEngine)。
  4. 支持 Tensor Parallel、分布式训练。
  5. 提供异步推理、生成配置等功能。

如果你需要,我还可以为你绘制结构图和类依赖关系图,帮助更直观理解。

if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
            self.request_config.n = self.args.tensor_parallel_size
            if self.infer_rank >= 0:
                self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size

        self.model_accepts_loss_kwargs = False
        for i, reward_func in enumerate(self.reward_funcs):
            if isinstance(reward_func, PreTrainedModel):
                self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
        self.log_completions = args.log_completions
        self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))

        # Multi-step
        self.num_iterations = args.num_iterations  # = 𝜇 in the GRPO paper
        self.epsilon_low = args.epsilon
        self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon

        # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa
        self._step = 0
        # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
        # `_get_train_sampler` and `_prepare_inputs`.
        self._buffered_inputs = [None] * args.gradient_accumulation_steps
        if self.args.async_generate:
            self.add_callback(GRPOCallback(self))
        self.set_multi_turn_engine_default_max_tokens()

    def split_batches(self):
        """Sync weights in batches
        Only split LLM layers for now:
        1. N batches for layers
        2. other, embeds, lm_heads in one batch
        3. multi-modal components in one batch
        """
        model = self.accelerator.unwrap_model(self.model)
        if self.args.move_model_batches is None:
            # All in one
            return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None]

        model_arch = get_model_arch(model.model_meta.model_arch)
        non_llm_parameters = []
        llm_embeds = []
        parameters = []
        pattern = r'\.(\d+)\.'

        layer_count = None
        # Get the number of layers in LLM modules
        for name, module in model.named_modules():
            if isinstance(module, ModuleList):
                if model_arch is not None and isinstance(model_arch, MultiModelKeys):
                    llm = model_arch.language_model
                    vision_tower = model_arch.vision_tower
                    if any(vt in name for vt in vision_tower):
                        continue
                    if isinstance(llm, list):
                        llm = llm[0]
                    if name.startswith('base_model'):
                        name = name.replace('base_model.', '')
                    if llm in name:
                        layer_count = len(module)
                else:
                    layer_count = len(module)
        assert layer_count is not None, 'Cannot find ModuleList to split modules.'

        n_layers = ceil(layer_count / self.args.move_model_batches)
        for _ in range(self.args.move_model_batches):
            parameters.append([])

        def replace_lora(name):
            if 'lora_' in name:
                return ''
            else:
                return name.replace('base_layer.', '')

        def remove_lora_and_prefix(names):
            names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names])
            return [n for n in names if n]

        def split_llm(name):
            match = re.search(pattern, name)
            if match:
                number = match.group(1)
                group = int(number) // n_layers
                parameters[group].append(name)
            else:
                llm_embeds.append(name)

        for name, parameter in model.named_parameters():
            if 'ref_model' in name:
                continue
            if model_arch is not None and isinstance(model_arch, MultiModelKeys):
                llm = model_arch.language_model
                vision_tower = model_arch.vision_tower
                if any(vt in name for vt in vision_tower):
                    non_llm_parameters.append(name)
                elif isinstance(llm, list):
                    llm = llm[0]
                    if llm in name:
                        split_llm(name)
                    else:
                        non_llm_parameters.append(name)
            else:
                split_llm(name)

        if llm_embeds:
            parameters.append(llm_embeds)
        if non_llm_parameters:
            parameters.append(non_llm_parameters)
        parameters = [p for p in parameters if p]
        parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters]
        return parameters, parameters_no_lora

    def prepare_vllm(self, model, fast_infer_device):
        from swift.tuners import Swift
        from swift.llm import VllmEngine
        from swift.llm.infer.infer_engine import GRPOVllmEngine
        _, _, _, local_world_size = get_dist_setting()
        if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
            # Compatibility with TP
            cls = GRPOVllmEngine
        else:
            cls = VllmEngine
        with Swift.grpo_context(model, self.template.processor):
            self.engine = cls(
                model.model_dir,
                model.model_info.torch_dtype,
                model_type=model.model_meta.model_type,
                device=fast_infer_device[self.local_infer_rank],
                tensor_parallel_size=self.args.tensor_parallel_size,
                gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
                enable_prefix_caching=self.args.vllm_enable_prefix_caching,
                max_num_seqs=self.args.vllm_max_num_seqs,
                enforce_eager=self.args.vllm_enforce_eager,
                limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt,
                num_infer_workers=self.args.num_infer_workers,
                enable_sleep_mode=self.args.sleep_level > 0,
                use_async_engine=False,
                distributed_executor_backend='external_launcher',
                max_model_len=self.args.vllm_max_model_len)
            self.engine.default_template = self.template

atest

Poe

我们继续逐行解释你提供的 Swift 框架中 GRPOTrainer 的后续代码。这部分主要涉及:

  • 推理设备的配置
  • 奖励模型准备
  • 多轮训练参数设置
  • 模型参数的分批管理(offload)
  • vLLM 的初始化等

🔧 设置推理配置(request_config)

python

运行复制

if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
  • 如果本地 GPU 数量和推理工作线程数一致,并且 GPU 数量大于 1,说明使用了 Tensor Parallel(TP)推理

python

运行复制

    self.request_config.n = self.args.tensor_parallel_size
  • 设置每个 prompt 的生成数(n)为 TP 大小,用于并行生成。

python

运行复制

    if self.infer_rank >= 0:
        self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size
  • 如果当前是推理进程,为每个 TP 分组设置不同的随机种子,避免生成重复内容。

🧪 奖励模型准备

python

运行复制

self.model_accepts_loss_kwargs = False
  • 标记模型是否接受 loss_kwargs,默认不支持。

python

运行复制

for i, reward_func in enumerate(self.reward_funcs):
    if isinstance(reward_func, PreTrainedModel):
        self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
  • 如果奖励函数是模型对象,使用 Accelerate 框架将其封装并切换为评估模式。

📁 完成记录和日志初始化

python

运行复制

self.log_completions = args.log_completions
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))
  • 是否记录生成内容,并初始化写入 .jsonl 文件的对象。

🔁 GRPO 多步迭代参数设置(对应论文中的 μ 和 ε)

python

运行复制

self.num_iterations = args.num_iterations  # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
  • 设置 GRPO 的核心迭代次数 (μ) 和 epsilon 探索参数。

🔄 训练状态相关变量

python

运行复制

self._step = 0
  • 当前训练步数。

python

运行复制

self._buffered_inputs = [None] * args.gradient_accumulation_steps
  • 缓存梯度累积周期内的输入数据。

python

运行复制

if self.args.async_generate:
    self.add_callback(GRPOCallback(self))
  • 如果启用了异步生成,注册自定义回调函数。

python

运行复制

self.set_multi_turn_engine_default_max_tokens()
  • 设置多轮对话生成的默认最大 token 数。

🔀 模型参数分批(offload)处理:split_batches

python

运行复制

def split_batches(self):
  • 将模型参数分组,以支持分批迁移(例如 CPU/GPU offload)。

python

运行复制

model = self.accelerator.unwrap_model(self.model)
  • 解包模型,去除 Accelerator 封装层。

python

运行复制

if self.args.move_model_batches is None:
    return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None]
  • 如果未设置分批数量,则所有参数一组。

🔍 模型结构识别与分层

python

运行复制

model_arch = get_model_arch(model.model_meta.model_arch)
non_llm_parameters = []
llm_embeds = []
parameters = []
pattern = r'\.(\d+)\.'
  • 获取模型结构,准备分类参数。
  • 正则从参数名中提取层数编号。

python

运行复制

layer_count = None
for name, module in model.named_modules():
    if isinstance(module, ModuleList):
        ...
        layer_count = len(module)
  • 遍历模型模块,查找 LLM 的层数。

python

运行复制

assert layer_count is not None, 'Cannot find ModuleList to split modules.'
  • 必须找到层数信息,否则报错。

📦 分批构建

python

运行复制

n_layers = ceil(layer_count / self.args.move_model_batches)
for _ in range(self.args.move_model_batches):
    parameters.append([])
  • 按照层数将参数平均分成 move_model_batches 组。

🔧 工具函数(去掉 lora、前缀名)

python

运行复制

def replace_lora(name): ...
def remove_lora_and_prefix(names): ...
  • 清理参数名,去除 Lora 层、模型封装前缀。

🧩 按照层号进行分组

python

运行复制

def split_llm(name):
    match = re.search(pattern, name)
    if match:
        number = match.group(1)
        group = int(number) // n_layers
        parameters[group].append(name)
    else:
        llm_embeds.append(name)
  • 根据参数名中的层号对参数分组。

🧹 参数分类

python

运行复制

for name, parameter in model.named_parameters():
    if 'ref_model' in name:
        continue
    ...
    split_llm(name)
  • 把 LLM 层、embedding 层、非LLM 参数分类并分组。

✅ 返回最终分组结果

python

运行复制

if llm_embeds: parameters.append(llm_embeds)
if non_llm_parameters: parameters.append(non_llm_parameters)
parameters = [p for p in parameters if p]
parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters]
return parameters, parameters_no_lora
  • 返回两个参数列表:一个是完整参数分组,一个是去除 Lora 和前缀后的名称。

🚀 vLLM 引擎初始化:prepare_vllm

python

运行复制

def prepare_vllm(self, model, fast_infer_device):
  • 使用 vLLM 后端进行高性能推理。

python

运行复制

from swift.tuners import Swift
from swift.llm import VllmEngine
from swift.llm.infer.infer_engine import GRPOVllmEngine
  • 导入推理引擎。

python

运行复制

_, _, _, local_world_size = get_dist_setting()
if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
    cls = GRPOVllmEngine
else:
    cls = VllmEngine
  • 根据是否为 TP 设置选择合适的引擎类。

⚙️ 使用 grpo_context 上下文管理器

python

运行复制

with Swift.grpo_context(model, self.template.processor):
  • 使用 grpo_context 临时启用 Swift 的上下文环境(可能设置模型为 eval,关闭 hook)。

🧠 初始化 vLLM 引擎实例

python

运行复制

self.engine = cls(
    model.model_dir,
    model.model_info.torch_dtype,
    model_type=model.model_meta.model_type,
    device=fast_infer_device[self.local_infer_rank],
    tensor_parallel_size=self.args.tensor_parallel_size,
    gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
    enable_prefix_caching=self.args.vllm_enable_prefix_caching,
    max_num_seqs=self.args.vllm_max_num_seqs,
    enforce_eager=self.args.vllm_enforce_eager,
    limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt,
    num_infer_workers=self.args.num_infer_workers,
    enable_sleep_mode=self.args.sleep_level > 0,
    use_async_engine=False,
    distributed_executor_backend='external_launcher',
    max_model_len=self.args.vllm_max_model_len)
  • 创建并配置 vLLM 推理引擎,参数包括模型路径、数据类型、推理设备、并行设置、缓存策略等。

python

运行复制

self.engine.default_template = self.template
  • 设置默认模板,避免多线程访问时出错。

✅ 总结

这一部分代码主要做了以下关键事情:

  • 根据分布式设备情况配置生成参数。
  • 初始化奖励模型、日志工具。
  • 支持多轮对话的设置。
  • 将大模型参数智能分批,便于 offload 管理。
  • 初始化 vLLM 推理引擎,提升生成效率。

如果你需要图示解释,比如:

  • 模型分批结构图
  • 推理引擎架构图
  • 参数分组可视化
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐