swift GRPOTrainer类 源码解析
如何处理 DeepSpeed Stage 3 参数收集。如何自定义训练回调以优化内存和数据加载。如何定义训练过程中的输入输出缓存结构。它是构建一个高效、可扩展的大模型训练框架的重要组件。*_args,**kwargs):else:else:= 'no':else:else:o-LatestPoe这段代码定义了类,是 Swift 框架下用于强化学习人类反馈(RLHF)的大模型训练器,集成了多个模块(
@contextmanager
def unwrap_model_for_generation(
model,
accelerator,
gather_deepspeed3_params=True,
gather_parameters: List = None,
):
unwrapped_model = accelerator.unwrap_model(model)
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed
parameters = [
parameter for name, parameter in model.named_parameters()
if not gather_parameters or name in gather_parameters
]
with deepspeed.zero.GatheredParameters(parameters):
from trl.models.utils import remove_hooks
remove_hooks(model)
yield accelerator.unwrap_model(model)
from trl.models.utils import add_hooks
add_hooks(model)
else:
yield unwrapped_model
class GRPOCallback(TrainerCallback):
def __init__(self, trainer):
self.trainer = trainer
# offload original_modules to cpu, to save memory
def on_train_begin(self, args, state, control, **kwargs):
self.trainer.queue = self.trainer.train_queue
train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
self.trainer._prefetch(train_dataloader)
@dataclass
class DataCache:
inputs: List[Dict] = field(default_factory=list)
outputs: List[Dict] = field(default_factory=list)
distributed_idx: List[List] = field(default_factory=list)
Poe
这段代码是用于大语言模型训练的,使用了 Accelerate、DeepSpeed、TrainerCallback 以及自定义的数据缓存结构。它涉及模型解包(unwrap)和 DeepSpeed Zero Stage 3 的特殊处理。下面是对代码中每一部分逐行的详细解释:
一、unwrap_model_for_generation 上下文管理器
python
运行复制
@contextmanager
def unwrap_model_for_generation(
model,
accelerator,
gather_deepspeed3_params=True,
gather_parameters: List = None,
):
@contextmanager:表明该函数是一个上下文管理器(可用于with语句)。model:传入的模型对象。accelerator:Accelerate框架的对象,用于处理多设备训练、混合精度等。gather_deepspeed3_params:是否需要收集 DeepSpeed Stage 3 分散的参数。gather_parameters:指定收集哪些参数的名字列表(可选)。
python
运行复制
unwrapped_model = accelerator.unwrap_model(model)
- 使用
accelerator.unwrap_model解包模型,去除封装(如 DDP、FP16、Deepspeed 包装等),得到底层原始模型。
python
运行复制
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
- 判断是否使用了 DeepSpeed,并且是 Zero Stage 3(该阶段会对参数进行分片,不能直接访问完整参数)。
python
运行复制
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
- 如果不需要收集分片参数,直接 yield 解包后的模型。
python
运行复制
else:
import deepspeed
- 否则导入
deepspeed,准备使用其工具收集参数。
python
运行复制
parameters = [
parameter for name, parameter in model.named_parameters()
if not gather_parameters or name in gather_parameters
]
- 根据
gather_parameters筛选出要收集的参数。如果没有指定,则收集所有参数。
python
运行复制
with deepspeed.zero.GatheredParameters(parameters):
- 使用 Deepspeed 的上下文管理器,将 Stage 3 中被分散的参数 收集到当前进程,以便可以正常使用。
python
运行复制
from trl.models.utils import remove_hooks
remove_hooks(model)
- 从模型中暂时移除 hook(比如 forward hook),防止收集参数时引发副作用。
python
运行复制
yield accelerator.unwrap_model(model)
- yield 解包后的模型,此时参数是完整的。
python
运行复制
from trl.models.utils import add_hooks
add_hooks(model)
- 在 yield 后恢复之前移除的 hook,保持模型状态的一致性。
python
运行复制
else:
yield unwrapped_model
- 如果不是 DeepSpeed Stage 3,直接 yield 解包后的模型。
二、GRPOCallback 回调类
python
运行复制
class GRPOCallback(TrainerCallback):
- 自定义的训练回调类,继承自
TrainerCallback(用于训练过程中的事件响应)。
python
运行复制
def __init__(self, trainer):
self.trainer = trainer
- 初始化时将
trainer对象保存为成员变量,以便后续使用。
python
运行复制
def on_train_begin(self, args, state, control, **kwargs):
- 训练开始时回调的函数。
python
运行复制
self.trainer.queue = self.trainer.train_queue
- 将
train_queue赋值给queue,可能用于数据加载或任务队列处理。
python
运行复制
train_dataloader = getattr(state, 'train_dataloader', None) or kwargs.get('train_dataloader')
- 尝试从
state或kwargs中获取训练数据加载器。
python
运行复制
self.trainer._prefetch(train_dataloader)
- 调用
_prefetch方法预加载数据,提高训练效率。
三、DataCache 数据缓存结构
python
运行复制
@dataclass
class DataCache:
- 使用
dataclass声明一个简单的数据结构类DataCache。
python
运行复制
inputs: List[Dict] = field(default_factory=list)
- 存储输入数据,类型为字典的列表。
python
运行复制
outputs: List[Dict] = field(default_factory=list)
- 存储模型输出,类型为字典的列表。
python
运行复制
distributed_idx: List[List] = field(default_factory=list)
- 存储分布式训练场景下的索引映射(可能用于还原或聚合)。
总结
这段代码展示了:
- 如何处理 DeepSpeed Stage 3 参数收集。
- 如何自定义训练回调以优化内存和数据加载。
- 如何定义训练过程中的输入输出缓存结构。
它是构建一个高效、可扩展的大模型训练框架的重要组件。
class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def __init__(self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
reward_funcs: Optional[List[Union[str, Callable]]] = None,
*_args,
**kwargs):
from swift.trainers.rlhf_arguments import GRPOConfig
args: GRPOConfig = kwargs['args']
self.args = args
self.train_queue = Queue()
self.eval_queue = Queue()
self.processing_class = kwargs.get('template').tokenizer
self.offload_modules = {}
self.offload_states = {}
_, _, _, local_world_size = get_dist_setting()
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
if reward_funcs:
for i, reward_func in enumerate(reward_funcs):
if reward_func in orms:
reward_func_class = orms[reward_func]
reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
reward_func_kwargs = {
key: getattr(args, key)
for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
}
if 'tokenizer' in reward_func_args:
reward_func_kwargs['tokenizer'] = self.processing_class
reward_funcs[i] = reward_func_class(**reward_func_kwargs)
elif not callable(reward_func):
raise ValueError(f'reward_function {reward_func} is not implemented in swift.llm.plugin')
self.reward_funcs = reward_funcs
self.multi_turn_func = None
if self.args.multi_turn_func:
if isinstance(self.args.multi_turn_func, str):
assert self.args.multi_turn_func in multi_turns
multi_turn_func = multi_turns[self.args.multi_turn_func]
self.multi_turn_func = multi_turn_func
else:
self.multi_turn_func = self.args.multi_turn_func
self.reward_templates = [None] * len(self.reward_funcs)
if reward_model is not None:
self.reward_templates.append(kwargs.pop('reward_template', None))
self.reward_funcs.append(reward_model)
if not self.reward_funcs:
raise ValueError('You must specify reward_funcs or reward_model')
# Reward weights
if args.reward_weights is not None:
if len(args.reward_weights) != len(reward_funcs):
raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward '
f'functions ({len(reward_funcs)})')
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
self.num_generations = args.num_generations
self.temperature = args.temperature
model.warnings_issued['estimate_tokens'] = True
kwargs['data_collator'] = lambda features: features
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
use_vllm = args.use_vllm
use_lmdeploy = args.use_lmdeploy
if self.args.tensor_parallel_size > 1 and self.multi_turn_func:
import torch.distributed as dist
rank, _, _, _ = get_dist_setting()
for tp_group in self.tp_group_ranks():
group = dist.new_group(tp_group)
if rank in tp_group:
self.group = group
super().__init__(model, ref_model, *_args, **kwargs)
num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f'The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly '
f'divisible by the number of generations per prompt ({self.num_generations}). Given the current train '
f'batch size, the valid values for the number of generations are: {possible_values}.')
if self.args.eval_strategy != 'no':
global_batch_size = args.per_device_eval_batch_size * num_processes
possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
if self.num_generations not in possible_values:
raise ValueError(
f'The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly '
f'divisible by the number of generations per prompt ({self.num_generations}). Given the current '
f'eval batch size, the valid values for the number of generations are: {possible_values}.')
# Ensure each process receives a unique seed to prevent duplicate completions when generating with
# transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
# it's safer to set it in all cases.
set_seed(args.seed, device_specific=True)
self.parameter_groups, self.parameter_groups_no_lora = self.split_batches()
self.infer_device = None
if use_vllm or use_lmdeploy:
if self.infer_rank >= 0:
fast_infer_device = self.args.vllm_device or self.args.lmdeploy_device
if fast_infer_device[0] == 'auto':
if get_device_count() == 1:
fast_infer_device = [get_device()] # particular case when training with only 1 GPU: share it
else:
fast_infer_device = []
for idx in range(get_device_count() - self.args.num_infer_workers, get_device_count()):
fast_infer_device.append(get_device(idx))
for _device in fast_infer_device:
# Check that the requested device is available
if _device.split(':')[0] in {'cuda', 'npu'} and int(_device.split(':')[1]) >= get_device_count():
raise ValueError(f'The requested device for vllm ({_device}) is not available. '
f'You are likely using vLLM '
'without restricting the number of GPUs for training. '
'Set the `--num_processes` argument to a '
'value lower than the number of GPUs available on your machine—typically, '
'reducing it by one is sufficient. '
f'In your case: `--num_processes {get_device_count() - 1}`.')
if use_vllm:
if not is_vllm_available():
raise ImportError('vLLM is not available and `use_vllm` is set to True. '
'Please install vLLM with `pip install vllm -U` to use it.')
self.prepare_vllm(model, fast_infer_device)
self.infer_device = fast_infer_device[self.local_infer_rank]
elif use_lmdeploy:
if not is_lmdeploy_available():
raise ImportError('LMDeploy is not available and `use_lmdeploy` is set to True.'
'Please install LMDeploy with `pip install lmdeploy -U` to use it.')
from swift.llm import LmdeployEngine
from swift.tuners import Swift
with Swift.grpo_context(model, self.template.processor):
fast_infer_device = int(fast_infer_device[self.local_infer_rank].split(':')[1])
self.engine = LmdeployEngine(
model.model_dir,
model.model_info.torch_dtype,
model_type=model.model_meta.model_type,
devices=[fast_infer_device],
session_len=args.lmdeploy_session_len,
cache_max_entry_count=args.lmdeploy_cache_max_entry_count,
reload_weights=True)
self.infer_device = fast_infer_device
self.engine.default_template = copy(self.template) # Avoid thread-unsafe modifications of the mode.
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
# When using vLLM, the main process is responsible for loading the model weights. This can cause process
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
from swift.llm import PtEngine
self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0) # 0: no limit
# Avoid thread-unsafe modifications of the mode.
self.request_config = RequestConfig(
max_tokens=args.max_completion_length,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
stop=args.stop_words,
)
Poe
这段代码定义了 GRPOTrainer 类,是 Swift 框架下用于强化学习人类反馈(RLHF)的大模型训练器,集成了多个模块(如 vLLM、LMDeploy、DeepSpeed 等)。我们逐行详细解析整个类的逻辑与作用。
🔶 类定义
python
运行复制
class GRPOTrainer(RLHFTrainerMixin, SwiftMixin, HFGRPOTrainer):
- 继承自多个混入类:
RLHFTrainerMixin、SwiftMixin和HFGRPOTrainer,它们提供了 RLHF 训练、Swift 框架支持和 HuggingFace Trainer 的功能。 GRPOTrainer是用于 GRPO(Preference Optimization)训练流程的核心类。
python
运行复制
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
- 创建一个线程池执行器对象,最大线程数为 1,可能用于异步任务(如推理、日志记录等)。
🔷 初始化函数 __init__
python
运行复制
def __init__(self, model, ref_model, reward_model, reward_funcs, *_args, **kwargs):
- 构造函数,接收多个模型(主模型、参考模型、奖励模型)和奖励函数列表,以及其他参数。
python
运行复制
from swift.trainers.rlhf_arguments import GRPOConfig
args: GRPOConfig = kwargs['args']
self.args = args
- 导入并提取训练配置,
args是一个包含所有训练超参的对象。
python
运行复制
self.train_queue = Queue()
self.eval_queue = Queue()
- 创建训练队列和评估队列(线程安全),用于异步数据处理。
python
运行复制
self.processing_class = kwargs.get('template').tokenizer
- 获取模板中的 tokenizer,用于后续文本处理。
python
运行复制
self.offload_modules = {}
self.offload_states = {}
- 用于记录被 CPU/offload 的模块状态和数据,节省内存。
python
运行复制
_, _, _, local_world_size = get_dist_setting()
- 获取分布式训练的设置,
local_world_size是当前节点的 GPU 数目。
🎯 奖励函数初始化
python
运行复制
if not isinstance(reward_funcs, list):
reward_funcs = [reward_funcs]
- 将单个奖励函数包装成列表。
python
运行复制
for i, reward_func in enumerate(reward_funcs):
- 遍历每个奖励函数。
python
运行复制
if reward_func in orms:
orms是一个注册表,查找是否为内置奖励函数。
python
运行复制
reward_func_class = orms[reward_func]
reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters)
- 获取奖励类构造函数的参数。
python
运行复制
reward_func_kwargs = {
key: getattr(args, key)
for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key)
}
- 根据配置构造奖励函数参数字典。
python
运行复制
if 'tokenizer' in reward_func_args:
reward_func_kwargs['tokenizer'] = self.processing_class
reward_funcs[i] = reward_func_class(**reward_func_kwargs)
- 实例化奖励函数对象。
python
运行复制
elif not callable(reward_func):
raise ValueError(...)
- 如果不是可调用对象,也不是注册名,抛出异常。
python
运行复制
self.reward_funcs = reward_funcs
- 保存奖励函数列表。
🤖 多轮对话处理函数设置
python
运行复制
self.multi_turn_func = None
if self.args.multi_turn_func:
- 多轮对话处理函数。
python
运行复制
if isinstance(self.args.multi_turn_func, str):
...
else:
self.multi_turn_func = self.args.multi_turn_func
- 支持字符串引用或直接传入函数。
🧠 奖励模型与模板
python
运行复制
self.reward_templates = [None] * len(self.reward_funcs)
if reward_model is not None:
self.reward_templates.append(kwargs.pop('reward_template', None))
self.reward_funcs.append(reward_model)
- 将
reward_model也作为奖励函数之一加入。
python
运行复制
if not self.reward_funcs:
raise ValueError(...)
- 如果没有任何奖励函数或模型,抛出错误。
⚖️ 奖励权重处理
python
运行复制
if args.reward_weights is not None:
...
self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
- 设置奖励函数的权重,确保数量一致。
python
运行复制
else:
self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
- 默认所有权重为 1。
🔧 训练参数配置
python
运行复制
self.num_generations = args.num_generations
self.temperature = args.temperature
model.warnings_issued['estimate_tokens'] = True
- 设置生成数量、温度,并记录 token 估计警告。
python
运行复制
kwargs['data_collator'] = lambda features: features
self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)}
- 设置数据收集器(跳过)和初始化指标缓存。
⚙️ 分布式相关:Tensor Parallel
python
运行复制
if self.args.tensor_parallel_size > 1 and self.multi_turn_func:
import torch.distributed as dist
rank, _, _, _ = get_dist_setting()
for tp_group in self.tp_group_ranks():
group = dist.new_group(tp_group)
if rank in tp_group:
self.group = group
- 创建 tensor parallel 分组,用于跨卡共享模型参数。
🧬 调用父类初始化
python
运行复制
super().__init__(model, ref_model, *_args, **kwargs)
- 调用
HFGRPOTrainer的构造函数。
✅ 校验 num_generations 合法性
python
运行复制
num_processes = self.accelerator.num_processes
global_batch_size = args.per_device_train_batch_size * num_processes
...
if self.num_generations not in possible_values:
raise ValueError(...)
- 确保
num_generations能被全局 batch size 整除。
🔐 设置随机种子
python
运行复制
set_seed(args.seed, device_specific=True)
- 保证每个进程种子不同,避免重复生成。
📦 模型推理设备配置
python
运行复制
if use_vllm or use_lmdeploy:
if self.infer_rank >= 0:
...
for _device in fast_infer_device:
...
- 自动推断推理设备,分配给 infer worker。
🔍 vLLM 初始化
python
运行复制
if use_vllm:
if not is_vllm_available():
raise ImportError(...)
self.prepare_vllm(model, fast_infer_device)
self.infer_device = fast_infer_device[self.local_infer_rank]
- 初始化 vLLM 推理引擎。
🧠 LMDeploy 初始化
python
运行复制
elif use_lmdeploy:
...
self.engine = LmdeployEngine(...)
self.infer_device = fast_infer_device
- 初始化 LMDeploy 推理引擎。
python
运行复制
self.engine.default_template = copy(self.template)
- 避免多线程对模板的修改。
🧱 PtEngine 作为默认推理引擎
python
运行复制
else:
from swift.llm import PtEngine
self.engine = PtEngine.from_model_template(self.model, copy(self.template), max_batch_size=0)
- 使用 Swift 自带引擎,如果没有使用 vLLM/LMDeploy。
🔧 请求配置初始化
python
运行复制
self.request_config = RequestConfig(
max_tokens=args.max_completion_length,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
repetition_penalty=args.repetition_penalty,
stop=args.stop_words,
)
- 设置生成请求的参数,如生成长度、采样策略、惩罚系数等。
✅ 总结
该类是 Swift 框架中 GRPO 算法的核心训练器,完成了如下功能:
- 初始化模型、参考模型、奖励模型与奖励函数。
- 校验配置项合法性(如 batch size 可整除)。
- 支持多种推理后端(vLLM、LMDeploy、PtEngine)。
- 支持 Tensor Parallel、分布式训练。
- 提供异步推理、生成配置等功能。
如果你需要,我还可以为你绘制结构图和类依赖关系图,帮助更直观理解。
if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
self.request_config.n = self.args.tensor_parallel_size
if self.infer_rank >= 0:
self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size
self.model_accepts_loss_kwargs = False
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
self.log_completions = args.log_completions
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))
# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. # noqa
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
# `_get_train_sampler` and `_prepare_inputs`.
self._buffered_inputs = [None] * args.gradient_accumulation_steps
if self.args.async_generate:
self.add_callback(GRPOCallback(self))
self.set_multi_turn_engine_default_max_tokens()
def split_batches(self):
"""Sync weights in batches
Only split LLM layers for now:
1. N batches for layers
2. other, embeds, lm_heads in one batch
3. multi-modal components in one batch
"""
model = self.accelerator.unwrap_model(self.model)
if self.args.move_model_batches is None:
# All in one
return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None]
model_arch = get_model_arch(model.model_meta.model_arch)
non_llm_parameters = []
llm_embeds = []
parameters = []
pattern = r'\.(\d+)\.'
layer_count = None
# Get the number of layers in LLM modules
for name, module in model.named_modules():
if isinstance(module, ModuleList):
if model_arch is not None and isinstance(model_arch, MultiModelKeys):
llm = model_arch.language_model
vision_tower = model_arch.vision_tower
if any(vt in name for vt in vision_tower):
continue
if isinstance(llm, list):
llm = llm[0]
if name.startswith('base_model'):
name = name.replace('base_model.', '')
if llm in name:
layer_count = len(module)
else:
layer_count = len(module)
assert layer_count is not None, 'Cannot find ModuleList to split modules.'
n_layers = ceil(layer_count / self.args.move_model_batches)
for _ in range(self.args.move_model_batches):
parameters.append([])
def replace_lora(name):
if 'lora_' in name:
return ''
else:
return name.replace('base_layer.', '')
def remove_lora_and_prefix(names):
names = set([re.sub(r'^_model\.', '', replace_lora(n)) for n in names])
return [n for n in names if n]
def split_llm(name):
match = re.search(pattern, name)
if match:
number = match.group(1)
group = int(number) // n_layers
parameters[group].append(name)
else:
llm_embeds.append(name)
for name, parameter in model.named_parameters():
if 'ref_model' in name:
continue
if model_arch is not None and isinstance(model_arch, MultiModelKeys):
llm = model_arch.language_model
vision_tower = model_arch.vision_tower
if any(vt in name for vt in vision_tower):
non_llm_parameters.append(name)
elif isinstance(llm, list):
llm = llm[0]
if llm in name:
split_llm(name)
else:
non_llm_parameters.append(name)
else:
split_llm(name)
if llm_embeds:
parameters.append(llm_embeds)
if non_llm_parameters:
parameters.append(non_llm_parameters)
parameters = [p for p in parameters if p]
parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters]
return parameters, parameters_no_lora
def prepare_vllm(self, model, fast_infer_device):
from swift.tuners import Swift
from swift.llm import VllmEngine
from swift.llm.infer.infer_engine import GRPOVllmEngine
_, _, _, local_world_size = get_dist_setting()
if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
# Compatibility with TP
cls = GRPOVllmEngine
else:
cls = VllmEngine
with Swift.grpo_context(model, self.template.processor):
self.engine = cls(
model.model_dir,
model.model_info.torch_dtype,
model_type=model.model_meta.model_type,
device=fast_infer_device[self.local_infer_rank],
tensor_parallel_size=self.args.tensor_parallel_size,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
enable_prefix_caching=self.args.vllm_enable_prefix_caching,
max_num_seqs=self.args.vllm_max_num_seqs,
enforce_eager=self.args.vllm_enforce_eager,
limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt,
num_infer_workers=self.args.num_infer_workers,
enable_sleep_mode=self.args.sleep_level > 0,
use_async_engine=False,
distributed_executor_backend='external_launcher',
max_model_len=self.args.vllm_max_model_len)
self.engine.default_template = self.template
Poe
我们继续逐行解释你提供的 Swift 框架中 GRPOTrainer 的后续代码。这部分主要涉及:
- 推理设备的配置
- 奖励模型准备
- 多轮训练参数设置
- 模型参数的分批管理(offload)
- vLLM 的初始化等
🔧 设置推理配置(request_config)
python
运行复制
if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
- 如果本地 GPU 数量和推理工作线程数一致,并且 GPU 数量大于 1,说明使用了 Tensor Parallel(TP)推理。
python
运行复制
self.request_config.n = self.args.tensor_parallel_size
- 设置每个 prompt 的生成数(
n)为 TP 大小,用于并行生成。
python
运行复制
if self.infer_rank >= 0:
self.request_config.seed = self.infer_rank // self.args.tensor_parallel_size
- 如果当前是推理进程,为每个 TP 分组设置不同的随机种子,避免生成重复内容。
🧪 奖励模型准备
python
运行复制
self.model_accepts_loss_kwargs = False
- 标记模型是否接受
loss_kwargs,默认不支持。
python
运行复制
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
- 如果奖励函数是模型对象,使用
Accelerate框架将其封装并切换为评估模式。
📁 完成记录和日志初始化
python
运行复制
self.log_completions = args.log_completions
self.jsonl_writer = JsonlWriter(os.path.join(self.args.output_dir, 'completions.jsonl'))
- 是否记录生成内容,并初始化写入
.jsonl文件的对象。
🔁 GRPO 多步迭代参数设置(对应论文中的 μ 和 ε)
python
运行复制
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
- 设置 GRPO 的核心迭代次数 (
μ) 和 epsilon 探索参数。
🔄 训练状态相关变量
python
运行复制
self._step = 0
- 当前训练步数。
python
运行复制
self._buffered_inputs = [None] * args.gradient_accumulation_steps
- 缓存梯度累积周期内的输入数据。
python
运行复制
if self.args.async_generate:
self.add_callback(GRPOCallback(self))
- 如果启用了异步生成,注册自定义回调函数。
python
运行复制
self.set_multi_turn_engine_default_max_tokens()
- 设置多轮对话生成的默认最大 token 数。
🔀 模型参数分批(offload)处理:split_batches
python
运行复制
def split_batches(self):
- 将模型参数分组,以支持分批迁移(例如 CPU/GPU offload)。
python
运行复制
model = self.accelerator.unwrap_model(self.model)
- 解包模型,去除
Accelerator封装层。
python
运行复制
if self.args.move_model_batches is None:
return [[n for n, p in model.named_parameters() if 'ref_model' not in n]], [None]
- 如果未设置分批数量,则所有参数一组。
🔍 模型结构识别与分层
python
运行复制
model_arch = get_model_arch(model.model_meta.model_arch)
non_llm_parameters = []
llm_embeds = []
parameters = []
pattern = r'\.(\d+)\.'
- 获取模型结构,准备分类参数。
- 正则从参数名中提取层数编号。
python
运行复制
layer_count = None
for name, module in model.named_modules():
if isinstance(module, ModuleList):
...
layer_count = len(module)
- 遍历模型模块,查找 LLM 的层数。
python
运行复制
assert layer_count is not None, 'Cannot find ModuleList to split modules.'
- 必须找到层数信息,否则报错。
📦 分批构建
python
运行复制
n_layers = ceil(layer_count / self.args.move_model_batches)
for _ in range(self.args.move_model_batches):
parameters.append([])
- 按照层数将参数平均分成
move_model_batches组。
🔧 工具函数(去掉 lora、前缀名)
python
运行复制
def replace_lora(name): ...
def remove_lora_and_prefix(names): ...
- 清理参数名,去除 Lora 层、模型封装前缀。
🧩 按照层号进行分组
python
运行复制
def split_llm(name):
match = re.search(pattern, name)
if match:
number = match.group(1)
group = int(number) // n_layers
parameters[group].append(name)
else:
llm_embeds.append(name)
- 根据参数名中的层号对参数分组。
🧹 参数分类
python
运行复制
for name, parameter in model.named_parameters():
if 'ref_model' in name:
continue
...
split_llm(name)
- 把 LLM 层、embedding 层、非LLM 参数分类并分组。
✅ 返回最终分组结果
python
运行复制
if llm_embeds: parameters.append(llm_embeds)
if non_llm_parameters: parameters.append(non_llm_parameters)
parameters = [p for p in parameters if p]
parameters_no_lora = [remove_lora_and_prefix(p_list) for p_list in parameters]
return parameters, parameters_no_lora
- 返回两个参数列表:一个是完整参数分组,一个是去除 Lora 和前缀后的名称。
🚀 vLLM 引擎初始化:prepare_vllm
python
运行复制
def prepare_vllm(self, model, fast_infer_device):
- 使用 vLLM 后端进行高性能推理。
python
运行复制
from swift.tuners import Swift
from swift.llm import VllmEngine
from swift.llm.infer.infer_engine import GRPOVllmEngine
- 导入推理引擎。
python
运行复制
_, _, _, local_world_size = get_dist_setting()
if local_world_size == self.args.num_infer_workers == get_device_count() and local_world_size > 1:
cls = GRPOVllmEngine
else:
cls = VllmEngine
- 根据是否为 TP 设置选择合适的引擎类。
⚙️ 使用 grpo_context 上下文管理器
python
运行复制
with Swift.grpo_context(model, self.template.processor):
- 使用
grpo_context临时启用 Swift 的上下文环境(可能设置模型为 eval,关闭 hook)。
🧠 初始化 vLLM 引擎实例
python
运行复制
self.engine = cls(
model.model_dir,
model.model_info.torch_dtype,
model_type=model.model_meta.model_type,
device=fast_infer_device[self.local_infer_rank],
tensor_parallel_size=self.args.tensor_parallel_size,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
enable_prefix_caching=self.args.vllm_enable_prefix_caching,
max_num_seqs=self.args.vllm_max_num_seqs,
enforce_eager=self.args.vllm_enforce_eager,
limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt,
num_infer_workers=self.args.num_infer_workers,
enable_sleep_mode=self.args.sleep_level > 0,
use_async_engine=False,
distributed_executor_backend='external_launcher',
max_model_len=self.args.vllm_max_model_len)
- 创建并配置 vLLM 推理引擎,参数包括模型路径、数据类型、推理设备、并行设置、缓存策略等。
python
运行复制
self.engine.default_template = self.template
- 设置默认模板,避免多线程访问时出错。
✅ 总结
这一部分代码主要做了以下关键事情:
- 根据分布式设备情况配置生成参数。
- 初始化奖励模型、日志工具。
- 支持多轮对话的设置。
- 将大模型参数智能分批,便于 offload 管理。
- 初始化 vLLM 推理引擎,提升生成效率。
如果你需要图示解释,比如:
- 模型分批结构图
- 推理引擎架构图
- 参数分组可视化
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)