Classifier Free Guidance (CFG)

简介:是生成模型推理时的一种操作,目的是使的生成图像与输入的text prompt更接近。
问题:需要推理两次。
伪代码如下:
功能

  • 当引导系数 s=0 时,等价于无引导生成
  • 当 s 增大时,生成结果更接近提示语义,但可能降低多样性
# Class-Free Guidance(CFG)伪代码实现
def sample_with_cfg(model, x, prompt, guidance_scale, num_steps=20):
    """
    使用无类别引导(CFG)生成图像
    参数:
    model: 扩散模型
    x: 初始化噪声 
    prompt: 文本提示
    guidance_scale: 引导强度系数
    num_steps: 扩散过程步数
    """
    for i in range(num_steps):
    	timestep_embed = timestep_embed_layer(i)
        # 第一遍推理: 无引导(无条件),预测无提示条件下的噪声
        noise_pred_uncond = model(x, timestep_embed, context=None)
        # 第二遍推理: 有引导(有条件),生成考虑提示的噪声预测
        noise_pred_cond = model(x, timestep_embed, context=encode_prompt(prompt))
        # CFG核心: 结合两次预测的结果,通过引导系数调整条件预测的影响
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
        # 使用预测的噪声更新当前样本
        x = scheduler.step(noise_pred, i, x).prev_sample
    
    return decode_latents_to_image(x)  # 将最终隐变量解码为图像

CFG-Distilled

CFG-Distilled的做法就是

  • 将模型使用cfg推理的结果直接蒸馏到一个新模型上,这个新模型就不需要两次推理了。
  • 同时为了新模型也能根据不同的guidance_scale产生不同结果,将guidance_scale直接embed后输入模型中,应该是为了对模型的变动最小化,因此直接加到timestep_embed上。

FLUX.1 [dev] 用到的指引蒸馏技术似乎来自论文 On Distillation of Guided Diffusion Models

# guidance modulation
guidance_in = TimestepEmbedder(hidden_size, get activation_layer("silu")**factory_kwargs)

def cfg_distillation_training(teacher_model, student_model, dataloader, num_epochs):
    optimizer = AdamW(student_model.parameters(), lr=1e-5)
    
    for epoch in range(num_epochs):
        for batch in dataloader:
            images, prompts = batch
            x_t = add_noise(images, timestep=t) # 初始化噪声
            guidance_scale = random_guidance_scale(min_g, max_g)
            # 1. 从教师模型获取CFG目标, 教师模型执行完整CFG(两遍推理)
            with torch.no_grad():
                noise_pred_teacher_cfg = sample_with_cfg(teacher_model, x_t, prompts, guidance_scale)
            # 2. 学生模型单次前向传播,将guidance_scale embed到模型中
            timestep_embed = timestep_embed_layer(i)
            guaidance_embed = guidance_in(guidance_scale)
            noise_pred_student = model(x, timestep_embed + guaidance_embed, context=encode_prompt(prompt))
            # 3. 计算蒸馏损失
            loss = mse_loss(noise_pred_student, noise_pred_teacher_cfg)
            
            # 4. 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐