引言

GRPO(Group Relative Policy Optimization 是由 DeepSeek 团队 在最新研究中提出的一种高效强化学习方法。本文参考 Swift 官方文档,在服务器环境中完整复现 GRPO 的训练流程,并记录了从环境配置到实训结果的全过程。

论文链接:https://arxiv.org/pdf/2402.03300

准备环境

创建 Conda 虚拟环境 myswift 并安装所需依赖包

conda create -n myswift python=3.10
conda activate myswift
pip install 'ms-swift'
pip install wandb
pip install vllm
pip install math_verify==0.5.2

两种训练方式:

我本来用方式一进行训练,但是我的环境是单张NVIDIA RTX 4090 ,显存24GB,实在不够用,所以我换成了方式二。但是我依旧会记录方式一的步骤,你可以试着操作,可以按照我的方式成功运行。

维度 colocate 模式 独立 server 模式
进程数 1 个进程同时运行 vLLM 与训练 2 个进程:① rollout server(推理)② rlhf(训练)
通信方式 内部函数调用,无网络延迟 训练端通过 HTTP/gRPC 调用 server
显存分配 训练与推理共用同一 GPU server 与训练各占不同 GPU
启动命令 swift rlhf --vllm_mode colocate swift rollout → swift rlhf --vllm_mode client
适用场景 单卡、小模型(≤24GB) 多卡、大模型、高吞吐部署
优缺点 简单、省显存,但推理吞吐较低 可扩展、高吞吐,但占显存多

方式一-独立 server 模式

Rollout 阶段

使用 swift rollout 生成模型在当前任务上的若干候选输出(completions)。

1. 手动下载模型和数据集(可选择自动下载)

手动拉取数据集和模型,这里主要是为了防止奇奇怪怪的报错,你也可以选择不手动,也就是说自动命令行下载也行。

魔塔数据集网址https://modelscope.cn/datasets/zouxuhong/Countdown-Tasks-3to4

(myswift) root@zhao:/mnt/zhao#git clone https://www.modelscope.cn/datasets/zouxuhong/Countdown-Tasks-3to4.git

下载完成,数据集本地位置/mnt/zhao/Countdown-Tasks-3to4/

魔塔qwen模型网址https://modelscope.cn/models/Qwen/Qwen2.5-0.5B-Instruct

(myswift) root@zhao:/mnt/zhao# modelscope download --model Qwen/Qwen2.5-0.5B-Instruct --local_dir ./models/Qwen2.5-0.5B-Instruct

下载完成,模型本地位置/mnt/zhao/models/Qwen2.5-0.5B-Instruct/

2.rollout
CUDA_VISIBLE_DEVICES=0 \
swift rollout \
    --model /mnt/zhao/models/Qwen2.5-0.5B-Instruct \
    --model_type qwen2_5 \
    --dataset 'zouxuhong/Countdown-Tasks-3to4#100' \
    --vllm_max_num_seqs 32 \
    --vllm_gpu_memory_utilization 0.5 \
    --vllm_max_model_len 2048

这说明vLLM已经成功在 8000 端口启动了 rollout 推理服务,可以看到:

INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)

注意:rollout 服务要保持运行(不要关掉该窗口)。

rollout 阶段仅用于支撑训练,即启动推理服务供后续 RLHF 调用。此阶段不会生成 JSONL 结果文件,但会创建空的结果目录。

/mnt/zhao/result/Qwen2.5-0.5B-Instruct/deploy_result/

RLHF 阶段(训练)

打开新的终端,再执行swift rlhf 命令。

使用 swift rlhf 加载 rollouts,再根据 reward function 优化策略。

1. 获取wandb的api-key

wandb可用于可视化训练损失、奖励和生成结果。

在wandb官网https://wandb.ai/site/注册账号,获得api-key(免费)

2. 下载官方示例的 plugin.py

在swift官网下载plugin.py放在examples/train/grpo/plugin目录下

https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py

3. GRPO训练命令
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
    --rlhf_type grpo \
    --model /mnt/zhao/models/Qwen2.5-0.5B-Instruct \
    --external_plugins /mnt/zhao/examples/train/grpo/plugin/plugin.py \
    --reward_funcs external_countdown format \
    --use_vllm true \
    --vllm_mode colocate \ # 表示训练与推理在同一进程中运行
    --train_type full \
    --torch_dtype bfloat16 \
    --dataset ms::/mnt/zhao/Countdown-Tasks-3to4#100 \
    --max_length 2048 \
    --max_completion_length 512 \
    --num_train_epochs 2 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 8 \
    --learning_rate 5e-7 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 5 \
    --logging_steps 1 \
    --output_dir output/GRPO_COUNTDOWN \
    --warmup_ratio 0.01 \
    --dataloader_num_workers 2 \
    --num_generations 2 \
    --temperature 1.0 \
    --system 'You are a helpful assistant. You first think about the reasoning process in your mind and then provide the user with the answer.' \
    --log_completions true \
    --report_to wandb \
    --beta 0.001 \ # GRPO 中的正则化系数
    --num_iterations 1
参数 含义 说明
--rlhf_type 强化学习方法类型 这里为 grpo(Group Relative Policy Optimization)
--external_plugins 外部插件路径 自定义 reward 函数逻辑文件
--reward_funcs 奖励函数 支持多种类型,如 format、external_countdown
--use_vllm 是否启用 vLLM 提升生成速度与显存利用率
--vllm_mode 推理运行模式 colocate(训练与推理同进程)或 client(独立 server)
--torch_dtype 精度类型 建议 bfloat16 以兼顾性能与显存
--gradient_accumulation_steps 梯度累积 小显存优化策略,累计多步梯度再反向传播
--beta GRPO 正则化系数 控制更新强度,防止策略过拟合
--num_generations rollout 生成样本数 每条 prompt 生成多少条候选
--report_to wandb 报告方式 将指标上传至 W&B 方便可视化分析
--system 系统提示词 控制生成风格或角色(类似 system prompt)

内存不够用,还差十多个G,就到这里了。

方式二-colocate 模式

1. 先用小参数离线跑通(可省略,最简单的训练方式)

CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
    --rlhf_type grpo \
    --model /mnt/zhao/models/Qwen2.5-0.5B-Instruct \
    --external_plugins /mnt/zhao/examples/train/grpo/plugin/plugin.py \
    --reward_funcs external_countdown format \
    --use_vllm true \
    --vllm_mode colocate \
    --vllm_gpu_memory_utilization 0.25 \
    --train_type full \
    --torch_dtype bfloat16 \
    --dataset ms::/mnt/zhao/Countdown-Tasks-3to4#100 \
    --max_length 2048 \
    --max_completion_length 256 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --num_generations 2 \
    --learning_rate 5e-7 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 5 \
    --logging_steps 1 \
    --output_dir output/GRPO_COUNTDOWN \
    --warmup_ratio 0.01 \
    --dataloader_num_workers 2 \
    --temperature 1.0 \
    --system 'You are a helpful assistant. You first think about the reasoning process in your mind and then provide the user with the answer.' \
    --log_completions true \
    --beta 0.001 \
    --num_iterations 1

结果展示:

Train: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:38<00:00,  1.55s/it]
[INFO:swift] last_model_checkpoint: /mnt/zhao/output/GRPO_COUNTDOWN/v16-20251111-121942/checkpoint-25
[INFO:swift] best_model_checkpoint: None
[INFO:swift] images_dir: /mnt/zhao/output/GRPO_COUNTDOWN/v16-20251111-121942/images
[INFO:swift] End time of running main: 2025-11-11 12:20:41.842218
[rank0]:[W1111 12:20:42.443431073 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

模型最终权重在这里:

/mnt/zhao/output/GRPO_COUNTDOWN/v16-20251111-121942/checkpoint-25

实验中的loss曲线等在下面目录中:

/mnt/zhao/output/GRPO_COUNTDOWN/v16-20251111-121942/images/

2. 用wandb展示实验效果

2.1. 登录 wandb详细步骤(2 分钟):

在终端执行

wandb login

会看到:

wandb: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

复制 API key,浏览器打开 https://wandb.ai/authorize,登录后把 40 位字符串复制到剪贴板。

回到终端粘贴(不会显示字符,直接回车) 出现:

Successfully logged in to Weights & Biases!

2.2. GRPO训练
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
    --rlhf_type grpo \
    --model /mnt/zhao/models/Qwen2.5-0.5B-Instruct \
    --external_plugins /mnt/zhao/examples/train/grpo/plugin/plugin.py \
    --reward_funcs external_countdown format \
    --use_vllm true \
    --vllm_mode colocate \
    --vllm_gpu_memory_utilization 0.25 \
    --train_type full \
    --torch_dtype bfloat16 \
    --dataset ms::/mnt/zhao/Countdown-Tasks-3to4#100 \
    --max_length 2048 \
    --max_completion_length 256 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --num_generations 2 \
    --learning_rate 5e-7 \
    --eval_steps 50 \
    --save_steps 50 \
    --save_total_limit 5 \
    --logging_steps 1 \
    --output_dir output/GRPO_COUNTDOWN \
    --warmup_ratio 0.01 \
    --dataloader_num_workers 2 \
    --temperature 1.0 \
    --system 'You are a helpful assistant. You first think about the reasoning process in your mind and then provide the user with the answer.' \
    --log_completions true \
    --report_to wandb \
    --beta 0.001 \
    --num_iterations 1

结果展示:

{'train_runtime': 121.99, 'train_samples_per_second': 0.82, 'train_steps_per_second': 0.205, 'train_loss': 8e-07, 'epoch': 1.0, 'global_step/max_steps': '25/25', 'percentage': '100.00%', 'elapsed_time': '1m 59s', 'remaining_time': '0s', 'memory(GiB)': 11.73, 'train_speed(iter/s)': 0.209785}
Train: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 25/25 [02:04<00:00,  4.98s/it]
[INFO:swift] last_model_checkpoint: /mnt/zhao/output/GRPO_COUNTDOWN/v18-20251111-141732/checkpoint-25
[INFO:swift] best_model_checkpoint: None
[INFO:swift] End time of running main: 2025-11-11 14:19:58.407963
wandb: 
wandb: 🚀 View run /mnt/zhao/output/GRPO_COUNTDOWN/v18-20251111-141732 at: 
wandb: Find logs at: wandb/run-20251111_141752-vd1oc7ol/logs
[rank0]:[W1111 14:20:01.944944379 ProcessGroupNCCL.cpp:1538] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐