您询问的是 actor_infer 初始化完成后,为什么要执行 self.model.release_memory_occupation

日志含义

这些日志显示 actor_infer 在初始化完成后,立即执行了内存释放操作:

[actor_infer-0] actor_infer-0 initialized
[actor_infer-0] Cache flushed successfully!
[actor_infer-0] self.model.release_memory_occupation exec ....

为什么要这样做

这是 colocated 模式下的内存优化策略

1. Colocated 模式的特点

在您的配置中,actor_train 和 actor_infer 共享同一组 GPU(单节点 8 卡)。这意味着:

  • 训练时: actor_train 需要占用 GPU 内存
  • 推理时: actor_infer 需要占用 GPU 内存
  • 两者不能同时占用,否则会 OOM

2. 初始化后立即 Offload 的原因

ActorWorker.initialize 方法中,初始化完成后会立即调用 offload_states(): 1

关键步骤:

  1. 创建 strategy 并初始化模型(此时模型在 GPU 上)
  2. 如果配置了 resume,加载 checkpoint
  3. 立即调用 strategy.offload_states() 将模型从 GPU 移到 CPU

3. SGLang Strategy 的 offload_states 实现

在 SGLang 策略中,offload_states 会调用 release_memory_occupation: 2

执行内容:

  1. 检查是否需要 offload model_params
  2. 检查是否是 colocated 模式(is_train_infer_colocated)
  3. 如果模型在 GPU 上,调用 release_memory_occupation() 释放内存
  4. 设置 is_model_in_gpu = False 标记
  5. 清理缓存和垃圾回收

4. release_memory_occupation 的实现

在 SGLang scheduler 中,这个方法会: 3

关键操作:

  1. 导出模型的静态状态(如 buffer、persistent buffer 等)
  2. 将模型移到 CPU: model.to('cpu')
  3. 暂停 KV cache 的内存占用
  4. 清空 cache

完整的内存管理流程

在 colocated 模式下,内存管理遵循以下模式:

Pipeline ActorTrain ActorInfer 初始化阶段 initialize() - 模型加载到 GPU offload_states() - 立即释放 GPU initialize() - 模型加载到 GPU offload_states() - 立即释放 GPU 训练阶段 load_states() - 恢复到 GPU train_step() offload_states() - 释放 GPU 推理阶段 load_states() - 恢复到 GPU generate() offload_states() - 释放 GPU Pipeline ActorTrain ActorInfer

为什么初始化后不保留在 GPU

如果初始化后不立即 offload,会导致:

  1. 内存浪费: actor_infer 初始化后可能不会立即使用,但占用了 GPU 内存
  2. OOM 风险: actor_train 在训练时需要大量内存,如果 actor_infer 还在 GPU 上,会导致内存不足
  3. 无法并行: 两个集群无法在同一组 GPU 上交替工作

后续使用时的恢复

当需要使用 actor_infer 时,会调用 load_states() 恢复模型到 GPU: 4

恢复流程:

  1. 检查模型是否在 GPU 上
  2. 如果不在,调用 resume_memory_occupation() 将模型从 CPU 移回 GPU
  3. 设置 is_model_in_gpu = True

Notes

这种设计是 ROLL 框架支持 colocated 模式的关键机制。通过在初始化后立即 offload,确保了:

  1. 内存高效: GPU 内存只在需要时才被占用
  2. 灵活调度: 可以在训练和推理之间快速切换
  3. 避免 OOM: 防止多个集群同时占用 GPU 导致内存溢出

这就是为什么您在日志中看到 release_memory_occupation 在初始化后立即执行的原因。

Wiki pages you might want to explore:

Citations

File: roll/pipeline/base_worker.py (L43-58)


    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def initialize(self, pipeline_config):
        super().initialize(pipeline_config)

        self.strategy = create_strategy(worker=self)

        if self.worker_config.model_args.model_type == "diffusion_module":
            self.strategy.initialize(model_provider=default_diffusion_module_provider)
        else:
            self.strategy.initialize(model_provider=default_actor_model_provider)
        
        self.tokenizer = self.strategy.tokenizer
        if self.pipeline_config.resume_from_checkpoint:
            load_dir = download_model(self.pipeline_config.resume_from_checkpoint)
            self.strategy.load_checkpoint(load_dir=load_dir, tag="checkpoint")

File: roll/distributed/strategy/sglang_strategy.py (L262-266)

    def load_states(self, *args, **kwargs):
        if not self.model.is_model_in_gpu:
            self.model.resume_memory_occupation()
            logger.info("self.model.resume_memory_occupation exec ....")
            self.model.is_model_in_gpu = True

File: roll/distributed/strategy/sglang_strategy.py (L268-276)

    def offload_states(self, include=None, non_blocking=False):
        if include is None or OffloadStateType.model_params in include:
            if self.model.is_model_in_gpu:
                self.model.release_memory_occupation()
                logger.info("self.model.release_memory_occupation exec ....")
                self.model.is_model_in_gpu = False
        self.recv_manager.clear()
        gc.collect()
        current_platform.empty_cache()

File: roll/third_party/sglang/v0410post2_patch/scheduler.py (L450-457)

    def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
        self.stashed_model_static_state = _export_static_state(
            self.tp_worker.worker.model_runner.model
        )
        self.tp_worker.worker.model_runner.model.to('cpu')
        self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
        self.flush_cache()
        return ReleaseMemoryOccupationReqOutput()

您询问的是为什么模型在初始化后要在 GPU 中

问题澄清

根据上下文,我理解您实际上是在问:为什么模型初始化时需要先加载到 GPU,然后又立即 offload 到 CPU

模型初始化必须在 GPU 中的原因

1. 推理引擎的架构要求

SGLang 和 vLLM 等推理引擎在初始化时必须在 GPU 上创建模型,这是它们的架构设计决定的: 1

关键原因:

  • 模型加载: get_model() 方法需要在 GPU 上初始化模型结构和分配内存
  • KV Cache 初始化: 推理引擎需要在 GPU 上预分配 KV cache 空间
  • CUDA 图优化: 某些推理引擎会在初始化时构建 CUDA 图,必须在 GPU 上进行

2. 模型配置和验证

初始化过程中需要在 GPU 上进行多项配置和验证: 2

必须在 GPU 上完成的操作:

  • 检测 GPU 计算能力(compute capability)
  • 设置 dtype(如果 GPU 不支持 bf16,自动降级为 fp16)
  • 加载量化参数(如 FP8 KV cache scaling factors)
  • 初始化 attention 机制(sliding window size 等)

3. 分布式通信初始化

在多 GPU 场景下,通信组的初始化也需要模型在 GPU 上: 3

Warmup all-reduce 需要在 GPU 上创建张量并执行通信操作,验证通信组可用。

为什么初始化后立即 Offload

虽然初始化必须在 GPU 上,但在 colocated 模式下,初始化完成后会立即 offload: 4

原因:

  1. 内存共享: actor_train 和 actor_infer 共享同一组 GPU,如果 actor_infer 一直占用 GPU,actor_train 无法训练
  2. 延迟使用: actor_infer 在第一个 step 的 model_update 之前不会被使用,没必要一直占用 GPU
  3. 避免 OOM: 防止多个集群同时占用 GPU 导致内存溢出

完整流程

Pipeline.__init__() ActorWorker SGLangStrategy SGLang Engine initialize() create_strategy() EngineSA(**config) 必须在 GPU 上初始化 load_model() on GPU 初始化 KV cache on GPU 设置 dtype, compute capability 构建 CUDA 图(如果启用) 初始化完成,模型在 GPU model.is_model_in_gpu = True Colocated 模式优化 offload_states() release_memory_occupation() 将模型移到 CPU model.is_model_in_gpu = False 后续使用时再 load_states() Pipeline.__init__() ActorWorker SGLangStrategy SGLang Engine

技术限制

无法跳过 GPU 初始化的原因:

  1. 推理引擎设计: SGLang/vLLM 的 get_model() 方法硬编码为在 GPU 上加载
  2. CUDA 依赖: 许多初始化操作(如 attention backend 选择)依赖 CUDA 上下文
  3. 内存分配: KV cache 的大小计算需要先在 GPU 上分配模型,测量可用内存

Notes

模型初始化必须在 GPU 中是推理引擎的架构限制,无法绕过。但通过 colocated 模式的 offload 机制,可以在初始化后立即释放 GPU 内存,实现训练和推理的内存共享。这种设计在单节点多卡场景下非常高效,避免了重复加载模型的开销。

Wiki pages you might want to explore:

Citations

File: roll/third_party/sglang/v052_patch/model_runner.py (L32-75)

    def load_model(self):
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
        logger.info(
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        # This can reduce thread conflicts and speed up weight loading.
        if self.device != "cpu":
            torch.set_num_threads(1)
        if self.device == current_platform.device_type:
            if current_platform.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                self.model_config.dtype = torch.float16
                if current_platform.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")

        set_cuda_arch()

        # Prepare the model config
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
            model_loader_extra_config=self.server_args.model_loader_extra_config,
        )
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()

        # Load the model
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
        monkey_patch_isinstance_for_vllm_base_layer()

        self.model = get_model(
            model_config=self.model_config,
            load_config=self.load_config,
            device_config=DeviceConfig(self.device),
        )

File: roll/third_party/sglang/v0410post2_patch/model_runner.py (L30-120)

    def load_model(self):
        before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
        logger.info(
            f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
        )

        # This can reduce thread conflicts and speed up weight loading.
        if self.device != "cpu":
            torch.set_num_threads(1)
        if self.device == "cuda":
            if torch.cuda.get_device_capability()[0] < 8:
                logger.info(
                    "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
                )
                self.server_args.dtype = "float16"
                self.model_config.dtype = torch.float16
                if torch.cuda.get_device_capability()[1] < 5:
                    raise RuntimeError("SGLang only supports sm75 and above.")

        set_cuda_arch()

        # Prepare the model config
        self.load_config = LoadConfig(
            load_format=self.server_args.load_format,
            download_dir=self.server_args.download_dir,
            model_loader_extra_config=self.server_args.model_loader_extra_config,
        )
        if self.device == "cpu":
            self.model_config = adjust_config_with_unaligned_cpu_tp(
                self.model_config, self.load_config, self.tp_size
            )
        if self.server_args.load_format == "gguf":
            monkey_patch_vllm_gguf_config()

        # Load the model
        # Remove monkey_patch when linear.py quant remove dependencies with vllm
        monkey_patch_vllm_parallel_state()
        monkey_patch_isinstance_for_vllm_base_layer()

        self.model = get_model(
            model_config=self.model_config,
            load_config=self.load_config,
            device_config=DeviceConfig(self.device),
        )
        monkey_patch_vllm_parallel_state(reverse=True)
        monkey_patch_isinstance_for_vllm_base_layer(reverse=True)

        if self.server_args.kv_cache_dtype == "fp8_e4m3":
            if self.server_args.quantization_param_path is not None:
                if callable(getattr(self.model, "load_kv_cache_scales", None)):
                    self.model.load_kv_cache_scales(
                        self.server_args.quantization_param_path
                    )
                    logger.info(
                        "Loaded KV cache scaling factors from %s",
                        self.server_args.quantization_param_path,
                    )
                else:
                    raise RuntimeError(
                        "Using FP8 KV cache and scaling factors provided but "
                        "model %s does not support loading scaling factors.",
                        self.model.__class__,
                    )
            else:
                logger.warning(
                    "Using FP8 KV cache but no scaling factors "
                    "provided. Defaulting to scaling factors of 1.0. "
                    "This may lead to less accurate results!"
                )

        # Parse other args
        self.sliding_window_size = None
        if hasattr(self.model, "get_attention_sliding_window_size"):
            self.sliding_window_size = self.model.get_attention_sliding_window_size()
        elif self.model_config.attention_chunk_size is not None:
            self.sliding_window_size = self.model_config.attention_chunk_size
            logger.info(
                f"Setting sliding_window_size to be attention_chunk_size: {self.sliding_window_size}"
            )

        self.dtype = self.model_config.dtype

        after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
        self.weight_load_mem_usage = before_avail_memory - after_avail_memory
        logger.info(
            f"Load weight end. "
            f"type={type(self.model).__name__}, "
            f"dtype={self.dtype}, "
            f"avail mem={after_avail_memory:.2f} GB, "
            f"mem usage={self.weight_load_mem_usage:.2f} GB."
        )

File: roll/third_party/vllm/worker_helper.py (L65-87)

    def setup_collective_group(self, comm_plan, backend, rank_in_cluster):
        self.model_update_comm_plan = getattr(self, "model_update_comm_plan", {})
        rank, comm_plan_args = get_dist_info_from_comm_plan(comm_plan, rank_in_cluster=rank_in_cluster,
                                                            rank_in_worker=dist.get_rank())
        if rank is None:
            logger.info(f"no comm_plan found for rank {rank_in_cluster}/{dist.get_rank()}")
            return
        group_name = comm_plan_args["group_name"]
        master_addr = comm_plan_args["master_addr"]
        master_port = comm_plan_args["master_port"]
        world_size = len(comm_plan_args["tgt_devices"]) + 1
        src_pp_rank = comm_plan_args["src_pp_rank"]
        collective.init_collective_group(world_size, rank, backend=backend, group_name=group_name,
                                         master_addr=master_addr, master_port=master_port)
        # A small all_reduce for warmup.
        collective.allreduce(torch.zeros(1).to(current_platform.device_type), group_name=group_name)
        self.model_update_comm_plan[src_pp_rank] = dict(rank=rank,
                                                        world_size=world_size,
                                                        src_pp_rank=src_pp_rank,
                                                        group_name=group_name,
                                                        comm_plan=comm_plan,
                                                        comm_plan_args=comm_plan_args)
        logger.info(f"warmup setup_collective_group: {group_name} rank: {rank} world_size: {world_size}")

File: roll/pipeline/base_worker.py (L44-61)

    @register(dispatch_mode=Dispatch.ONE_TO_ALL)
    def initialize(self, pipeline_config):
        super().initialize(pipeline_config)

        self.strategy = create_strategy(worker=self)

        if self.worker_config.model_args.model_type == "diffusion_module":
            self.strategy.initialize(model_provider=default_diffusion_module_provider)
        else:
            self.strategy.initialize(model_provider=default_actor_model_provider)
        
        self.tokenizer = self.strategy.tokenizer
        if self.pipeline_config.resume_from_checkpoint:
            load_dir = download_model(self.pipeline_config.resume_from_checkpoint)
            self.strategy.load_checkpoint(load_dir=load_dir, tag="checkpoint")
        self.logger.info(f"{self.worker_name} initialized")

        self.strategy.offload_states()

model_update 时: 临时 load actor_train 的权重,广播到 actor_infer,然后再 offload 哪里有 这个逻辑,为什么

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐