【生成模型】【基础知识】CFG与CFG蒸馏
是生成模型推理时的一种操作,目的是使的生成图像与输入的text prompt更接近。
·
Classifier Free Guidance (CFG)
简介:是生成模型推理时的一种操作,目的是使的生成图像与输入的text prompt更接近。
问题:需要推理两次。
伪代码如下:
功能:
- 当引导系数 s=0 时,等价于无引导生成
- 当 s 增大时,生成结果更接近提示语义,但可能降低多样性
# Class-Free Guidance(CFG)伪代码实现
def sample_with_cfg(model, x, prompt, guidance_scale, num_steps=20):
"""
使用无类别引导(CFG)生成图像
参数:
model: 扩散模型
x: 初始化噪声
prompt: 文本提示
guidance_scale: 引导强度系数
num_steps: 扩散过程步数
"""
for i in range(num_steps):
timestep_embed = timestep_embed_layer(i)
# 第一遍推理: 无引导(无条件),预测无提示条件下的噪声
noise_pred_uncond = model(x, timestep_embed, context=None)
# 第二遍推理: 有引导(有条件),生成考虑提示的噪声预测
noise_pred_cond = model(x, timestep_embed, context=encode_prompt(prompt))
# CFG核心: 结合两次预测的结果,通过引导系数调整条件预测的影响
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
# 使用预测的噪声更新当前样本
x = scheduler.step(noise_pred, i, x).prev_sample
return decode_latents_to_image(x) # 将最终隐变量解码为图像
CFG-Distilled
CFG-Distilled的做法就是
- 将模型使用cfg推理的结果直接蒸馏到一个新模型上,这个新模型就不需要两次推理了。
- 同时为了新模型也能根据不同的guidance_scale产生不同结果,将guidance_scale直接embed后输入模型中,应该是为了对模型的变动最小化,因此直接加到timestep_embed上。
FLUX.1 [dev] 用到的指引蒸馏技术似乎来自论文 On Distillation of Guided Diffusion Models。
# guidance modulation
guidance_in = TimestepEmbedder(hidden_size, get activation_layer("silu"),**factory_kwargs)
def cfg_distillation_training(teacher_model, student_model, dataloader, num_epochs):
optimizer = AdamW(student_model.parameters(), lr=1e-5)
for epoch in range(num_epochs):
for batch in dataloader:
images, prompts = batch
x_t = add_noise(images, timestep=t) # 初始化噪声
guidance_scale = random_guidance_scale(min_g, max_g)
# 1. 从教师模型获取CFG目标, 教师模型执行完整CFG(两遍推理)
with torch.no_grad():
noise_pred_teacher_cfg = sample_with_cfg(teacher_model, x_t, prompts, guidance_scale)
# 2. 学生模型单次前向传播,将guidance_scale embed到模型中
timestep_embed = timestep_embed_layer(i)
guaidance_embed = guidance_in(guidance_scale)
noise_pred_student = model(x, timestep_embed + guaidance_embed, context=encode_prompt(prompt))
# 3. 计算蒸馏损失
loss = mse_loss(noise_pred_student, noise_pred_teacher_cfg)
# 4. 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
更多推荐
所有评论(0)