使用Transformer库进行知识蒸馏,以生成SQL查询的任务为例

概要

模型蒸馏(Model Distillation)是将一个复杂模型(教师模型)的知识迁移到一个较小、效率更高的模型(学生模型)的过程。数据准备是蒸馏过程中的关键步骤,直接影响学生模型的性能。
本文将以生成SQL查询的任务为例,详细介绍从数据集准备到模型蒸馏的完整流程。我们假设教师模型是一个大型语言模型(Qwen2.5-32B),学生模型是一个小型模型(Qwen2.5-3B)。

整体流程步骤

示例模型:Qwen2.5-32B,Qwen2.5-3B

任务背景

我们的目标是构建一个轻量级模型,能够根据自然语言问题生成SQL查询。例如:

输入:“本公司有多少员工?”
输出:SELECT COUNT(*) FROM employee;

大型模型在生成SQL方面表现出色,但推理速度慢、资源消耗大。通过蒸馏,我们希望训练一个小型模型,既能接近教师模型的性能,又能在边缘设备或低资源环境下运行。

数据集是模型蒸馏的核心,直接影响学生模型的性能。我们需要准备两类数据:
标注数据集:包含自然语言问题和对应的SQL查询,用于监督学习。
未标注数据集:用于生成教师模型的软标签(soft labels),传递教师模型的知识。

1.数据集准备

1.1 收集标注数据集

标注数据集由问题-SQL对组成,来源可以是:

  • 内部数据:从公司数据库的查询日志中提取。
  • 公开数据集:如Spider数据集(包含自然语言问题和SQL查询)。
  • 手动标注:根据业务场景生成。

示例数据
我们假设有一个名为employee的表,结构如下:

CREATE TABLE employee (
    id INT PRIMARY KEY,
    name VARCHAR(50),
    department VARCHAR(50),
    hire_date DATE
);

例如:
以下是5条标注数据样本(保存为train.json):

[
    {
        "input": "公司有多少员工?",
        "table_name": "employee",
        "sql": "SELECT COUNT(*) FROM employee;"
    },
    {
        "input": "公司有哪些部门?",
        "table_name": "employee",
        "sql": "SELECT DISTINCT department FROM employeea;"
    },
    {
        "input": "2023年入职的员工有多少?",
        "table_name": "employee",
        "sql": "SELECT COUNT(*) FROM employee WHERE YEAR(hire_date) = 2023;"
    },
    {
        "input": "公司研发部门的员工是谁?",
        "table_name": "employee",
        "sql": "SELECT name FROM employee WHERE department = '研发';"
    },
    {
        "input": "员工入职日期最早的是谁?",
        "table_name": "employee",
        "sql": "SELECT name FROM employee ORDER BY hire_date ASC LIMIT 1;"
    }
]
1.2 收集未标注数据集

未标注数据集仅包含自然语言问题,用于生成教师模型的软标签。来源可以是:

  • 用户输入日志。
  • 业务场景中的问题模板。

示例数据
以下是5条未标注数据样本(保存为unlabeled.json):

[
    {"input": "公司员工总数是多少?"},
    {"input": "研发部门有多少人?"},
    {"input": "最近入职的员工是谁?"},
    {"input": "销售部门的员工有哪些?"},
    {"input": "2022年之后入职的员工有多少?"}
]
1.3 生成教师模型的软标签

我们使用教师模型(Qwen2.5-32B)对未标注数据生成软标签。软标签是模型输出的概率分布,表示对不同SQL类型的置信度。

import json
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset, load_from_disk

# 用transformer库加载 tokenizer 和教师模型
teacher_model_path = "/app/home/qwen-32B" #替换成你下载的模型的路径
teacher_model = AutoModelForCausalLM.from_pretrained(
    teacher_model_path,
    torch_dtype=torch.bfloat16,
    device_map="auto"
).eval()
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)

# 加载未标注数据
with open("unlabeled.json", "r") as f:
    unlabeled_data = json.load(f)

# 表结构
table_structure = {
    "table_name": "employee",
    "columns": [
        {"name": "id", "type": "INT", "comment": "主键"},
        {"name": "name", "type": "VARCHAR", "comment": "员工姓名"},
        {"name": "department", "type": "VARCHAR", "comment": "部门"},
        {"name": "hire_date", "type": "DATE", "comment": "入职日期"}
    ]
}

# 生成软标签
soft_labeled_data = []
for sample in unlabeled_data:
    input_text = sample["input"]
    prompt = f"""
    输入:{input_text}
    表结构:{json.dumps(table_structure, ensure_ascii=False)}
    请生成SQL查询,并返回SQL类型(aggregation, selection, join)的概率分布。
    """
    response = teacher_model.generate(prompt)
    soft_labeled_data.append({
        "input": input_text,
        "table_name": "employee_yinka",
        "sql": response["sql"],
        "soft_label": response["probabilities"]  # 假设返回[0.9, 0.05, 0.05]
    })

# 保存软标签数据
with open("soft_labeled.json", "w") as f:
    json.dump(soft_labeled_data, f, ensure_ascii=False, indent=2)
1.4 验证软标签数据集

查看是否生成的json文件如下所示:

[
    {
        "input": "公司员工总数是多少?",
        "table_name": "employee",
        "sql": "SELECT COUNT(*) FROM employee;",
        "soft_label": {"aggregation": 0.95, "selection": 0.03, "join": 0.02}
    },
    {
        "input": "研发部门有多少人?",
        "table_name": "employee",
        "sql": "SELECT COUNT(*) FROM employee WHERE department = '研发';",
        "soft_label": {"aggregation": 0.90, "selection": 0.08, "join": 0.02}
    }
]

2.开始蒸馏

2.1 选择蒸馏方法

我们采用Response-based Distillation(基于响应的蒸馏),学生模型直接学习教师模型生成的SQL。此外,结合软标签,优化学生模型对SQL类型的理解。
损失函数:

  • 硬标签损失:学生模型生成的SQL与教师模型SQL的交叉熵损失。
  • 软标签损失:学生模型的概率分布与教师模型软标签的KL散度(Kullback-Leibler Divergence)。
2.2 准备学生模型

注意:一定要确认你的教师模型和学生模型的词汇表大小是一样的,否则无法直接蒸馏,需要更改模型结构或直接使用学生模型的Tokenizer

import json
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import Dataset, DataLoader
from torch.nn import CrossEntropyLoss
from torch.nn.functional import kl_div, log_softmax

# 自定义数据集
class SQLDataset(Dataset):
    def __init__(self, data_file, tokenizer):
        with open(data_file, "r") as f:
            self.data = json.load(f)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        input_text = f"question: {sample['input']} table: {sample['table_name']} [id, name, department, hire_date]"
        target_sql = sample["sql"]
        soft_label = torch.tensor(list(sample["soft_label"].values()), dtype=torch.float32)

        inputs = self.tokenizer(input_text, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
        targets = self.tokenizer(target_sql, return_tensors="pt", padding="max_length", max_length=64, truncation=True)

        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": targets["input_ids"].squeeze(),
            "soft_label": soft_label
        }

# 加载模型和分词器
student_model_path = "/app/home/Qwen2.5-0.5B"
tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)  # 假设与教师模型一致
student_model = AutoModelForCausalLM.from_pretrained(student_model_path)
student_model.resize_token_embeddings(tokenizer.vocab_size)  # 调整学生模型的嵌入层
student_model.train()  # 设置为训练模式
student_model.to(device)

# 加载数据集
train_dataset = SQLDataset("soft_labeled.json", tokenizer)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# 训练设置
optimizer = torch.optim.AdamW(student_model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
student_model.to(device)

# 训练循环
alpha = 0.5
temperature = 2.0
num_epochs = 3

for epoch in range(num_epochs):
    student_model.train()
    total_loss = 0
    for batch in train_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        soft_labels = batch["soft_label"].to(device)

        # 前向传播
        outputs = student_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        hard_loss = outputs.loss

        # 计算软标签损失
        logits = outputs.logits[:, 0, :] / temperature  # 假设分类SQL类型
        student_probs = log_softmax(logits, dim=-1)
        soft_loss = kl_div(student_probs, soft_labels / temperature, reduction="batchmean")

        # 总损失
        loss = alpha * hard_loss + (1 - alpha) * soft_loss
        total_loss += loss.item()

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")

# 保存模型
student_model.save_pretrained("distilled_sql")
tokenizer.save_pretrained("distilled_sql")

验证结果

input_text = "question: 公司有多少员工? table: employee [id, name, department, hire_date]"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
outputs = student_model.generate(**inputs, max_length=64)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
# 输出:SELECT COUNT(*) FROM employee;

注意事项

确保标注数据和软标签的SQL可执行,避免错误传播;训练数据应覆盖实际使用场景(如聚合、过滤、排序) ;生成软标签可能需要GPU支持,分批处理未标注数据

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐