【深度学习】Transformer(Attention Is All You Need)
本文介绍了Transformer模型的架构与核心组件。该模型完全基于注意力机制,摒弃了传统的CNN和RNN结构,具有更强的并行化能力和训练效率。模型采用编解码器结构,编码器由多头自注意力层、残差连接、层归一化和前馈网络组成。重点解析了自注意力机制的计算过程,包括查询、键、值的线性变换和缩放点积注意力公式。多头注意力通过分割嵌入维度到多个子空间并行计算,最后合并结果。模型还使用残差连接缓解梯度消失,
前言
本文是阅读论文《Attention Is All You Need》的笔记。在本文中 Google(2017)提出了 Transformer 结构,一种完全用注意力机制组成的模型结构,相比之前的 CNN 和 RNN 等模型,并行化能力变强了,训练效率提升。在 Transformer 之前,大家更多是在魔改模型结构,比如 CNN、RNN、ResNet、CRNN、CRNN-Attention… 这些模型基本都基于卷积、RNN 记忆缓存和注意力机制。Transformer 论文解决机器翻译问题,输入可能是很长的文字序列。CNN 利用局部感受野,如果要捕获到所有输入信息,需要叠加很多层,不适合处理变长序列。而 RNN 是自回归的,每层需要逐个状态传递信息,无法并行化,对于长序列内存有限制,且一句话读到后面,容易忘记句首信息。注意力机制在当时常常作为 RNN 的辅助模块,用来提点,常见的是 CRNN 后面接一个 Attention,或者像 LAS 一样利用 Attention 来信息对齐。作者的贡献在于完全舍弃卷积和 RNN,只用注意力机制,并行化处理输入,大幅提高训练效率,且在翻译任务中性能 SOTA。
本文从模型结构、编码器、解码器、位置编码 4 方面说明 Transformer,理解每个模块的设计。
一、模型结构
机器翻译任务流行采用编解码器(Encoder-Decoder)结构,比如英-汉翻译,输入是英文符号序列 x = ( x 1 , … , x n ) \mathbf{x}=(x_1,\dots, x_n) x=(x1,…,xn),经过编码器处理 z = E n c o d e r ( x ) , z = ( z 1 , … , z n ) \mathbf{z}=Encoder(\mathbf{x}), \mathbf{z}=(z_1,\dots,z_n) z=Encoder(x),z=(z1,…,zn),再由 Decoder 生成中文序列 y = ( y 1 , … , y m ) \mathbf{y}=(y_1,\dots,y_m) y=(y1,…,ym)。注意模型是自回归的,需要利用上一个汉字来解码下一个字符,即 y n = D e c o d e r ( z , y n − 1 ) , n ≤ m y_n=Decoder(\mathbf{z}, y_{n-1}), n≤m yn=Decoder(z,yn−1),n≤m 。
图 1.1 Transformer模型图
二、编码器
图 1.1 左边是 Embedding(可以理解为文本词元 Token) 结合位置编码,输入到粗线框中,我们称粗线框模块为编码器块,左边的 “N×” 意思是 N 个一模一样的编码器块顺序连接,N 一般为 6 或 12。编码器块由多头自注意力层、残差连接、层归一化(LayerNorm)和前馈网络(FFN)组成。
2.1 自注意力(Self-attention)
注意力机制,一般是权重(注意力分数)与输入的加权求和。相较平均,利用不同权重处理信息,更加符合现实场景。比如打 MOBA 类 5v5,一般想要取得团队胜利,资源的分配是按权重设计的,假设总权重为 1,
| 位置 | 中单 | Carry | 优势路 | 劣势路 | 辅助打野 |
|---|---|---|---|---|---|
| 分配资源 | 0.3 | 0.4 | 0.15 | 0.1 | 0.05 |
表 2.1 游戏不同位置资源权重
每个位置获得的资源是不一样的。那么权重怎么获得?以游戏举例,可以统计几百场比赛每个位置前期获得的金币收益,做归一化 softmax \text{softmax} softmax( softmax \text{softmax} softmax 后的权重总和是1,权重作为百分比更容易理解一些,也可以用其他归一化方法),统计得到每个位置的资源权重。我们发现权重是可以通过具体问题来进行统计的。但是作为一个模型,没有这些先验信息,所有的权重要靠模型自己去学,这就有了自注意力。
自注意力层,对于输入 { x i ∈ R d } i = 1 L \{\mathbf{x_{i}}\in \mathbb{R}^d\}_{i=1}^L {xi∈Rd}i=1L,其中 d d d 是 Embedding 维度,我们需要算出权重,再用权重和输入加权求和。
Transformer 引入了 3 个线性变换 W Q ∈ R d × d q W^Q \in \mathbb{R}^{d×d_q} WQ∈Rd×dq, W K ∈ R d × d k W^K \in \mathbb{R}^{d×d_k} WK∈Rd×dk, W V ∈ R d × d v W^V \in \mathbb{R}^{d×d_v} WV∈Rd×dv,使得
Q = X W Q (2.1) Q = XW^Q \tag{2.1} Q=XWQ(2.1)
K = X W K (2.2) K = XW^K \tag{2.2} K=XWK(2.2)
V = X W V (2.3) V = XW^V \tag{2.3} V=XWV(2.3)
然后
Z = Attention ( Q , K , V ) = softmax ( Q K T d k ) V (2.4) Z = \text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V \tag{2.4} Z=Attention(Q,K,V)=softmax(dkQKT)V(2.4)
一般情况下 d q = d k = d v d_q=d_k=d_v dq=dk=dv,这样 Q Q Q 和 K K K 的维度是 d × d d×d d×d, Q Q Q 可以和 K K K 的转置做内积。 也允许 d q ≠ d k d_q\ne d_k dq=dk,可以用线性变换将维度对齐[3]。
观察(2.4)式,这里做了缩放 1 d k \frac{1}{\sqrt{d_k}} dk1。因为 Q K T QK^T QKT 做内积,计算出的每个值是与 d k d_k dk 相关的,比如论文中设 d k = 512 d_k=512 dk=512,每一次点乘的值可能会很大或很小(负数),过了 softmax \text{softmax} softmax 大的值会接近 1,小的值会接近 0,容易导致梯度接近 0,梯度消失。通过缩放将 d k d_k dk 的影响抵消掉,而在论文中假设做内积的向量 q q q 和 k k k 服从均值 0 方差 1 的独立同分布,故
V a r ( q ⋅ k d k ) = 1 d k V a r ( ∑ i = 1 d k q i k i ) = ∑ i = 1 d k V a r ( q i ) V a r ( k i ) d k = 1 Var\left(\frac{q · k}{\sqrt{d_k}}\right) = \frac{1}{d_k}Var\left(\sum_{i=1}^{d_k}q_i k_i\right)=\frac{\sum_{i=1}^{d_k} Var(q_i)Var(k_i)}{d_k} = 1 Var(dkq⋅k)=dk1Var(i=1∑dkqiki)=dk∑i=1dkVar(qi)Var(ki)=1
第二个等号用到 q i q_i qi 之间独立, k i k_i ki 之间独立,那么 q i k i q_i k_i qiki 和 q j k j ( i ≠ j ) q_j k_j(i \neq j) qjkj(i=j) 是独立的。
Q Q Q、 K K K、 V V V分别称为询问(query)、键(key)和值(value),(2,4) 式 softmax ( Q K T d k ) \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) softmax(dkQKT) 是权重,与 V V V 加权求和得到输出。
2.2 多头注意力层(Multi-Head Attention)
将 Embedding 维度 d d d 拆分成 h h h 个子空间(要求 h h h 整除 d d d),在每个子空间做线性映射,然后将每个子空间的信息 Concat \text{Concat} Concat 起来,如下
MultiHead ( X Q , X K , X V ) = Concat ( head 1 , … , head h ) W O where head i = Attention ( X Q W i Q , X K W i K , X V W i V ) , i = 1 , . . . , h \begin{aligned} \text{MultiHead}(X_Q, X_K, X_V) &= \text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O \\ \text{where} \ \text{head}_i &= \text{Attention}(X_QW_i^Q, X_KW_i^K, X_VW_i^V), \ i=1,...,h \end{aligned} MultiHead(XQ,XK,XV)where headi=Concat(head1,…,headh)WO=Attention(XQWiQ,XKWiK,XVWiV), i=1,...,h
其中 W i Q ∈ R d × d k W_i^Q \in \mathbb{R}^{d×d_k} WiQ∈Rd×dk, W i K ∈ R d × d k W_i^K \in \mathbb{R}^{d×d_k} WiK∈Rd×dk, W i V ∈ R d × d v W_i^V \in \mathbb{R}^{d×d_v} WiV∈Rd×dv, W O ∈ R h d v × d W^O \in \mathbb{R}^{hd_v×d} WO∈Rhdv×d。
下面举例说明多头自注意力层的计算。
假设参数:
- d = 512
- d k d_k dk = 256(key 的嵌入维度)
- d v d_v dv = 384(value 的嵌入维度)
- h = 8
- h_dim = 512 // 8 = 64
- batch_size = 32
- seq_len = 100
- W Q W_Q WQ: (512, 512)
- W K W_K WK: (512, 256)
- W V W_V WV: (512, 384)
- W O W_O WO: (512, 512)
| 步骤 | 操作 | 输入维度 | 输出维度 |
|---|---|---|---|
| 输入 | X Q X_Q XQ, X K X_K XK, X V X_V XV | (100,32,512),(100,32,256),(100,32,384) | (100,32,512)×3 |
| Reshape | 多头分割 | (100,32,512) | (256,100,64) |
| 注意力 | Attention | (256,100,64) | (256,100,64) |
| 输出 | 合并+投影 | (256,100,64) | (100,32,512) |
Reshape 多头分割是
# Q: (100, 32, 512) -> (100, 32*8, 64) -> (256, 100, 64)
q = Q.view(100, 32*8, 64).transpose(0, 1) # (256, 100, 64)
# K: (100, 32, 512) -> (100, 32*8, 64) -> (256, 100, 64)
k = K.view(100, 32*8, 64).transpose(0, 1) # (256, 100, 64)
# V: (100, 32, 512) -> (100, 32*8, 64) -> (256, 100, 64)
v = V.view(100, 32*8, 64).transpose(0, 1) # (256, 100, 64)
维度说明:
- 256 = 32 * 8(批次大小 × 头数)
- 100 是序列长度
- 64 是每个头的维度
最后的合并+投影是
# 1. 转置并 reshape
# (256, 100, 64) -> (100, 256, 64) -> (100*32, 512)
attn_output = attn_output.transpose(0, 1).contiguous().view(100*32, 512)
# = (3200, 512)
# 2. 输出投影
out_proj_weight: (512, 512)
out_proj_bias: (512,)
attn_output = attn_output @ out_proj_weight.T + out_proj_bias
# (3200, 512) @ (512, 512) = (3200, 512)
# 3. Reshape 回原始格式
attn_output = attn_output.view(100, 32, 512) # (100, 32, 512)
2.3 残差连接、层归一化(LayerNorm)和前馈网络(FFN)
如图 1.1,过了多头自注意力层(_sa_block),会进行残差连接,可以缓解梯度消失,稳定训练。然后进行 LayerNorm \text{LayerNorm} LayerNorm,是特征维度进行归一化,可以提高稳定性,加速训练。下面对 BatchNorm \text{BatchNorm} BatchNorm 和 LayerNorm \text{LayerNorm} LayerNorm 做一下比较。
BatchNorm \text{BatchNorm} BatchNorm 是在 batch 维度上归一化。这种方式比较符合直觉,比如我们想拟合一个班学生的身高分布,在模型训练前,算出平均身高和方差,做归一化,稳定训练,我们关心的是整个样本集合的分布。在 CNN 类的网络中效果好,推理的时候如果是逐个样本推理,可以使用训练时保存的均值方差等参数。
而 LayerNorm \text{LayerNorm} LayerNorm 是对单个样本进行归一化,比如机器翻译,输入一个 token 序列,对这个序列做归一化。每个样本独立归一化,不依赖批次大小,适合序列模型,实时计算统计量。 BatchNorm \text{BatchNorm} BatchNorm 不适合处理变长序列,处理变长序列可能需要 padding 操作,导致统计量不稳定。所以 Transformer 中采用 LayerNorm \text{LayerNorm} LayerNorm 。
编码器块的执行顺序为:
- 自注意力: x + sa_block ( x ) x + \text{sa\_block}(x) x+sa_block(x) (残差连接)
- LayerNorm \text{LayerNorm} LayerNorm: norm1 ( x + sa_block ( x ) ) \text{norm1}(x + \text{sa\_block}(x)) norm1(x+sa_block(x))
- 前馈网络: x + FFN ( x ) = x + m a x ( 0 , x W 1 + b 1 ) W 2 + b 2 x + \text{FFN}(x) = x + max(0, xW_1+b_1)W_2 + b_2 x+FFN(x)=x+max(0,xW1+b1)W2+b2(残差连接)
- LayerNorm \text{LayerNorm} LayerNorm: norm2 ( x + FFN ( x ) ) \text{norm2}(x + \text{FFN}(x)) norm2(x+FFN(x))
FFN 加入了线性变换,容易控制特征尺寸,比如 FFN 将维度扩展到 4 × d m o d e l 4 × d_{model} 4×dmodel,再压缩回 d m o d e l d_{model} dmodel,可以有效控制参数量,提供更大的表示空间,增强模型的拟合能力。另一方面,通过 ReLU \text{ReLU} ReLU 可以引入非线性,增强模型的表达能力。
三、解码器
右边黑色框为解码器块,相较编码器块,解码器多了掩码多头注意力,并且融合了来自编码器的信息。最终经过线性变换和 softmax \text{softmax} softmax 来预测下一个 token 的概率。
3.1 掩码多头注意力
Transformer 用于变长序列的训练,会生成 Key Padding Mask,标记哪些位置是 padding,应该被忽略。我们这一节主要说明解码器的因果掩码(Causal Mask)。
在机器翻译中,编码器用于融合上下文语义信息,而解码器负责生成目标语言序列。过程是自回归的,需要利用之前生成 token 的信息,而不能看到后面解码的信息。所以需要因果掩码来屏蔽后续文本信息。
import torch
# 生成大小为 5×5 的 causal mask
mask = torch.triu(torch.full((5, 5), float('-inf')), diagonal=1)
print(mask)
输出
tensor([[ 0., -inf, -inf, -inf, -inf],
[ 0., 0., -inf, -inf, -inf],
[ 0., 0., 0., -inf, -inf],
[ 0., 0., 0., 0., -inf],
[ 0., 0., 0., 0., 0.]])
计算流程
# 1. 计算注意力分数
attn_scores = Q @ K^T / sqrt(d_k) # (batch*heads, seq_len, seq_len)
# 2. 应用 mask
if attn_mask is not None:
attn_scores = attn_mask + attn_scores # mask 中的 -inf 会让对应位置变成 -inf
# 3. Softmax(-inf 会被映射到 0)
attn_weights = softmax(attn_scores, dim=-1) # -inf → 0, 正常值 → 概率分布
# 4. 加权求和
output = attn_weights @ V
3.2 交叉注意力
编码器输出的信息通过交叉注意力传输给解码器,是在掩码多头注意力模块之后。
整体流程
# 步骤 1: 编码器处理源序列
memory = encoder(src) # (src_len, batch, d_model)
# 步骤 2: 解码器使用编码器输出
output = decoder(tgt, memory) # (tgt_len, batch, d_model)
交叉注意力核心是 Q Q Q 来自上一层的解码器, K K K 和 V V V 来自编码器。
# 交叉注意力计算
Q = x @ W_Q # (tgt_len, batch, d_model) - 来自解码器
K = memory @ W_K # (src_len, batch, d_model) - 来自编码器
V = memory @ W_V # (src_len, batch, d_model) - 来自编码器
# 注意力分数
attn_scores = Q @ K^T / sqrt(d_k) # (tgt_len, src_len)
四、位置编码
以上是 Transformer 的大致框架,我们看到,核心运算都来自 self_attention 和 FFN,其中基本都是矩阵点积运算,长距离依赖和下一个 token 一样,只做一次点积,这样 Transformer 可以并行处理数据。但由于加和运算的可交换性,导致无法区分输入 token 的位置,翻译时 上海 和 海上,两个含义差异巨大的名词,对于模型是一样的运算。为了区分 token 位置,文中引入了位置编码。
P E ( p o s , 2 i ) = s i n ( p o s / 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = c o s ( p o s / 1000 0 2 i / d m o d e l ) (4.1) \begin{aligned} PE_{(pos, 2i)} = sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos, 2i+1)} = cos(pos/10000^{2i/d_{model}}) \tag{4.1} \end{aligned} PE(pos,2i)=sin(pos/100002i/dmodel)PE(pos,2i+1)=cos(pos/100002i/dmodel)(4.1)
其中 p o s pos pos 是位置, i i i 是维度。
既然是为了区分 token 的位置,我给第 1 个 token 向量加 1, 第 2 个 token 向量加 2,…,以此类推,不就做出区分了吗,为什么要用(4.1)这么复杂的编码?
我们举个简单的例子,比如输入形状为 (batch=1, seq_len=5, d_model=4),我们来计算正余弦位置编码。
# 输入 x: (1, 5, 4)
x = torch.randn(1, 5, 4)
# 例如:
# x = [[[ 0.1, 0.2, 0.3, 0.4],
# [ 0.5, 0.6, 0.7, 0.8],
# [ 0.9, 1.0, 1.1, 1.2],
# [ 1.3, 1.4, 1.5, 1.6],
# [ 1.7, 1.8, 1.9, 2.0]]]
# pe 形状: (5, 4)
# 位置编码:
pe = [
# 位置 0: [sin(0*1.0), cos(0*1.0), sin(0*0.01), cos(0*0.01)]
[ 0.0000, 1.0000, 0.0000, 1.0000],
# 位置 1: [sin(1*1.0), cos(1*1.0), sin(1*0.01), cos(1*0.01)]
[ 0.8415, 0.5403, 0.0100, 0.9999],
# 位置 2: [sin(2*1.0), cos(2*1.0), sin(2*0.01), cos(2*0.01)]
[ 0.9093, -0.4161, 0.0200, 0.9998],
# 位置 3: [sin(3*1.0), cos(3*1.0), sin(3*0.01), cos(3*0.01)]
[ 0.1411, -0.9900, 0.0300, 0.9996],
# 位置 4: [sin(4*1.0), cos(4*1.0), sin(4*0.01), cos(4*0.01)]
[-0.7568, -0.6536, 0.0400, 0.9992]
]
# 缩放输入
xscale = sqrt(4) = 2.0
x_scaled = x * xscale
# 添加位置编码
x_with_pos = x_scaled + pe
对于 s i n ( ω ∗ p o s ) sin(ω * pos) sin(ω∗pos) 或 c o s ( ω ∗ p o s ) cos(ω * pos) cos(ω∗pos) 我们称 ω ω ω 是角频率,决定了函数变化的快慢。频率高 ( ω ω ω 大),函数变化快,相邻位置的值差异大;频率低 ( ω ω ω 小),函数变化慢,相邻位置的值差异小。对于维度 i i i,频率为
ω = 1 / 1000 0 ( 2 i / d m o d e l ) \omega = 1 / 10000^{(2i/d_{model})} ω=1/10000(2i/dmodel)
| PE | 0 | 1 | 2 | 3 |
|---|---|---|---|---|
| 维度 i i i | 0 | 0 | 1 | 1 |
| 频率 ω \omega ω | 1 | 1 | 0.01 | 0.01 |
| 位置 p o s pos pos=0 | 0.0000 | 1.0000 | 0.0000 | 1.0000 |
| 位置 p o s pos pos=1 | 0.8415 | 0.5403 | 0.0100 | 0.9999 |
| 位置 p o s pos pos=2 | 0.9093 | -0.4161 | 0.0200 | 0.9998 |
| 位置 p o s pos pos=3 | 0.1411 | -0.9900 | 0.0300 | 0.9996 |
| 位置 p o s pos pos=4 | -0.7568 | -0.6536 | 0.0400 | 0.9992 |
表 4.1 正余弦位置编码举例
观察表 4.1,我们发现正余弦位置编码第 1 列相邻位置差异大,第 4 列相邻位置差异小,正余弦位置编码是利用高频捕获局部信息,比如在维度 0 时,位置 0 和 位置 1 有明显的差异,易于区分。
利用低频捕获全局信息,比如维度 1 时,PE(0,3) 和 PE(1,3) 几乎相同,但 PE(0,3) 和 PE(4,3) 差异较大。我们举例中 d m o d e l = 4 d_{model} = 4 dmodel=4,若 d m o d e l d_{model} dmodel 比较大,比如 512,则差异会更加明显。所以正余弦编码可以多尺度地捕获输入位置信息。可以看出,如果按照本节开头,给输入向量简单加一个常数,则会导致相邻位置差异大,但远距离位置可能无法区分。
由于
P E ( p o s + k , 2 i ) = P E ( p o s , 2 i ) ∗ c o s ( ω i ∗ k ) + P E ( p o s , 2 i + 1 ) ∗ s i n ( ω i ∗ k ) P E ( p o s + k , 2 i + 1 ) = P E ( p o s , 2 i + 1 ) ∗ c o s ( ω i ∗ k ) − P E ( p o s , 2 i ) ∗ s i n ( ω i ∗ k ) \begin{aligned} PE_{(pos+k, 2i)} = PE_{(pos, 2i)} * cos(ω_i * k) + PE_{(pos, 2i+1)} * sin(ω_i * k) \\ PE_{(pos+k, 2i+1)} = PE_{(pos, 2i+1)} * cos(ω_i * k) - PE_{(pos, 2i)} * sin(ω_i * k) \end{aligned} PE(pos+k,2i)=PE(pos,2i)∗cos(ωi∗k)+PE(pos,2i+1)∗sin(ωi∗k)PE(pos+k,2i+1)=PE(pos,2i+1)∗cos(ωi∗k)−PE(pos,2i)∗sin(ωi∗k)
对于任意位置 p o s + k pos + k pos+k,可以表示为 p o s pos pos 的线性组合,这个特性使得模型易于学习相对位置关系。
正余弦编码每个位置都有唯一的编码表示,且编码是确定的,不用引入额外的训练参数,编码稳定,不会随训练过程变化。可以处理未见过的序列长度。值域在 [ − 1 , 1 ] [-1, 1] [−1,1] 之间,好控制,有利于模型训练的稳定性。
总结
Transformer 是深度学习里程碑式的成果,提高并行效率的同时,能保证性能 SOTA,这使得模型拥有了处理海量数据的能力,开启了 LLM 的“大航海时代”。
参考文献
[1]: Chan W , Jaitly N , Le Q V ,et al.Listen, Attend and Spell[J].Computer Science, 2015.DOI:10.48550/arXiv.1508.01211.
[2]: 张奇,桂韬,郑锐,黄萱菁,《大规模语言模型:从理论到实践(第2版)》,2025.
[3]: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/transformer.py
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)