Transformer架构详解:从理论到实践

在深度学习领域,Transformer 架构的提出彻底改变了序列建模的方式。无论是自然语言处理(NLP)中的机器翻译,还是计算机视觉(CV)中的图像分类,Transformer 都展现了强大的能力。本文将从其设计背景、核心思想、架构细节到代码实现,全面解析这一划时代的模型。


一、Transformer的提出背景

1. 传统序列模型的局限性

在 Transformer 出现之前,序列建模主要依赖 循环神经网络(RNN)长短期记忆网络(LSTM),但这些模型存在明显缺陷:

  • 无法并行计算:RNN 需按时间步逐词处理,训练效率低下。
  • 长距离依赖建模困难:当序列长度超过 50 词时,LSTM 也难以捕捉远距离关联(如“虽然……但是”)。
  • 计算复杂度高:传统模型的复杂度与序列长度呈线性或平方关系,难以处理长文本。

2. Transformer的突破

2017 年,Vaswani 等人在论文《Attention is All You Need》中提出了 Transformer,其核心创新在于:

  • 完全基于注意力机制:摒弃循环结构,通过自注意力(Self-Attention)直接建模全局依赖。
  • 并行计算:所有位置同时处理,大幅提升训练速度。
  • 位置编码:通过数学方法注入词序信息,替代传统的递归或卷积操作。

二、Transformer的核心思想

1. 自注意力机制(Self-Attention)

自注意力机制是 Transformer 的核心,其目标是 动态计算序列中每个词与其他词的相关性。具体步骤包括:

  1. 生成Q、K、V矩阵
    输入序列通过线性变换得到查询(Query)、键(Key)、值(Value)矩阵。
  2. 计算注意力分数
    通过点积计算查询与键的相似度,缩放后应用 Softmax 归一化。
  3. 加权求和
    将权重矩阵与值矩阵相乘,得到上下文向量。
数学公式

Attention(Q,K,V)=Softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V Attention(Q,K,V)=Softmax(dk QKT)V

其中,dkd_kdk 为键向量的维度,缩放因子 dk\sqrt{d_k}dk 用于防止内积过大导致梯度消失。

2. 多头注意力(Multi-Head Attention)

为了捕捉不同子空间的语义信息,Transformer 将自注意力扩展为 多头注意力

  • 并行计算:将 Q、K、V 拆分为多个头(如 8 头),分别计算注意力后拼接结果。
  • 增强表达能力:不同头关注语法、语义等不同层面的特征。
公式

MultiHead(Q,K,V)=Concat(head1,…,headh)WO \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,,headh)WO

3. 位置编码(Positional Encoding)

自注意力机制本身无法感知词序,因此 Transformer 通过 位置编码 注入位置信息:

  • 正弦与余弦函数
    PE(pos,2i)=sin⁡(pos100002i/d),PE(pos,2i+1)=cos⁡(pos100002i/d) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d}}\right) PE(pos,2i)=sin(100002i/dpos),PE(pos,2i+1)=cos(100002i/dpos)
  • 可学习参数:某些变体(如 BERT)使用可学习的嵌入向量表示位置。

三、Transformer架构详解

1. 整体架构

Transformer 由 编码器(Encoder)解码器(Decoder) 组成,每部分包含多个相同的层。
在这里插入图片描述

编码器(Encoder)
  • 输入处理:词嵌入 + 位置编码。
  • 核心层
    1. 多头自注意力层:计算输入序列内部的依赖关系。
    2. 前馈神经网络(FFN):两层全连接网络,激活函数为 ReLU。
  • 残差连接与层归一化:每个子层后接残差连接和 LayerNorm。
解码器(Decoder)
  • 输入处理:目标序列的词嵌入 + 位置编码。
  • 核心层
    1. 掩码多头自注意力层:防止未来词信息泄露(训练时)。
    2. 编码器-解码器注意力层:关联编码器输出与解码器状态。
    3. 前馈神经网络(FFN)

2. 残差连接与层归一化

  • 残差连接:将子层输入直接加到输出上,缓解梯度消失。
  • 层归一化:对每个样本的特征维度进行归一化,加速收敛。

四、代码实现(PyTorch)

1. 多头注意力层

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads
        self.wq = nn.Linear(d_model, d_model)  # 查询变换
        self.wk = nn.Linear(d_model, d_model)  # 键变换
        self.wv = nn.Linear(d_model, d_model)  # 值变换
        self.dense = nn.Linear(d_model, d_model)  # 最终线性层

    def split_heads(self, x, batch_size):
        # 将输入拆分为多个头 [batch_size, seq_len, num_heads, depth]
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_len, depth]

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        q = self.wq(q)  # [batch_size, seq_len, d_model]
        k = self.wk(k)
        v = self.wv(v)
        
        # 拆分多头
        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)
        
        # 计算缩放点积注意力
        matmul_qk = torch.matmul(q, k.transpose(-2, -1))  # [batch_size, num_heads, seq_len, seq_len]
        dk = torch.tensor(k.size(-1), dtype=torch.float32)
        scaled_attention_logits = matmul_qk / torch.sqrt(dk)
        
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  # 掩码处理(解码器用)
        
        attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
        output = torch.matmul(attention_weights, v)  # [batch_size, num_heads, seq_len, depth]
        
        # 拼接多头结果
        output = output.permute(0, 2, 1, 3).contiguous()
        output = output.view(batch_size, -1, self.d_model)
        output = self.dense(output)
        return output

2. 编码器层实现

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, dff),  # 第一层全连接
            nn.ReLU(),
            nn.Linear(dff, d_model)    # 第二层全连接
        )
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, mask):
        # 多头自注意力子层
        attn_output = self.mha(x, x, x, mask)  # Q=K=V=x
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(x + attn_output)  # 残差连接 + 层归一化
        
        # 前馈网络子层
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        out2 = self.layernorm2(out1 + ffn_output)
        return out2

五、Transformer的应用与优化

1. 应用场景

  • 自然语言处理
    • BERT:双向编码器,用于文本分类、问答。
    • GPT:单向解码器,用于文本生成。
  • 计算机视觉
    • ViT(Vision Transformer):将图像分块输入Transformer。
  • 语音处理
    • 语音识别:结合Transformer与CTC损失函数。

2. 训练优化策略

  • 动态学习率:使用 Warmup 策略逐步增加学习率。
  • 标签平滑:将硬标签(0或1)替换为软标签(如0.1和0.9),缓解过拟合。
  • 混合精度训练:FP16与FP32混合计算,节省显存并加速。

六、Transformer的优缺点

1. 优势

  • 全局依赖建模:自注意力机制直接捕捉任意距离的词间关系。
  • 高效并行计算:所有位置同时处理,训练速度远超RNN。
  • 多模态适配性:适用于文本、图像、语音等多种数据类型。

2. 局限性

  • 计算复杂度高:自注意力复杂度为 O(n2)O(n^2)O(n2),处理长序列时资源消耗大。
  • 位置编码依赖:固定编码可能无法充分表达复杂的位置关系(如旋转、缩放)。

七、未来发展方向

  1. 稀疏注意力
    • Longformer:结合局部窗口与全局标记,复杂度降为 O(n)O(n)O(n)
    • BigBird:分块稀疏化,保留关键连接。
  2. 线性复杂度变体
    • Performer:通过核方法近似注意力矩阵。
  3. 多模态融合
    • CLIP:联合训练文本与图像编码器。
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐