【大模型知识蒸馏】从零开始进行模型蒸馏,利用Transformer库对Qwen2.5-32B模型进行知识蒸馏
模型蒸馏(Model Distillation)是将一个复杂模型(教师模型)的知识迁移到一个较小、效率更高的模型(学生模型)的过程。数据准备是蒸馏过程中的关键步骤,直接影响学生模型的性能。本文将以生成SQL查询的任务为例,详细介绍从数据集准备到模型蒸馏的完整流程。我们假设教师模型是一个大型语言模型(Qwen2.5-7B),学生模型是一个小型模型(Qwen2.5-0.5B)。
使用Transformer库进行知识蒸馏,以生成SQL查询的任务为例
概要
模型蒸馏(Model Distillation)是将一个复杂模型(教师模型)的知识迁移到一个较小、效率更高的模型(学生模型)的过程。数据准备是蒸馏过程中的关键步骤,直接影响学生模型的性能。
本文将以生成SQL查询的任务为例,详细介绍从数据集准备到模型蒸馏的完整流程。我们假设教师模型是一个大型语言模型(Qwen2.5-32B),学生模型是一个小型模型(Qwen2.5-3B)。
整体流程步骤
示例模型:Qwen2.5-32B,Qwen2.5-3B
任务背景
我们的目标是构建一个轻量级模型,能够根据自然语言问题生成SQL查询。例如:
输入:“本公司有多少员工?”
输出:SELECT COUNT(*) FROM employee;
大型模型在生成SQL方面表现出色,但推理速度慢、资源消耗大。通过蒸馏,我们希望训练一个小型模型,既能接近教师模型的性能,又能在边缘设备或低资源环境下运行。
数据集是模型蒸馏的核心,直接影响学生模型的性能。我们需要准备两类数据:
标注数据集:包含自然语言问题和对应的SQL查询,用于监督学习。
未标注数据集:用于生成教师模型的软标签(soft labels),传递教师模型的知识。
1.数据集准备
1.1 收集标注数据集
标注数据集由问题-SQL对组成,来源可以是:
- 内部数据:从公司数据库的查询日志中提取。
- 公开数据集:如Spider数据集(包含自然语言问题和SQL查询)。
- 手动标注:根据业务场景生成。
示例数据
我们假设有一个名为employee的表,结构如下:
CREATE TABLE employee (
id INT PRIMARY KEY,
name VARCHAR(50),
department VARCHAR(50),
hire_date DATE
);
例如:
以下是5条标注数据样本(保存为train.json):
[
{
"input": "公司有多少员工?",
"table_name": "employee",
"sql": "SELECT COUNT(*) FROM employee;"
},
{
"input": "公司有哪些部门?",
"table_name": "employee",
"sql": "SELECT DISTINCT department FROM employeea;"
},
{
"input": "2023年入职的员工有多少?",
"table_name": "employee",
"sql": "SELECT COUNT(*) FROM employee WHERE YEAR(hire_date) = 2023;"
},
{
"input": "公司研发部门的员工是谁?",
"table_name": "employee",
"sql": "SELECT name FROM employee WHERE department = '研发';"
},
{
"input": "员工入职日期最早的是谁?",
"table_name": "employee",
"sql": "SELECT name FROM employee ORDER BY hire_date ASC LIMIT 1;"
}
]
1.2 收集未标注数据集
未标注数据集仅包含自然语言问题,用于生成教师模型的软标签。来源可以是:
- 用户输入日志。
- 业务场景中的问题模板。
示例数据
以下是5条未标注数据样本(保存为unlabeled.json):
[
{"input": "公司员工总数是多少?"},
{"input": "研发部门有多少人?"},
{"input": "最近入职的员工是谁?"},
{"input": "销售部门的员工有哪些?"},
{"input": "2022年之后入职的员工有多少?"}
]
1.3 生成教师模型的软标签
我们使用教师模型(Qwen2.5-32B)对未标注数据生成软标签。软标签是模型输出的概率分布,表示对不同SQL类型的置信度。
import json
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, load_from_disk
# 用transformer库加载 tokenizer 和教师模型
teacher_model_path = "/app/home/qwen-32B" #替换成你下载的模型的路径
teacher_model = AutoModelForCausalLM.from_pretrained(
teacher_model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)
# 加载未标注数据
with open("unlabeled.json", "r") as f:
unlabeled_data = json.load(f)
# 表结构
table_structure = {
"table_name": "employee",
"columns": [
{"name": "id", "type": "INT", "comment": "主键"},
{"name": "name", "type": "VARCHAR", "comment": "员工姓名"},
{"name": "department", "type": "VARCHAR", "comment": "部门"},
{"name": "hire_date", "type": "DATE", "comment": "入职日期"}
]
}
# 生成软标签
soft_labeled_data = []
for sample in unlabeled_data:
input_text = sample["input"]
prompt = f"""
输入:{input_text}
表结构:{json.dumps(table_structure, ensure_ascii=False)}
请生成SQL查询,并返回SQL类型(aggregation, selection, join)的概率分布。
"""
response = teacher_model.generate(prompt)
soft_labeled_data.append({
"input": input_text,
"table_name": "employee_yinka",
"sql": response["sql"],
"soft_label": response["probabilities"] # 假设返回[0.9, 0.05, 0.05]
})
# 保存软标签数据
with open("soft_labeled.json", "w") as f:
json.dump(soft_labeled_data, f, ensure_ascii=False, indent=2)
1.4 验证软标签数据集
查看是否生成的json文件如下所示:
[
{
"input": "公司员工总数是多少?",
"table_name": "employee",
"sql": "SELECT COUNT(*) FROM employee;",
"soft_label": {"aggregation": 0.95, "selection": 0.03, "join": 0.02}
},
{
"input": "研发部门有多少人?",
"table_name": "employee",
"sql": "SELECT COUNT(*) FROM employee WHERE department = '研发';",
"soft_label": {"aggregation": 0.90, "selection": 0.08, "join": 0.02}
}
]
2.开始蒸馏
2.1 选择蒸馏方法
我们采用Response-based Distillation(基于响应的蒸馏),学生模型直接学习教师模型生成的SQL。此外,结合软标签,优化学生模型对SQL类型的理解。
损失函数:
- 硬标签损失:学生模型生成的SQL与教师模型SQL的交叉熵损失。
- 软标签损失:学生模型的概率分布与教师模型软标签的KL散度(Kullback-Leibler Divergence)。
2.2 准备学生模型
注意:一定要确认你的教师模型和学生模型的词汇表大小是一样的,否则无法直接蒸馏,需要更改模型结构或直接使用学生模型的Tokenizer
import json
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn.functional import kl_div, log_softmax
# 自定义数据集
class SQLDataset(Dataset):
def __init__(self, data_file, tokenizer):
with open(data_file, "r") as f:
self.data = json.load(f)
self.tokenizer = tokenizer
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
input_text = f"question: {sample['input']} table: {sample['table_name']} [id, name, department, hire_date]"
target_sql = sample["sql"]
soft_label = torch.tensor(list(sample["soft_label"].values()), dtype=torch.float32)
inputs = self.tokenizer(input_text, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
targets = self.tokenizer(target_sql, return_tensors="pt", padding="max_length", max_length=64, truncation=True)
return {
"input_ids": inputs["input_ids"].squeeze(),
"attention_mask": inputs["attention_mask"].squeeze(),
"labels": targets["input_ids"].squeeze(),
"soft_label": soft_label
}
# 加载模型和分词器
student_model_path = "/app/home/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path) # 假设与教师模型一致
student_model = AutoModelForCausalLM.from_pretrained(student_model_path)
student_model.resize_token_embeddings(tokenizer.vocab_size) # 调整学生模型的嵌入层
student_model.train() # 设置为训练模式
student_model.to(device)
# 加载数据集
train_dataset = SQLDataset("soft_labeled.json", tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
# 训练设置
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model.to(device)
# 训练循环
alpha = 0.5
temperature = 2.0
num_epochs = 3
for epoch in range(num_epochs):
student_model.train()
total_loss = 0
for batch in train_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["labels"].to(device)
soft_labels = batch["soft_label"].to(device)
# 前向传播
outputs = student_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
hard_loss = outputs.loss
# 计算软标签损失
logits = outputs.logits[:, 0, :] / temperature # 假设分类SQL类型
student_probs = log_softmax(logits, dim=-1)
soft_loss = kl_div(student_probs, soft_labels / temperature, reduction="batchmean")
# 总损失
loss = alpha * hard_loss + (1 - alpha) * soft_loss
total_loss += loss.item()
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
# 保存模型
student_model.save_pretrained("distilled_sql")
tokenizer.save_pretrained("distilled_sql")
验证结果
input_text = "question: 公司有多少员工? table: employee [id, name, department, hire_date]"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = student_model.generate(**inputs, max_length=64)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 输出:SELECT COUNT(*) FROM employee;
注意事项
确保标注数据和软标签的SQL可执行,避免错误传播;训练数据应覆盖实际使用场景(如聚合、过滤、排序) ;生成软标签可能需要GPU支持,分批处理未标注数据
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)