用LoRA微调Stable Diffusion 并开启 Gradient Checkpointing的方法
检查点分段:将前向传播过程划分为多个段(Segment),仅缓存每个段的输入和输出。动态重新计算:在反向传播时,对每个段重新执行前向计算,得到中间激活值(用完即弃)。显存节省:显存占用从 O(N)(N=层数)降至 O(√N)
LoRA和Gradient Checkpointing介绍
方法介绍部分由Deepseek提供
LoRA(低秩矩阵分解实现参数高效微调)
LoRA 是一种 参数高效微调(Parameter-Efficient Fine-Tuning, PEFT) 技术,核心思想是通过 低秩矩阵分解,在预训练大模型(如LLM、Stable Diffusion)的权重矩阵上添加轻量级适配层,仅训练少量参数即可实现下游任务适应。
- 数学原理
假设原始模型的某个权重矩阵为 ,LoRA 在其基础上添加两个低秩矩阵
和
:
其中:
-
是秩(Rank),控制适配层复杂度
-
微调时仅更新
和
,原始
冻结
Gradient Checkpointing(梯度检查方法降低显存)
-
检查点分段:将前向传播过程划分为多个段(Segment),仅缓存每个段的输入和输出。
-
动态重新计算:在反向传播时,对每个段重新执行前向计算,得到中间激活值(用完即弃)。
-
显存节省:显存占用从 O(N)(N=层数)降至 O(√N)
模型加载和参数设置
1. 加载预训练模型架构和参数。
pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)
2. 冻结参数,取消安全检查。
# freeze parameters of models to save more memory
pipeline.vae.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)
pipeline.unet.requires_grad_(not config.use_lora)
# disable safety checker
pipeline.safety_checker = None
3. 设置采样的scheduler(如DDIM)。
# switch to DDIM scheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
4. 设置LoRA层。
# Set correct lora layers
lora_attn_procs = {}
for name in pipeline.unet.attn_processors.keys():
cross_attention_dim = (None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim)
if name.startswith("mid_block"):
hidden_size = pipeline.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = pipeline.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
pipeline.unet.set_attn_processor(lora_attn_procs)
# this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
# this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
# `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
class _Wrapper(AttnProcsLayers):
def forward(self, *args, **kwargs):
return pipeline.unet(*args, **kwargs)
unet = _Wrapper(pipeline.unet.attn_processors)
5. 调用diffusers包中class UNet2DConditionModel自带的函数_set_gradient_checkpointing对LoRA层的模块开启 Gradient Checkpointing。
# set gradient checkpointing
pipeline.unet._set_gradient_checkpointing(unet, True)
6. 初始化优化器。
optimizer = optimizer_cls(
unet.parameters(),
lr=config.train.learning_rate,
betas=(config.train.adam_beta1, config.train.adam_beta2),
weight_decay=config.train.adam_weight_decay,
eps=config.train.adam_epsilon,
)
unet, optimizer = accelerator.prepare(unet, optimizer)
7. 开始训练部分代码
遇到如下警告
UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.
还未找到合适的解决方式,不确定是否需要解决。
更多推荐
所有评论(0)