前言

Hello,大家好,我是GISer Liu😁,一名热爱AI技术的GIS开发者,本系列文章是作者参加DataWhale2025年1月份学习赛,旨在讲解Transformer模型的理论和实践。😲

Transformer模型自2017年由Vaswani等人提出以来,已经成为自然语言处理(NLP)领域的重要基石。与传统的RNN和CNN不同,Transformer完全依赖于自注意力机制(Self-Attention)来捕捉输入序列中的全局依赖关系。本文将详细介绍如何使用PyTorch实现一个完整的Transformer模型,并对其进行训练和评估。


一、数据准备与预处理

1. 数据格式

我们使用一个简单的翻译任务作为示例,输入是中文句子,输出是英文句子。每个句子由多个单词组成,单词通过空格分隔。为了处理不同长度的句子,我们使用占位符P来填充不足的部分。


# Encoder_input    Decoder_input          Decoder_output(预测下一个字符)

sentences = [

    ['我 是 学 生 P', 'S I am a student', 'I am a student E'],  # S: 开始符号, E: 结束符号

    ['我 喜 欢 学 习', 'S I like learning P', 'I like learning P E'],

    ['我 是 男 生 P', 'S I am a boy', 'I am a boy E']

]

2. 构建词汇表

我们需要为输入(中文)和输出(英文)分别构建词汇表,并将单词映射到唯一的索引。


# 中文词汇表

src_vocab = {'P': 0, '我': 1, '是': 2, '学': 3, '生': 4, '喜': 5, '欢': 6, '习': 7, '男': 8}

src_idx2word = {src_vocab[key]: key for key in src_vocab}

src_vocab_size = len(src_vocab)



# 英文词汇表

tgt_vocab = {'S': 0, 'E': 1, 'P': 2, 'I': 3, 'am': 4, 'a': 5, 'student': 6, 'like': 7, 'learning': 8, 'boy': 9}

idx2word = {tgt_vocab[key]: key for key in tgt_vocab}

tgt_vocab_size = len(tgt_vocab)

3. 数据转换

将句子转换为索引序列,方便模型处理。


def make_data(sentences):

    enc_inputs, dec_inputs, dec_outputs = [], [], []

    for i in range(len(sentences)):

        enc_input = [[src_vocab[n] for n in sentences[i][0].split()]]  # Encoder输入

        dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]]  # Decoder输入

        dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]]  # Decoder输出

        enc_inputs.extend(enc_input)

        dec_inputs.extend(dec_input)

        dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)



enc_inputs, dec_inputs, dec_outputs = make_data(sentences)

print(enc_inputs)  # Encoder输入

print(dec_inputs)  # Decoder输入

print(dec_outputs)  # Decoder输出

输出:


tensor([[1, 2, 3, 4, 0],

        [1, 5, 6, 3, 7],

        [1, 2, 8, 4, 0]])

tensor([[0, 3, 4, 5, 6],

        [0, 3, 7, 8, 2],

        [0, 3, 4, 5, 9]])

tensor([[3, 4, 5, 6, 1],

        [3, 7, 8, 2, 1],

        [3, 4, 5, 9, 1]])

4. 数据加载器

使用PyTorch的DataLoader将数据分批加载。


class MyDataSet(Data.Dataset):

    def __init__(self, enc_inputs, dec_inputs, dec_outputs):

        super(MyDataSet, self).__init__()

        self.enc_inputs = enc_inputs

        self.dec_inputs = dec_inputs

        self.dec_outputs = dec_outputs



    def __len__(self):

        return self.enc_inputs.shape[0]



    def __getitem__(self, idx):

        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]



loader = Data.DataLoader(MyDataSet(enc_inputs, dec_inputs, dec_outputs), batch_size=2, shuffle=False)


二、位置编码(Positional Encoding)

由于Transformer模型不包含任何循环或卷积结构,因此需要一种方法来注入序列的位置信息。位置编码通过将正弦和余弦函数应用于不同频率的序列位置来实现这一点。


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):

        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        pos_table = np.array([

            [pos / np.power(10000, 2 * i / d_model) for i in range(d_model)]

            if pos != 0 else np.zeros(d_model) for pos in range(max_len)

        ])

        pos_table[1:, 0::2] = np.sin(pos_table[1:, 0::2])  # 偶数位置使用正弦函数

        pos_table[1:, 1::2] = np.cos(pos_table[1:, 1::2])  # 奇数位置使用余弦函数

        self.pos_table = torch.FloatTensor(pos_table)  # [max_len, d_model]



    def forward(self, enc_inputs):

        enc_inputs += self.pos_table[:enc_inputs.size(1), :]  # 添加位置编码

        return self.dropout(enc_inputs)

代码解释
  • d_model:模型的维度。

  • max_len:序列的最大长度。

  • pos_table:位置编码矩阵,形状为 (max_len, d_model)

  • forward:将位置编码添加到输入序列中。


三、自注意力机制(Self-Attention)

自注意力机制是Transformer的核心组件,它允许模型在处理每个单词时关注输入序列中的其他单词。

1. 缩放点积注意力

class ScaledDotProductAttention(nn.Module):

    def __init__(self):

        super(ScaledDotProductAttention, self).__init__()



    def forward(self, Q, K, V, attn_mask):

        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # 计算注意力分数

        scores.masked_fill_(attn_mask, -1e9)  # 掩码处理

        attn = nn.Softmax(dim=-1)(scores)  # 计算注意力权重

        context = torch.matmul(attn, V)  # 加权求和

        return context, attn

2. 多头注意力机制

class MultiHeadAttention(nn.Module):

    def __init__(self):

        super(MultiHeadAttention, self).__init__()

        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)

        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)

        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)

        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)



    def forward(self, input_Q, input_K, input_V, attn_mask):

        residual, batch_size = input_Q, input_Q.size(0)

        Q = self.W_Q(input_Q).view(batch_size, -1, n_heads, d_k).transpose(1, 2)

        K = self.W_K(input_K).view(batch_size, -1, n_heads, d_k).transpose(1, 2)

        V = self.W_V(input_V).view(batch_size, -1, n_heads, d_v).transpose(1, 2)

        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

        context, attn = ScaledDotProductAttention()(Q, K, V, attn_mask)

        context = context.transpose(1, 2).reshape(batch_size, -1, n_heads * d_v)

        output = self.fc(context)

        return nn.LayerNorm(d_model)(output + residual), attn


四、前馈神经网络(Feed Forward Network)

前馈神经网络由两个线性变换和一个ReLU激活函数组成,用于进一步处理注意力机制的输出。


class FF(nn.Module):

    def __init__(self):

        super(FF, self).__init__()

        self.fc = nn.Sequential(

            nn.Linear(d_model, d_ff, bias=False),

            nn.ReLU(),

            nn.Linear(d_ff, d_model, bias=False)

        )



    def forward(self, inputs):

        residual = inputs

        output = self.fc(inputs)

        return nn.LayerNorm(d_model)(output + residual)


五、Encoder与Decoder

1. Encoder

class EncoderLayer(nn.Module):

    def __init__(self):

        super(EncoderLayer, self).__init__()

        self.enc_self_attn = MultiHeadAttention()

        self.pos_ffn = FF()



    def forward(self, enc_inputs, enc_self_attn_mask):

        enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)

        enc_outputs = self.pos_ffn(enc_outputs)

        return enc_outputs, attn



class Encoder(nn.Module):

    def __init__(self):

        super(Encoder, self).__init__()

        self.src_emb = nn.Embedding(src_vocab_size, d_model)

        self.pos_emb = PositionalEncoding(d_model)

        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])



    def forward(self, enc_inputs):

        enc_outputs = self.src_emb(enc_inputs)

        enc_outputs = self.pos_emb(enc_outputs.transpose(0, 1)).transpose(0, 1)

        enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)

        enc_self_attns = []

        for layer in self.layers:

            enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)

            enc_self_attns.append(enc_self_attn)

        return enc_outputs, enc_self_attns

2. Decoder

class DecoderLayer(nn.Module):

    def __init__(self):

        super(DecoderLayer, self).__init__()

        self.dec_self_attn = MultiHeadAttention()

        self.dec_enc_attn = MultiHeadAttention()

        self.pos_ffn = FF()



    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):

        dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)

        dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)

        dec_outputs = self.pos_ffn(dec_outputs)

        return dec_outputs, dec_self_attn, dec_enc_attn



class Decoder(nn.Module):

    def __init__(self):

        super(Decoder, self).__init__()

        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)

        self.pos_emb = PositionalEncoding(d_model)

        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])



    def forward(self, dec_inputs, enc_inputs, enc_outputs):

        dec_outputs = self.tgt_emb(dec_inputs)

        dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1)

        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs)

        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs)

        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0)

        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)

        dec_self_attns, dec_enc_attns = [], []

        for layer in self.layers:

            dec_outputs, dec_self_attn, dec_enc_attn = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

            dec_self_attns.append(dec_self_attn)

            dec_enc_attns.append(dec_enc_attn)

        return dec_outputs, dec_self_attns, dec_enc_attns


六、Transformer模型


class Transformer(nn.Module):

    def __init__(self):

        super(Transformer, self).__init__()

        self.Encoder = Encoder()

        self.Decoder = Decoder()

        self.projection = nn.Linear(d_model, tgt_vocab_size, bias=False)



    def forward(self, enc_inputs, dec_inputs):

        enc_outputs, enc_self_attns = self.Encoder(enc_inputs)

        dec_outputs, dec_self_attns, dec_enc_attns = self.Decoder(dec_inputs, enc_inputs, enc_outputs)

        dec_logits = self.projection(dec_outputs)

        dec_logits = dec_logits.view(-1, dec_logits.size(-1))

        return dec_logits, enc_self_attns, dec_self_attns, dec_enc_attns


七、训练与评估

1. 训练

model = Transformer()

criterion = nn.CrossEntropyLoss(ignore_index=0)

optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)



for epoch in range(50):

    for enc_inputs, dec_inputs, dec_outputs in loader:

        outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)

        loss = criterion(outputs, dec_outputs.view(-1))

        print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

2. 测试

def test(model, enc_input, start_symbol):

    enc_outputs, enc_self_attns = model.Encoder(enc_input)

    dec_input = torch.zeros(1, tgt_len).type_as(enc_input.data)

    next_symbol = start_symbol

    for i in range(0, tgt_len):

        dec_input[0][i] = next_symbol

        dec_outputs, _, _ = model.Decoder(dec_input, enc_input, enc_outputs)

        projected = model.projection(dec_outputs)

        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]

        next_word = prob.data[i]

        next_symbol = next_word.item()

    return dec_input



enc_inputs, _, _ = next(iter(loader))

predict_dec_input = test(model, enc_inputs[1].view(1, -1), start_symbol=tgt_vocab["S"])

predict, _, _, _ = model(enc_inputs[1].view(1, -1), predict_dec_input)

predict = predict.data.max(1, keepdim=True)[1]

print([src_idx2word[int(i)] for i in enc_inputs[1]], '->', [idx2word[n.item()] for n in predict.squeeze()])


总结

本文详细介绍了如何使用PyTorch实现一个完整的Transformer模型,包括数据预处理、位置编码、自注意力机制、多头注意力机制、前馈神经网络、Encoder与Decoder的实现,以及模型的训练与评估。通过这个实现,读者可以更好地理解Transformer模型的工作原理,并将其应用于各种NLP任务中。


🎉OK!今天就学习到这里了!🙂


项目地址


thank_watch

如果觉得我的文章对您有帮助,三连+关注便是对我创作的最大鼓励!或者一个star🌟也可以😂.

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐