使用qwen2.5系列模型在特定任务上进行知识蒸馏,教师模型为qwen2.5-3b(qwen2.5-7b),学生模型为qwen2.5-0.5b,尝试只使用KL散度、微调学生模型加KL散度和不微调学生模型加交叉熵加KL散度等不同思路,并且使用了KL散度不同变种(反向KL散度,偏向前向KL散度,偏向反向KL散度)。 

        在模型蒸馏的方法中,白盒蒸馏能够利用教师模型的内部信息,如中间层的特征表示、概率分布等,从而更有效地将知识传递给学生模型。而KL散度作为一种衡量两个概率分布差异的指标,在白盒蒸馏中发挥着重要作用。通过最小化学生模型和教师模型输出概率分布之间的KL散度,可以让学生模型学习到教师模型的知识。本文将通过详细解析两段Python代码,展示如何使用白盒蒸馏的四种基于KL散度的方法对大模型进行蒸馏训练,让你从代码层面深入理解大模型蒸馏的实现过程。

一、代码讲解

1. utils.py文件

        这个文件主要实现了白盒蒸馏中四种基于KL散度的计算方法,分别是前向KL散度、反向KL散度、偏向前KL散度和偏向反KL散度。

        实现了白盒蒸馏中四种基于 KL 散度的计算方法,分别为前向 KL 散度、反向 KL 散度、偏向前 KL 散度和偏向反 KL 散度。这些方法均通过对学生模型和教师模型的对数概率进行温度平滑处理,将其转换为概率分布和对数概率分布,再依据不同的 KL 散度公式进行计算,并根据 `reduction` 参数进行缩减操作,同时考虑了填充部分的屏蔽处理。其中,偏向前和偏向反 KL 散度还引入了加权系数 `skew_lambda` 来混合学生模型和教师模型的概率分布,以实现更灵活的知识迁移。

1.1 前向KL散度计算函数compute_fkl

def compute_fkl(
        logits, 
        teacher_logits, 
        target, 
        padding_id,
        reduction="sum",
        temp = 1.0, 
        
    ):
        logits = logits / temp
        teacher_logits = teacher_logits / temp

        log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)
        teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
        teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)
        kl = (teacher_probs * (teacher_log_probs - log_probs)) 
        kl = kl.sum(-1)
        if reduction == "sum":
            pad_mask = target.eq(padding_id)
            kl = kl.masked_fill_(pad_mask, 0.0)
            kl = kl.sum()

        return kl

输入参数

   logits:学生模型的输出对数概率。在神经网络中,最后一层的输出通常是未经过归一化的对数概率值,这些值可以通过 softmax 函数转换为概率分布。

   teacher_logits:教师模型的输出对数概率。同样是未经过归一化的对数概率值,代表了教师模型对输入数据的预测结果。

   target:目标标签。在训练过程中,这是真实的标签数据,用于计算损失和屏蔽填充部分。

   padding_id:填充 ID。在自然语言处理中,为了将不同长度的序列统一成相同的长度,通常会在较短的序列后面填充特定的 ID,这里的 padding_id 就是用于识别这些填充部分的 ID。

   reduction:损失的缩减方式,默认为 sum。可以选择 sum 表示对所有元素求和,也可以选择其他方式,如 mean 表示求平均值。

   temp:温度参数,用于平滑概率分布。温度参数可以控制概率分布的平滑程度,当 temp 大于 1 时,概率分布会更加平滑;当 temp 小于 1 时,概率分布会更加尖锐。

计算过程

        1.温度平滑

        将学生模型和教师模型的对数概率除以温度参数 temp,这样可以使概率分布更加平滑,避免某些类别概率过于集中。

        2.计算概率和对数概率

    log_probs:通过 torch.log_softmax 函数将学生模型的对数概率转换为对数概率分布。softmax 函数将对数概率转换为概率分布,再取对数得到对数概率分布

    teacher_probs:通过 torch.softmax 函数将教师模型的对数概率转换为概率分布。 

    teacher_log_probs:通过 torch.log_softmax 函数将教师模型的对数概率转换为对数概率分布。

        3.计算前向 KL 散度

        根据前向 KL 散度的公式 ,其中p是教师模型的概率分布,q是学生模型的概率分布。在代码中,kl = (teacher_probs * (teacher_log_probs - log_probs)) 实现了这一计算。

        4.缩减操作

    kl = kl.sum(-1):对最后一个维度求和,得到每个样本的 KL 散度。

        如果 reduction 为 sum,则使用 pad_mask 屏蔽填充部分的计算,并对所有样本的 KL 散度求和。pad_mask = target.eq(padding_id) 用于生成一个布尔掩码,标记出填充部分;kl = kl.masked_fill_(pad_mask, 0.0) 将填充部分的 KL 散度置为 0;kl = kl.sum() 对所有样本的 KL 散度求和。

1.2 反向 KL 散度计算函数 compute_rkl

def compute_rkl(
        logits, 
        teacher_logits, 
        target, 
        padding_id,
        reduction="sum", 
        temp = 1.0
    ):
        logits = logits / temp
        teacher_logits = teacher_logits / temp

        probs = torch.softmax(logits, -1, dtype=torch.float32)
        log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)
        teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)
        kl = (probs * (log_probs - teacher_log_probs))
        kl = kl.sum(-1)
        if reduction == "sum":
            pad_mask = target.eq(padding_id)
            kl = kl.masked_fill_(pad_mask, 0.0)
            kl = kl.sum()
        return kl

输入参数:与 compute_fkl 函数相同。

计算过程

        1.温度平滑

        与 compute_fkl 函数相同,将学生模型和教师模型的对数概率除以温度参数 temp

        2.计算概率和对数概率

    probs:通过 torch.softmax 函数将学生模型的对数概率转换为概率分布。

    log_probs:通过 torch.log_softmax 函数将学生模型的对数概率转换为对数概率分布。

    teacher_log_probs:通过 torch.log_softmax 函数将教师模型的对数概率转换为对数概率分布。

        3.计算反向 KL 散度

        根据反向 KL 散度的公式 ,其中p是教师模型的概率分布, q是学生模型的概率分布。在代码中,kl = (probs * (log_probs - teacher_log_probs))实现了这一计算。

        4.缩减操作

        与 compute_fkl 函数相同,对最后一个维度求和,根据 reduction 参数进行缩减操作。

1.3 偏向前 KL 散度计算函数 compute_skewed_fkl

def compute_skewed_fkl(
        logits, 
        teacher_logits, 
        target, 
        padding_id, 
        reduction="sum", 
        temp = 1.0,
        skew_lambda = 0.1
    ):
        logits = logits / temp
        teacher_logits = teacher_logits / temp

        probs = torch.softmax(logits, -1, dtype=torch.float32)
        teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
        mixed_probs = skew_lambda * teacher_probs + (1 - skew_lambda) * probs
        mixed_log_probs = torch.log(mixed_probs)
        teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)
        kl = (teacher_probs * (teacher_log_probs - mixed_log_probs))
        kl = kl.sum(-1)
        if reduction == "sum":
            pad_mask = target.eq(padding_id)
            kl = kl.masked_fill_(pad_mask, 0.0)
            kl = kl.sum()

            
        return kl

输入参数

        除了与 compute_fkl 函数相同的参数外,还增加了 skew_lambda 参数,它是一个加权系数,用于混合学生模型和教师模型的概率分布。

计算过程

        1.温度平滑

        将学生模型和教师模型的对数概率除以温度参数 temp

        2.计算概率分布

    probs:学生模型的概率分布。

    teacher_probs:教师模型的概率分布。

        3.混合概率分布

        通过加权系数 skew_lambda 混合学生模型和教师模型的概率分布得到 mixed_probs,即 mixed_probs = skew_lambda * teacher_probs + (1 - skew_lambda) * probs

        4.计算对数概率

        计算混合概率分布的对数概率 mixed_log_probs 和教师模型的对数概率 teacher_log_probs

        5.计算偏向前 KL 散度

        根据偏向前 KL 散度的公式,计算教师模型的概率分布与混合概率分布之间的 KL 散度。

        6.缩减操作

        与 compute_fkl 函数相同,对最后一个维度求和,根据 reduction 参数进行缩减操作。

1.4 偏向反 KL 散度计算函数 compute_skewed_rkl

def compute_skewed_rkl(
    logits, 
    teacher_logits, 
    target,
    padding_id,
    reduction="sum", 
    temp = 1.0,
    skew_lambda = 0.1
):
    logits = logits / temp
    teacher_logits = teacher_logits / temp
    
    probs = torch.softmax(logits, -1, dtype=torch.float32)
    teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
    mixed_probs = (1 - skew_lambda) * teacher_probs + skew_lambda * probs
    mixed_log_probs = torch.log(mixed_probs)
    log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)
    kl = (probs * (log_probs - mixed_log_probs))
    kl = kl.sum(-1)
    
    if reduction == "sum":
        pad_mask = target.eq(padding_id)
        kl = kl.masked_fill_(pad_mask, 0.0)
        kl = kl.sum()
    return kl

输入参数:与 compute_skewed_fkl 函数相同。

计算过程

        1.温度平滑

        将学生模型和教师模型的对数概率除以温度参数 temp

        2.计算概率分布

        计算学生模型的概率分布 probs 和教师模型的概率分布 teacher_probs

        3.混合概率分布

        通过加权系数 skew_lambda 混合学生模型和教师模型的概率分布得到 mixed_probs,但加权方式与 compute_skewed_fkl 函数不同,即 mixed_probs = (1 - skew_lambda) * teacher_probs + skew_lambda * probs

        4.计算对数概率

        计算混合概率分布的对数概率 mixed_log_probs 和学生模型的对数概率 log_probs

        5.计算偏向反 KL 散度

        根据偏向反 KL 散度的公式,计算学生模型的概率分布与混合概率分布之间的 KL 散度。

        6.缩减操作

        与 compute_fkl 函数相同,对最后一个维度求和,根据 reduction 参数进行缩减操作。

2.train.py 文件

        这个文件实现了模型的训练过程,使用 KGTrainer 类继承自 Trainer 类,自定义了损失计算方法。

        定义了 KGTrainer类,继承自 Trainer`类,自定义了损失计算方法。在 `compute_loss` 方法中,计算学生模型和教师模型的输出,处理输出形状不匹配的问题,调用 `compute_fkl` 函数计算前向 KL 散度,并根据 `if_use_entropy` 参数决定是否结合学生模型的损失。主程序部分加载了学生模型 `Qwen2.5 - 0.5B - Instruct` 和教师模型 `Qwen2.5 - 7B - Instruct`,配置了训练参数,加载数据集,创建 `KGTrainer` 实例进行训练,并保存模型和训练状态。

2.1 KGTrainer 类

class KGTrainer(Trainer):
    
    def __init__(
        self,
        model = None,
        teacher_model = None,
        if_use_entropy = False,
        args = None,
        data_collator = None, 
        train_dataset = None,
        eval_dataset = None,
        tokenizer = None,
        model_init = None, 
        compute_metrics = None, 
        callbacks = None,
        optimizers = (None, None), 
        preprocess_logits_for_metrics = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )
        self.teacher_model = teacher_model
        self.if_use_entropy = if_use_entropy
        
    
    def compute_loss(self, model, inputs, return_outputs=False):
        
        outputs = model(**inputs)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
        
        loss = outputs.loss
        logits = outputs.logits
        teacher_logits = teacher_outputs.logits
        
        # 如果教师模型和学生模型输出形状不匹配,对学生模型进行padding或对教师模型进行截断
        if logits.shape[-1] != teacher_logits.shape[-1]:
            teacher_logits = teacher_logits[:, :, :logits.shape[-1]]
        
        labels = inputs['labels']
        kl = compute_fkl(logits, teacher_logits, labels, padding_id=-100, temp=2.0)
        
        if self.if_use_entropy:
            loss_total = 0.5 * kl + 0.5 * loss
        else:
            loss_total = kl
        
        return (loss_total, outputs) if return_outputs else loss_total
  •  __init__ 方法
    • 初始化 KGTrainer 类,继承自 Trainer 类。
    • 接收学生模型 model、教师模型 teacher_model、是否使用熵损失 if_use_entropy 等参数。
    • 将教师模型和是否使用熵损失的标志保存为类的属性。
  • compute_loss 方法
    • 计算模型输出
      • outputs = model(**inputs):计算学生模型的输出。
      • teacher_outputs = self.teacher_model(**inputs):计算教师模型的输出,使用 torch.no_grad() 上下文管理器,避免计算教师模型的梯度,因为教师模型不需要更新参数。
    • 获取对数概率和损失
      • loss = outputs.loss:获取学生模型的损失。
      • logits = outputs.logits:获取学生模型的对数概率。
      • teacher_logits = teacher_outputs.logits:获取教师模型的对数概率。
    • 处理输出形状不匹配问题
      • 如果学生模型和教师模型的输出形状不匹配,对教师模型的输出进行截断,使其与学生模型的输出形状一致。
    • 计算 KL 散度
      • kl = compute_fkl(logits, teacher_logits, labels, padding_id=-100, temp=2.0):调用 compute_fkl 函数计算前向 KL 散度。
    • 计算总损失
      • 如果 if_use_entropy 为 True,则总损失为 KL 散度和学生模型损失的加权和,权重均为 0.5。
      • 如果 if_use_entropy 为 False,则总损失为 KL 散度。
    • 返回损失:根据 return_outputs 参数,返回总损失或总损失和学生模型的输出。

2.2 主程序

if __name__ == '__main__':
    
    # 学生模型
    model = AutoModelForCausalLM.from_pretrained("Qwen2.5-0.5B-Instruct")
    model.cuda()
    print(model.print_trainable_parameters())
    
    tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-Instruct")
    
    # 教师模型,在给定数据上通过lora微调
    teacher_model = AutoModelForCausalLM.from_pretrained("Qwen2.5-7B-Instruct")
    teacher_model.cuda()
    teacher_model.eval()
    
    args = TrainingArguments(output_dir='./results', 
                            num_train_epochs=10, 
                            do_train=True, 
                            per_device_train_batch_size=2,
                            gradient_accumulation_steps=16,
                            logging_steps=10,
                            report_to='tensorboard',
                            save_strategy='epoch',
                            save_total_limit=10,
                            bf16=True,
                            learning_rate=0.0005,
                            lr_scheduler_type='cosine',
                            dataloader_num_workers=8,
                            dataloader_pin_memory=True)
    data_collator = DefaultDataCollator()
    dataset = SFTDataset('data.json', tokenizer=tokenizer, max_seq_len=512)
    trainer = KGTrainer(model=model,
                        teacher_model=teacher_model, 
                        if_use_entropy = True,
                        args=args, 
                        train_dataset=dataset, 
                        tokenizer=tokenizer, 
                        data_collator=data_collator)
    # 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True
    trainer.train(resume_from_checkpoint=False)
    trainer.save_model('./saves')
    trainer.save_state()
  • 加载学生模型 Qwen2.5-0.5B-Instruct 和教师模型 Qwen2.5-7B-Instruct,并将它们移动到 GPU 上。
  • 配置训练参数 TrainingArguments,包括训练轮数、批次大小、学习率等。
  • 加载数据集 SFTDataset,使用 DefaultDataCollator 进行数据整理。
  • 创建 KGTrainer 实例,调用 train 方法进行训练,并保存模型和训练状态。

二、总体代码

2.1  数据处理(dataset.py)

import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import os
import pandas as pd

from torch.utils.data import IterableDataset, Dataset
import json
import numpy as np
from transformers import  PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import PretrainedConfig
from transformers import Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator, DataCollatorForTokenClassification, AutoConfig

class SFTDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_seq_len):
        super().__init__()
        self.data_path = data_path
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.padding_id = tokenizer.pad_token_id
        with open(self.data_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
            
    def __len__(self):
        return len(self.data)    
    
    def __getitem__(self, index):
        line = self.data[index]
        instruction_text = line['instruction']
        input_text = line['input']
        output_text = line['output']
        query = instruction_text + input_text
        answer = output_text + self.tokenizer.eos_token
        messages = []
        messages.append({'role': 'user', 'content': query})   
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) 
        
        prompt_input_ids = self.tokenizer.encode(prompt)
        answer_input_ids = self.tokenizer.encode(answer)
        
        input_ids = prompt_input_ids + answer_input_ids
        labels = [-100] * len(prompt_input_ids) + answer_input_ids
        attention_mask = [1] * len(input_ids)
        text_len = len(input_ids)
        
        if text_len > self.max_seq_len:
            input_ids = input_ids[:self.max_seq_len]
            labels = labels[:self.max_seq_len]
            attention_mask = attention_mask[:self.max_seq_len]
        else:
            input_ids = input_ids + [self.tokenizer.pad_token_id] * (self.max_seq_len - text_len)
            labels = labels + [-100] * (self.max_seq_len - text_len)
            attention_mask = attention_mask + [0] * (self.max_seq_len - text_len)
        
        # input_ids = input_ids[:-1]
        # labels = labels[1:]
        return {'input_ids': torch.tensor(input_ids), 'attention_mask':torch.tensor(attention_mask), 'labels': torch.tensor(labels)}

2.2 utils.py

import torch

# 计算前向kl散度
def compute_fkl(
        logits, 
        teacher_logits, 
        target, 
        padding_id,
        reduction="sum",
        temp = 1.0, 
        
    ):
        logits = logits / temp
        teacher_logits = teacher_logits / temp

        log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)
        teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
        teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)
        kl = (teacher_probs * (teacher_log_probs - log_probs)) 
        kl = kl.sum(-1)
        if reduction == "sum":
            pad_mask = target.eq(padding_id)
            kl = kl.masked_fill_(pad_mask, 0.0)
            kl = kl.sum()

        return kl
# 计算反向kl散度
def compute_rkl(
        logits, 
        teacher_logits, 
        target, 
        padding_id,
        reduction="sum", 
        temp = 1.0
    ):
        logits = logits / temp
        teacher_logits = teacher_logits / temp

        probs = torch.softmax(logits, -1, dtype=torch.float32)
        log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)
        teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)
        kl = (probs * (log_probs - teacher_log_probs))
        kl = kl.sum(-1)
        if reduction == "sum":
            pad_mask = target.eq(padding_id)
            kl = kl.masked_fill_(pad_mask, 0.0)
            kl = kl.sum()
        return kl

# 计算偏向前kl散度
def compute_skewed_fkl(
        logits, 
        teacher_logits, 
        target, 
        padding_id, 
        reduction="sum", 
        temp = 1.0,
        skew_lambda = 0.1
    ):
        logits = logits / temp
        teacher_logits = teacher_logits / temp

        probs = torch.softmax(logits, -1, dtype=torch.float32)
        teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
        mixed_probs = skew_lambda * teacher_probs + (1 - skew_lambda) * probs
        mixed_log_probs = torch.log(mixed_probs)
        teacher_log_probs = torch.log_softmax(teacher_logits, -1, dtype=torch.float32)
        kl = (teacher_probs * (teacher_log_probs - mixed_log_probs))
        kl = kl.sum(-1)
        if reduction == "sum":
            pad_mask = target.eq(padding_id)
            kl = kl.masked_fill_(pad_mask, 0.0)
            kl = kl.sum()

            
        return kl
# 计算偏向反kl散度    
def compute_skewed_rkl(
    logits, 
    teacher_logits, 
    target,
    padding_id,
    reduction="sum", 
    temp = 1.0,
    skew_lambda = 0.1
):
    logits = logits / temp
    teacher_logits = teacher_logits / temp
    
    probs = torch.softmax(logits, -1, dtype=torch.float32)
    teacher_probs = torch.softmax(teacher_logits, -1, dtype=torch.float32)
    mixed_probs = (1 - skew_lambda) * teacher_probs + skew_lambda * probs
    mixed_log_probs = torch.log(mixed_probs)
    log_probs = torch.log_softmax(logits, -1, dtype=torch.float32)
    kl = (probs * (log_probs - mixed_log_probs))
    kl = kl.sum(-1)
    
    if reduction == "sum":
        pad_mask = target.eq(padding_id)
        kl = kl.masked_fill_(pad_mask, 0.0)
        kl = kl.sum()


    return kl

 2.3 train.py

from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator
from peft import LoraConfig, get_peft_model, TaskType
from peft import PeftModel
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments
from dataset import SFTDataset
from utils import compute_fkl, compute_rkl, compute_skewed_fkl, compute_skewed_rkl


class KGTrainer(Trainer):
    
    def __init__(
        self,
        model = None,
        teacher_model = None,
        if_use_entropy = False,
        args = None,
        data_collator = None, 
        train_dataset = None,
        eval_dataset = None,
        tokenizer = None,
        model_init = None, 
        compute_metrics = None, 
        callbacks = None,
        optimizers = (None, None), 
        preprocess_logits_for_metrics = None,
    ):
        super().__init__(
            model,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            tokenizer,
            model_init,
            compute_metrics,
            callbacks,
            optimizers,
            preprocess_logits_for_metrics,
        )
        self.teacher_model = teacher_model
        self.if_use_entropy = if_use_entropy
        
    
    def compute_loss(self, model, inputs, return_outputs=False):
        
        outputs = model(**inputs)
        with torch.no_grad():
            teacher_outputs = self.teacher_model(**inputs)
        
        loss = outputs.loss
        logits = outputs.logits
        teacher_logits = teacher_outputs.logits
        
        # 如果教师模型和学生模型输出形状不匹配,对学生模型进行padding或对教师模型进行截断
        if logits.shape[-1] != teacher_logits.shape[-1]:
            # gap = teacher_logits.shape[-1] - logits.shape[-1]
            # if gap > 0:
            #     pad_logits = torch.zeros((logits.shape[0], logits.shape[1], gap)).to(logits.device)
            #     logits = torch.cat([logits, pad_logits], dim=-1)
            
            teacher_logits = teacher_logits[:, :, :logits.shape[-1]]
        
        labels = inputs['labels']
        kl = compute_fkl(logits, teacher_logits, labels, padding_id=-100, temp=2.0)
        
        if self.if_use_entropy:
            loss_total = 0.5 * kl + 0.5 * loss
        else:
            loss_total = kl
        
        return (loss_total, outputs) if return_outputs else loss_total
        

if __name__ == '__main__':
    
    # 学生模型
    model = AutoModelForCausalLM.from_pretrained("Qwen2.5-0.5B-Instruct")
    model.cuda()
    print(model.print_trainable_parameters())
    
    tokenizer = AutoTokenizer.from_pretrained("Qwen2.5-0.5B-Instruct")
    
    # 教师模型,在给定数据上通过lora微调
    teacher_model = AutoModelForCausalLM.from_pretrained("Qwen2.5-7B-Instruct")
    teacher_model.cuda()
    teacher_model.eval()
    
    args = TrainingArguments(output_dir='./results', 
                            num_train_epochs=10, 
                            do_train=True, 
                            per_device_train_batch_size=2,
                            gradient_accumulation_steps=16,
                            logging_steps=10,
                            report_to='tensorboard',
                            save_strategy='epoch',
                            save_total_limit=10,
                            bf16=True,
                            learning_rate=0.0005,
                            lr_scheduler_type='cosine',
                            dataloader_num_workers=8,
                            dataloader_pin_memory=True)
    data_collator = DefaultDataCollator()
    dataset = SFTDataset('data.json', tokenizer=tokenizer, max_seq_len=512)
    trainer = KGTrainer(model=model,
                        teacher_model=teacher_model, 
                        if_use_entropy = True,
                        args=args, 
                        train_dataset=dataset, 
                        tokenizer=tokenizer, 
                        data_collator=data_collator)
    # 如果是初次训练resume_from_checkpoint为false,接着checkpoint继续训练,为True
    trainer.train(resume_from_checkpoint=False)
    trainer.save_model('./saves')
    trainer.save_state()

总结

  1. 模型优化:模型蒸馏技术借助白盒蒸馏方法,能够让小模型(学生模型)有效地学习大模型(教师模型)的知识,从而在保持较高性能的同时,显著降低计算成本和存储需求,提高模型在资源受限环境中的部署效率。
  2. 灵活性与可扩展性:四种基于 KL 散度的计算方法为模型蒸馏提供了多种选择,研究人员和开发者可以根据具体任务和数据特点,灵活调整参数,如温度参数 temp 和加权系数 skew_lambda,以达到最佳的蒸馏效果。同时,KGTrainer 类的自定义损失计算方法也为后续的扩展和改进提供了便利。
  3. 推动技术发展:通过不断优化蒸馏方法和训练策略,能够进一步挖掘大模型的潜力,让小模型更好地学习大模型的知识,为自然语言处理、计算机视觉等领域的发展提供有力支持,推动人工智能技术在更多场景中的应用。
    Logo

    火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

    更多推荐