本文将介绍一下如何使用 Unsloth 框架和 TRL (Transformer Reinforcement Learning) 库,通过 GRPO (Group Relative Policy Optimization) 强化学习算法对 Qwen 2.5 (3B) 大模型进行微调(Fine-tuning)。不讲原理,直接上代码:

源代码已经放在:https://github.com/ArronAI007/Awesome-AGI/blob/main/LLM%20Pipeline/Fine-Tune/trl/01_Train_Qwen_2_5(3B)_To_Reason_With_GRPO.ipynb

整个代码实现了以下主要功能:

  • 环境配置:安装并配置 Unsloth、vLLM 和 TRL 等高性能训练库。

  • 模型加载与量化:加载 Qwen 2.5-3B-Instruct 模型,并使用 4-bit 量化(QLoRA)以降低显存占用,同时启用 vLLM 进行快速推理。

  • 数据集准备:加载 GSM8K(小学数学)数据集,并将其格式化为包含 和 XML 标签的 Prompt,旨在训练模型具备“思维链”(Chain-of-Thought)推理能力。

  • 定义奖励函数 (Reward Functions):定义了一组规则来评价模型的输出,包括:答案准确性、是否为整数、XML 格式是否规范等。这是强化学习(RL)的核心部分。

  • GRPO 训练:使用 GRPO 算法进行训练。GRPO 会让模型针对同一个问题生成多个回答,通过对比这些回答的奖励分数来优化策略,而不需要额外的价值模型(Value Model)。

  • 保存与推理:保存微调后的 LoRA 权重,并演示如何加载权重进行推理测试。

一、环境安装与设置

import os, numpy
# 设置环境变量,让 Unsloth 在 vLLM 中预留更多显存用于上下文
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
# 获取当前 numpy 版本以防止依赖冲突
numpy_version = f"numpy=={numpy.__version__}"
# Install dependencies with numpy version preservation
# 安装 Unsloth 及其依赖(Unsloth 用于加速训练,vLLM 用于加速推理)
!uv pip install unsloth_zoo
!uv pip install --upgrade unsloth vllm==0.9.2 {numpy_version} torchvision bitsandbytes xformers
!uv pip install triton==3.2.0
!uv pip install transformers==4.55.4
!uv pip install --no-deps trl==0.22.2

二、加载Model和Tokenizer​​​​​​​

from unsloth import FastLanguageModel
import torch
# 设置最大上下文长度
max_seq_length = 1024
# 加载预训练模型和分词器
# 功能:加载 Qwen2.5-3B-Instruct 模型,使用 4-bit 量化加载以节省显存,开启 fast_inference (vLLM)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True,               # 4bit 量化加载
    fast_inference = True,             # 启用 vLLM 快速推理引擎
    max_lora_rank = 8,                 # LoRA 秩
    gpu_memory_utilization = 0.9,      # 显存利用率上限
)

三、配置 LoRA (低秩适应)​​​​​​​

# 配置 PEFT (Parameter-Efficient Fine-Tuning)
# 功能:将模型转换为 LoRA 模式,只训练新增的少量参数,冻结原模型参数
model = FastLanguageModel.get_peft_model(
    model,
    r = 8,  # LoRA 的秩
    # 指定需要应用 LoRA 的模块(注意力层和前馈网络层)
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 8,
    use_gradient_checkpointing = "unsloth",       # 使用梯度检查点节省显存
    random_state = 1234,
)

四、数据集处理与格式化​​​​​​​

import re
from datasets import load_dataset, Dataset
# 系统提示词,强制模型使用特定的 XML 格式输出推理过程和答案
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
# 定义 XML 格式模板
XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""
# 函数:从模型输出中提取 XML 标签内的答案
def extract_xml_answer(text):
    if "" not in text or "" not in text:
        return ""
    return text.split(" ", 1)[-1].split(" ", 1)[0].strip()
# 函数:从 GSM8K 数据集的原始答案字段中提取最终数值(通常在 #### 之后)
def extract_hash_answer(text):
    return text.split("####")[-1].strip() if "####" in text else None
# 函数:加载并预处理 GSM8K 数据集
# 功能:加载 OpenAI 的 GSM8K 数据集,并将每个样本转化为包含 system prompt 和 user prompt 的对话格式
def get_gsm8k_dataset(split = "train"):
    data = load_dataset("openai/gsm8k", "main")[split]
    return data.map(
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question"]},
            ],
            "answer": extract_hash_answer(x["answer"]),           # 提取标准答案用于后续奖励计算
        }
    )
# 加载处理好的数据集
dataset = get_gsm8k_dataset()

五、定义奖励函数 (Reward Functions)

这是 GRPO 的核心,模型生成的每个结果都会经过这些函数打分。​​​​​​​

# 奖励函数 1:正确性奖励
# 功能:检查模型生成的答案(从 XML 中提取)是否与标准答案完全一致。正确得 2.0 分,否则 0 分。
def correctness_reward_func(prompts, completions, answer, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    # 打印日志方便调试
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
# 奖励函数 2:整数奖励
# 功能:检查提取出的答案是否为数字。是则得 0.5 分。
def int_reward_func(completions, **kwargs):
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
# 奖励函数 3:严格格式奖励
# 功能:使用正则检查输出是否严格符合 <reasoning>...\n<answer>... 的格式结构。
def strict_format_reward_func(completions, **kwargs):
    pattern = r"^\n.*?\n\n\n.*?\n\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]
# 奖励函数 4:宽松格式奖励
# 功能:检查输出是否至少包含了 XML 标签,允许格式上有少量空白差异。
def soft_format_reward_func(completions, **kwargs):
    pattern = r".*?\s*.*?"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]
# 辅助函数:计算 XML 标签的完整性
def count_xml(text):
    count = 0.0
    # 检查各个标签是否存在,每存在一个加分,如果格式混乱(如多余换行)则扣分
    if text.count("\n") == 1:
        count += 0.125
    if text.count("\n\n") == 1:
        count += 0.125
    if text.count("\n\n") == 1:
        count += 0.125
        count -= len(text.split("\n\n")[-1])*0.001       # 惩罚项
    if text.count("\n") == 1:
        count += 0.125
        count -= (len(text.split("\n")[-1]) - 1)*0.001   # 惩罚项
    return count
# 奖励函数 5:XML 计数奖励
# 功能:基于 XML 标签的完整性和位置给予分数。
def xmlcount_reward_func(completions, **kwargs):
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

六、配置与启动 GRPO 训练​​​​​​​

from trl import GRPOConfig, GRPOTrainer
# 配置训练参数
training_args = GRPOConfig(
    use_vllm = True,                  # 使用 vLLM 生成样本(极快)
    learning_rate = 5e-6,             # 学习率
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",             # 使用 8-bit 优化器节省显存
    logging_steps = 1,
    per_device_train_batch_size = 4,
    gradient_accumulation_steps = 1,
    num_generations = 4,              # GRPO 核心:每个 prompt 生成 4 个回答进行对比
    max_prompt_length = 256,
    max_completion_length = 200,
    max_steps = 250,                  # 训练总步数
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none",
    output_dir = "outputs",
)
    # 初始化 GRPO 训练器
    # 功能:将模型、奖励函数列表和训练配置结合
    trainer = GRPOTrainer(
        model = model,
        processing_class = tokenizer,
        reward_funcs = [
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            int_reward_func,
            correctness_reward_func,
        ],
        args = training_args,
        train_dataset = dataset,
    )
      # 开始训练
      # 功能:模型开始根据 prompt 生成多个回答,根据奖励函数的反馈更新 LoRA 权重,使模型更倾向于生成高分回答(格式正确且答案正确)。
      trainer.train()

      七、保存与推理测试​​​​​​​

      # 保存训练好的 LoRA 适配器
      model.save_lora("grpo_saved_lora")
        # --- 推理部分 ---
        from vllm import SamplingParams
        # 测试用的查询
        query = "How many r's are in strawberry?"
        # 构建聊天模板
        text = tokenizer.apply_chat_template([
            {"role" : "user", "content" : query},
        ], tokenize = False, add_generation_prompt = True)
        # 设置采样参数
        sampling_params = SamplingParams(
            temperature = 0.8,
            top_p = 0.95,
            max_tokens = 1024,
        )
        # 生成回答(不加载 LoRA 或 加载 LoRA)
        # 这里演示了如何使用 model.fast_generate 进行快速推理
        output = model.fast_generate(
            [text],
            sampling_params = sampling_params,
            lora_request = None,                    # 这里设为 None 表示用基础模型,若要用训练后的模型需加载 LoRA
        )[0].outputs[0].text
        print(output)
          # 构建聊天模板
          text = tokenizer.apply_chat_template([
              {"role" : "system", "content" : SYSTEM_PROMPT},
              {"role" : "user", "content" : query},
          ], tokenize = False, add_generation_prompt = True)
          sampling_params = SamplingParams(
              temperature = 0.8,
              top_p = 0.95,
              max_tokens = 1024,
          )
          # 再次生成,这次加载刚才保存的 LoRA 权重
          # 功能:验证经过 GRPO 训练后的模型表现
          output = model.fast_generate(
              text,
              sampling_params = sampling_params,
              lora_request = model.load_lora("grpo_saved_lora"),
          )[0].outputs[0].text
          print(output)

          至此,完整的微调代码就介绍完了。

          Logo

          火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

          更多推荐