LoRA和Gradient Checkpointing介绍

方法介绍部分由Deepseek提供

LoRA(低秩矩阵分解实现参数高效微调)

LoRA 是一种 参数高效微调(Parameter-Efficient Fine-Tuning, PEFT) 技术,核心思想是通过 低秩矩阵分解,在预训练大模型(如LLM、Stable Diffusion)的权重矩阵上添加轻量级适配层,仅训练少量参数即可实现下游任务适应。

  • 数学原理

假设原始模型的某个权重矩阵为 W \in \mathbb{R}^{d\times k},LoRA 在其基础上添加两个低秩矩阵AB

W' = W+\Delta W = W + B \cdot A \ (A \in \mathbb{R}^{d\times r}, B \in \mathbb{R}^{r \times k}, r << min(d,k))

其中:

  • r是秩(Rank),控制适配层复杂度

  • 微调时仅更新AB,原始W冻结

Gradient Checkpointing(梯度检查方法降低显存)

  • 检查点分段:将前向传播过程划分为多个段(Segment),仅缓存每个段的输入和输出。

  • 动态重新计算:在反向传播时,对每个段重新执行前向计算,得到中间激活值(用完即弃)。

  • 显存节省:显存占用从 O(N)(N=层数)降至 O(√N)

模型加载和参数设置

1. 加载预训练模型架构和参数。

pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision)

2. 冻结参数,取消安全检查。

# freeze parameters of models to save more memory
pipeline.vae.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)
pipeline.unet.requires_grad_(not config.use_lora)

# disable safety checker
pipeline.safety_checker = None

3. 设置采样的scheduler(如DDIM)。

# switch to DDIM scheduler
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)

4. 设置LoRA层。

# Set correct lora layers
lora_attn_procs = {}
for name in pipeline.unet.attn_processors.keys():
    cross_attention_dim = (None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim)
    if name.startswith("mid_block"):
        hidden_size = pipeline.unet.config.block_out_channels[-1]
    elif name.startswith("up_blocks"):
        block_id = int(name[len("up_blocks.")])
        hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id]
    elif name.startswith("down_blocks"):
        block_id = int(name[len("down_blocks.")])
        hidden_size = pipeline.unet.config.block_out_channels[block_id]

    lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
    pipeline.unet.set_attn_processor(lora_attn_procs)
    # this is a hack to synchronize gradients properly. the module that registers the parameters we care about (in
    # this case, AttnProcsLayers) needs to also be used for the forward pass. AttnProcsLayers doesn't have a
    # `forward` method, so we wrap it to add one and capture the rest of the unet parameters using a closure.
    class _Wrapper(AttnProcsLayers):
        def forward(self, *args, **kwargs):
            return pipeline.unet(*args, **kwargs)

    unet = _Wrapper(pipeline.unet.attn_processors)

5. 调用diffusers包中class UNet2DConditionModel自带的函数_set_gradient_checkpointing对LoRA层的模块开启 Gradient Checkpointing。

# set gradient checkpointing
pipeline.unet._set_gradient_checkpointing(unet, True)

6. 初始化优化器。

optimizer = optimizer_cls(
        unet.parameters(),
        lr=config.train.learning_rate,
        betas=(config.train.adam_beta1, config.train.adam_beta2),
        weight_decay=config.train.adam_weight_decay,
        eps=config.train.adam_epsilon,
    )

unet, optimizer = accelerator.prepare(unet, optimizer)

7. 开始训练部分代码

遇到如下警告

UserWarning: torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.5 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants. 

还未找到合适的解决方式,不确定是否需要解决。 

Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐