Transformer中为什么要使用多头注意力?
大模型小知识(面试八股)
·
参考视频:面试必刷:大模型为什么要使用多头注意力?_哔哩哔哩_bilibili
详解文章:Transformer内容详解(通透版)-CSDN博客
单头注意力的劣势:单头注意力只能从一个角度“看”输入序列,计算得到的注意力权重反映的是一种特定的关注模式。
多头注意力将注意力分为了多个“头”,每个头独立计算注意力,关注输入的不同子空间或不同方面的特征。这样,模型能够并行地捕捉到多种不同类型的语义关系。
将输入投射到多个不同的低维空间,分别计算注意力,最后再concat拼接,通过线性变换融合,丰富了模型的表达能力,使得Transformer能够学习复杂的组合特征。同时,每个注意力头的参数量和计算复杂度降低,有助于提升训练的稳定性和效率,有利于收敛。
单头注意力:
import torch
import torch.nn as nn
class Self_Attention(nn.Module):
def __init__(self, dim, dk, dv):
super().__init__()
self.scale = dk ** -0.5
self.q = nn.Linear(dim, dk)
self.k = nn.Linear(dim, dk)
self.v = nn.Linear(dim, dv)
def forward(self, x):
# x: [batch, seq_len, dim]
q = self.q(x) # [batch, seq_len, dk]
k = self.k(x) # [batch, seq_len, dk]
v = self.v(x) # [batch, seq_len, dv]
attn = (q @ k.transpose(-2, -1)) * self.scale # [batch, seq_len, seq_len]
attn = attn.softmax(dim=-1)
out = attn @ v # [batch, seq_len, dv]
return out
多头注意力:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, dim, dk, dv, num_heads):
super().__init__()
self.num_heads = num_heads
self.dk = dk
self.dv = dv
self.q_linear = nn.Linear(dim, dk * num_heads)
self.k_linear = nn.Linear(dim, dk * num_heads)
self.v_linear = nn.Linear(dim, dv * num_heads)
self.out_linear = nn.Linear(dv * num_heads, dim)
def forward(self, x):
B, N, _ = x.shape # batch, seq_len, dim
Q = self.q_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2) # [B, heads, N, dk]
K = self.k_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2)
V = self.v_linear(x).view(B, N, self.num_heads, self.dv).transpose(1, 2)
# Attention
attn = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5) # [B, heads, N, N]
attn = attn.softmax(dim=-1)
out = attn @ V # [B, heads, N, dv]
out = out.transpose(1, 2).reshape(B, N, self.num_heads * self.dv) # [B, N, heads*dv]
out = self.out_linear(out) # [B, N, dim]
return out
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)