【大模型训练】 roll框架 self.model.release_memory_occupation。
您询问的是。
您询问的是 actor_infer 初始化完成后,为什么要执行 self.model.release_memory_occupation。
日志含义
这些日志显示 actor_infer 在初始化完成后,立即执行了内存释放操作:
[actor_infer-0] actor_infer-0 initialized
[actor_infer-0] Cache flushed successfully!
[actor_infer-0] self.model.release_memory_occupation exec ....
为什么要这样做
这是 colocated 模式下的内存优化策略。
1. Colocated 模式的特点
在您的配置中,actor_train 和 actor_infer 共享同一组 GPU(单节点 8 卡)。这意味着:
- 训练时: actor_train 需要占用 GPU 内存
- 推理时: actor_infer 需要占用 GPU 内存
- 两者不能同时占用,否则会 OOM
2. 初始化后立即 Offload 的原因
在 ActorWorker.initialize 方法中,初始化完成后会立即调用 offload_states(): 1
关键步骤:
- 创建 strategy 并初始化模型(此时模型在 GPU 上)
- 如果配置了 resume,加载 checkpoint
- 立即调用
strategy.offload_states()将模型从 GPU 移到 CPU
3. SGLang Strategy 的 offload_states 实现
在 SGLang 策略中,offload_states 会调用 release_memory_occupation: 2
执行内容:
- 检查是否需要 offload model_params
- 检查是否是 colocated 模式(
is_train_infer_colocated) - 如果模型在 GPU 上,调用
release_memory_occupation()释放内存 - 设置
is_model_in_gpu = False标记 - 清理缓存和垃圾回收
4. release_memory_occupation 的实现
在 SGLang scheduler 中,这个方法会: 3
关键操作:
- 导出模型的静态状态(如 buffer、persistent buffer 等)
- 将模型移到 CPU:
model.to('cpu') - 暂停 KV cache 的内存占用
- 清空 cache
完整的内存管理流程
在 colocated 模式下,内存管理遵循以下模式:
为什么初始化后不保留在 GPU
如果初始化后不立即 offload,会导致:
- 内存浪费: actor_infer 初始化后可能不会立即使用,但占用了 GPU 内存
- OOM 风险: actor_train 在训练时需要大量内存,如果 actor_infer 还在 GPU 上,会导致内存不足
- 无法并行: 两个集群无法在同一组 GPU 上交替工作
后续使用时的恢复
当需要使用 actor_infer 时,会调用 load_states() 恢复模型到 GPU: 4
恢复流程:
- 检查模型是否在 GPU 上
- 如果不在,调用
resume_memory_occupation()将模型从 CPU 移回 GPU - 设置
is_model_in_gpu = True
Notes
这种设计是 ROLL 框架支持 colocated 模式的关键机制。通过在初始化后立即 offload,确保了:
- 内存高效: GPU 内存只在需要时才被占用
- 灵活调度: 可以在训练和推理之间快速切换
- 避免 OOM: 防止多个集群同时占用 GPU 导致内存溢出
这就是为什么您在日志中看到 release_memory_occupation 在初始化后立即执行的原因。
Wiki pages you might want to explore:
Citations
File: roll/pipeline/base_worker.py (L43-58)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def initialize(self, pipeline_config):
super().initialize(pipeline_config)
self.strategy = create_strategy(worker=self)
if self.worker_config.model_args.model_type == "diffusion_module":
self.strategy.initialize(model_provider=default_diffusion_module_provider)
else:
self.strategy.initialize(model_provider=default_actor_model_provider)
self.tokenizer = self.strategy.tokenizer
if self.pipeline_config.resume_from_checkpoint:
load_dir = download_model(self.pipeline_config.resume_from_checkpoint)
self.strategy.load_checkpoint(load_dir=load_dir, tag="checkpoint")
File: roll/distributed/strategy/sglang_strategy.py (L262-266)
def load_states(self, *args, **kwargs):
if not self.model.is_model_in_gpu:
self.model.resume_memory_occupation()
logger.info("self.model.resume_memory_occupation exec ....")
self.model.is_model_in_gpu = True
File: roll/distributed/strategy/sglang_strategy.py (L268-276)
def offload_states(self, include=None, non_blocking=False):
if include is None or OffloadStateType.model_params in include:
if self.model.is_model_in_gpu:
self.model.release_memory_occupation()
logger.info("self.model.release_memory_occupation exec ....")
self.model.is_model_in_gpu = False
self.recv_manager.clear()
gc.collect()
current_platform.empty_cache()
File: roll/third_party/sglang/v0410post2_patch/scheduler.py (L450-457)
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
self.stashed_model_static_state = _export_static_state(
self.tp_worker.worker.model_runner.model
)
self.tp_worker.worker.model_runner.model.to('cpu')
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
self.flush_cache()
return ReleaseMemoryOccupationReqOutput()
您询问的是为什么模型在初始化后要在 GPU 中。
问题澄清
根据上下文,我理解您实际上是在问:为什么模型初始化时需要先加载到 GPU,然后又立即 offload 到 CPU。
模型初始化必须在 GPU 中的原因
1. 推理引擎的架构要求
SGLang 和 vLLM 等推理引擎在初始化时必须在 GPU 上创建模型,这是它们的架构设计决定的: 1
关键原因:
- 模型加载:
get_model()方法需要在 GPU 上初始化模型结构和分配内存 - KV Cache 初始化: 推理引擎需要在 GPU 上预分配 KV cache 空间
- CUDA 图优化: 某些推理引擎会在初始化时构建 CUDA 图,必须在 GPU 上进行
2. 模型配置和验证
初始化过程中需要在 GPU 上进行多项配置和验证: 2
必须在 GPU 上完成的操作:
- 检测 GPU 计算能力(compute capability)
- 设置 dtype(如果 GPU 不支持 bf16,自动降级为 fp16)
- 加载量化参数(如 FP8 KV cache scaling factors)
- 初始化 attention 机制(sliding window size 等)
3. 分布式通信初始化
在多 GPU 场景下,通信组的初始化也需要模型在 GPU 上: 3
Warmup all-reduce 需要在 GPU 上创建张量并执行通信操作,验证通信组可用。
为什么初始化后立即 Offload
虽然初始化必须在 GPU 上,但在 colocated 模式下,初始化完成后会立即 offload: 4
原因:
- 内存共享: actor_train 和 actor_infer 共享同一组 GPU,如果 actor_infer 一直占用 GPU,actor_train 无法训练
- 延迟使用: actor_infer 在第一个 step 的 model_update 之前不会被使用,没必要一直占用 GPU
- 避免 OOM: 防止多个集群同时占用 GPU 导致内存溢出
完整流程
技术限制
无法跳过 GPU 初始化的原因:
- 推理引擎设计: SGLang/vLLM 的
get_model()方法硬编码为在 GPU 上加载 - CUDA 依赖: 许多初始化操作(如 attention backend 选择)依赖 CUDA 上下文
- 内存分配: KV cache 的大小计算需要先在 GPU 上分配模型,测量可用内存
Notes
模型初始化必须在 GPU 中是推理引擎的架构限制,无法绕过。但通过 colocated 模式的 offload 机制,可以在初始化后立即释放 GPU 内存,实现训练和推理的内存共享。这种设计在单节点多卡场景下非常高效,避免了重复加载模型的开销。
Wiki pages you might want to explore:
Citations
File: roll/third_party/sglang/v052_patch/model_runner.py (L32-75)
def load_model(self):
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
# This can reduce thread conflicts and speed up weight loading.
if self.device != "cpu":
torch.set_num_threads(1)
if self.device == current_platform.device_type:
if current_platform.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16
if current_platform.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
set_cuda_arch()
# Prepare the model config
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
self.model_config, self.load_config, self.tp_size
)
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
# Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer()
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
File: roll/third_party/sglang/v0410post2_patch/model_runner.py (L30-120)
def load_model(self):
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
logger.info(
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
# This can reduce thread conflicts and speed up weight loading.
if self.device != "cpu":
torch.set_num_threads(1)
if self.device == "cuda":
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
self.model_config.dtype = torch.float16
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
set_cuda_arch()
# Prepare the model config
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
self.model_config, self.load_config, self.tp_size
)
if self.server_args.load_format == "gguf":
monkey_patch_vllm_gguf_config()
# Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state()
monkey_patch_isinstance_for_vllm_base_layer()
self.model = get_model(
model_config=self.model_config,
load_config=self.load_config,
device_config=DeviceConfig(self.device),
)
monkey_patch_vllm_parallel_state(reverse=True)
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
self.model.load_kv_cache_scales(
self.server_args.quantization_param_path
)
logger.info(
"Loaded KV cache scaling factors from %s",
self.server_args.quantization_param_path,
)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
"model %s does not support loading scaling factors.",
self.model.__class__,
)
else:
logger.warning(
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!"
)
# Parse other args
self.sliding_window_size = None
if hasattr(self.model, "get_attention_sliding_window_size"):
self.sliding_window_size = self.model.get_attention_sliding_window_size()
elif self.model_config.attention_chunk_size is not None:
self.sliding_window_size = self.model_config.attention_chunk_size
logger.info(
f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
)
self.dtype = self.model_config.dtype
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
self.weight_load_mem_usage = before_avail_memory - after_avail_memory
logger.info(
f"Load weight end. "
f"type={type(self.model).__name__}, "
f"dtype={self.dtype}, "
f"avail mem={after_avail_memory:.2f} GB, "
f"mem usage={self.weight_load_mem_usage:.2f} GB."
)
File: roll/third_party/vllm/worker_helper.py (L65-87)
def setup_collective_group(self, comm_plan, backend, rank_in_cluster):
self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {})
rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster,
rank_in_worker=dist.get_rank())
if rank is None:
logger.info(f"no comm_plan found for rank {rank_in_cluster}/{dist.get_rank()}")
return
group_name = comm_plan_args["group_name"]
master_addr = comm_plan_args["master_addr"]
master_port = comm_plan_args["master_port"]
world_size = len(comm_plan_args["tgt_devices"]) + 1
src_pp_rank = comm_plan_args["src_pp_rank"]
collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name,
master_addr=master_addr, master_port=master_port)
# A small all_reduce for warmup.
collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)
self.model_update_comm_plan[src_pp_rank] = dict(rank=rank,
world_size=world_size,
src_pp_rank=src_pp_rank,
group_name=group_name,
comm_plan=comm_plan,
comm_plan_args=comm_plan_args)
logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")
File: roll/pipeline/base_worker.py (L44-61)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def initialize(self, pipeline_config):
super().initialize(pipeline_config)
self.strategy = create_strategy(worker=self)
if self.worker_config.model_args.model_type == "diffusion_module":
self.strategy.initialize(model_provider=default_diffusion_module_provider)
else:
self.strategy.initialize(model_provider=default_actor_model_provider)
self.tokenizer = self.strategy.tokenizer
if self.pipeline_config.resume_from_checkpoint:
load_dir = download_model(self.pipeline_config.resume_from_checkpoint)
self.strategy.load_checkpoint(load_dir=load_dir, tag="checkpoint")
self.logger.info(f"{self.worker_name} initialized")
self.strategy.offload_states()
model_update 时: 临时 load actor_train 的权重,广播到 actor_infer,然后再 offload 哪里有 这个逻辑,为什么
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)