【大模型手撕】pytorch实现Self-Attention,Multi-Head-Attention,Grouped-Query-Attention
本文介绍了使用PyTorch实现Self-Attention,Multi-Head-Attention,Grouped-Query-Attention的代码
pytorch手写Attention
Self-Attention
代码:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self,embed_dim):
super(SelfAttention,self).__init__()
self.embed_dim=embed_dim
self.WQ=nn.Linear(embed_dim,embed_dim)
self.WK=nn.Linear(embed_dim,embed_dim)
self.WV=nn.Linear(embed_dim,embed_dim)
self.dropout=nn.Dropout(0.1)
def forward(self,x,iscausal):
"""
输入序列x(batch_size,seq_len,embed_dim)
"""
# (batch_size,seq_len,embed_dim)
Q=self.WQ(x)
# (batch_size,seq_len,embed_dim)
K=self.WK(x)
# (batch_size,seq_len,embed_dim)
V=self.WV(x)
# K(batch_size,seq_len,embed_dim) ,K.transpose(-2,-1)交换张量的最后一个维度和倒数第二个维度
attention_scores=torch.matmul(Q,K.transpose(-2,-1))/(self.embed_dim**0.5)
# 被掩码的位置设为 -inf
if iscausal:
# 生成一个 (seq_len, seq_len) 的上三角矩阵
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
attention_scores=attention_scores.masked_fill(mask==1,float('-inf'))
# 沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度),对 seq_len 维度计算,让每个 Query 的注意力总和为 1
attention_weights=torch.softmax(attention_scores,-1)
output=torch.matmul(attention_weights,V)
return output,attention_weights
测试:
batch_size=2
seq_len=5
embed_dim=5
x=torch.rand(batch_size,seq_len,embed_dim)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
print("mask:")
print(mask)
print("-"*50)
self_attention=SelfAttention(embed_dim)
output_with_mask,weights_with_mask=self_attention(x,iscausal=True)
print("output_with_mask:")
print(output_with_mask)
print("weights_with_mask:")
print(weights_with_mask)
print("-"*50)
output_without_mask,weights_without_mask=self_attention(x,iscausal=False)
print("output_without_mask:")
print(output_without_mask)
print("weights_without_mask:")
print(weights_without_mask)
结果:
mask:
tensor([[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0.]])
--------------------------------------------------
output_with_mask:
tensor([[[ 7.9193e-01, -2.4524e-01, -1.4308e-01, 1.0665e-01, -1.1564e-01],
[ 7.6841e-01, -3.4178e-01, -1.3795e-01, 7.9713e-02, -1.8217e-01],
[ 8.2544e-01, -2.9498e-01, -5.1584e-02, 5.1098e-02, -1.2348e-01],
[ 7.6096e-01, -2.4629e-01, -9.5025e-02, 1.1596e-01, -2.4383e-01],
[ 7.0658e-01, -2.9171e-01, -1.2499e-01, 1.4677e-01, -2.1432e-01]],
[[ 6.4955e-01, -7.1487e-02, 7.6053e-04, 2.5570e-01, 7.1325e-02],
[ 7.5188e-01, -1.7114e-01, -5.5963e-02, 1.4976e-01, -1.5343e-01],
[ 7.2989e-01, -1.1941e-01, -8.5776e-02, 1.8858e-01, -1.7605e-01],
[ 7.6747e-01, -1.8748e-01, -4.1792e-02, 1.2632e-01, -1.0911e-01],
[ 7.2204e-01, -1.3516e-01, -6.3888e-02, 1.8362e-01, -1.0724e-01]]],
grad_fn=<UnsafeViewBackward0>)
weights_with_mask:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4874, 0.5126, 0.0000, 0.0000, 0.0000],
[0.3224, 0.3519, 0.3257, 0.0000, 0.0000],
[0.2442, 0.2569, 0.2344, 0.2646, 0.0000],
[0.1942, 0.1951, 0.1755, 0.2123, 0.2228]],
[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.4710, 0.5290, 0.0000, 0.0000, 0.0000],
[0.3236, 0.3420, 0.3344, 0.0000, 0.0000],
[0.2372, 0.2532, 0.2297, 0.2798, 0.0000],
[0.1986, 0.2023, 0.2036, 0.1898, 0.2057]]],
grad_fn=<SoftmaxBackward0>)
--------------------------------------------------
output_without_mask:
tensor([[[ 0.7117, -0.2931, -0.1214, 0.1422, -0.2092],
[ 0.7153, -0.2916, -0.1187, 0.1398, -0.2083],
[ 0.7159, -0.2925, -0.1177, 0.1389, -0.2070],
[ 0.7104, -0.2917, -0.1218, 0.1436, -0.2120],
[ 0.7066, -0.2917, -0.1250, 0.1468, -0.2143]],
[[ 0.7257, -0.1405, -0.0612, 0.1784, -0.1052],
[ 0.7272, -0.1414, -0.0627, 0.1772, -0.1101],
[ 0.7236, -0.1370, -0.0637, 0.1817, -0.1086],
[ 0.7324, -0.1501, -0.0565, 0.1689, -0.1016],
[ 0.7220, -0.1352, -0.0639, 0.1836, -0.1072]]],
grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1965, 0.2041, 0.1833, 0.2004, 0.2157],
[0.1965, 0.2067, 0.1906, 0.1992, 0.2071],
[0.1919, 0.2095, 0.1939, 0.1964, 0.2082],
[0.1914, 0.2014, 0.1838, 0.2074, 0.2159],
[0.1942, 0.1951, 0.1755, 0.2123, 0.2228]],
[[0.1975, 0.2042, 0.1980, 0.2035, 0.1968],
[0.1918, 0.2155, 0.1994, 0.1989, 0.1944],
[0.1960, 0.2072, 0.2026, 0.1919, 0.2023],
[0.1941, 0.2072, 0.1880, 0.2290, 0.1817],
[0.1986, 0.2023, 0.2036, 0.1898, 0.2057]]],
grad_fn=<SoftmaxBackward0>)
代码中的一些用法解释:
1)torch.nn.Linear ()
torch.nn.Linear() 是 PyTorch 最基础的全连接层(线性变换层),用于执行以下操作:
nn.Linear(in_features, out_features, bias=True),in_features 是输入特征维度,out_features是输出特征维度,bias表示是否使用偏置项,默认为 True
2)K.transpose(-2,-1)
在 PyTorch 中,torch.transpose() 用于交换张量的两个维度。参数 -2 和 -1 是指张量的倒数第二个维度和最后一个维度。
K.transpose(-2, -1) 和 K.transpose(-1, -2) 都是交换最后两个维度,它们的效果完全相同。
3)torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) # 生成上三角部分为 1
此函数解释:
- torch.ones(seq_len, seq_len) 生成 seq_len × seq_len 矩阵,所有元素都是 1
- torch.triu(…, diagonal=1)
torch.triu()如果不设置diagonal,即diagonal默认为0,表示取该矩阵的上三角部分为1(包含主对角线及以上的元素),其余元素置为 0。
如果diagonal=1,表示把主对角线本身置为 0,上三角部分是主对角线以上的元素
这样返回一个seq_len=5的mask矩阵:
tensor([[0., 1., 1., 1., 1.],
[0., 0., 1., 1., 1.],
[0., 0., 0., 1., 1.],
[0., 0., 0., 0., 1.],
[0., 0., 0., 0., 0.]])
4)attention_scores.masked_fill(mask==1,float('-inf'))
作用: 根据 mask矩阵 进行掩码处理,将 mask == 1 的位置填充为 -inf(负无穷),使其在 softmax 计算时权重变为 0
masked_fill 属于 PyTorch 张量(torch.Tensor)的方法,用于根据布尔掩码(mask)填充指定值。masked_fill 的语法: tensor.masked_fill(mask, value)
tensor:要修改的张量
mask:布尔掩码(True/False 或 0/1)
value:要填充的值(如 -inf)
解释:attention_scores.masked_fill(mask==1,float('-inf'))
mask == 1 选取 应该被屏蔽的位置(即 上三角部分)
masked_fill(mask == 1, -inf) 把上三角部分设为-inf
这样,Softmax 后被屏蔽的部分变成 0,不会影响注意力计算。
5) torch.softmax(attention_scores,-1)
作用: 对 attention_scores 进行 Softmax 归一化,确保注意力权重(attn_weights)的总和为 1,控制每个 Token 对序列中其他 Token 的关注程度
torch.softmax(input, dim) 语法:
input:要进行 Softmax 计算的张量
dim:沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度)
Multi-Head-Attention
import torch
import torch.nn as nn
import math
class Multi_Head_Attention(nn.Module):
def __init__(self,embed_dim,nums_heads):
super(Multi_Head_Attention,self).__init__()
assert embed_dim % nums_heads ==0,"embed_dim 必须能被 num_heads 整除"
self.embed_dim=embed_dim
self.nums_heads=nums_heads
self.head_dim=embed_dim//nums_heads
self.WQ=nn.Linear(embed_dim,embed_dim)
self.WK=nn.Linear(embed_dim,embed_dim)
self.WV=nn.Linear(embed_dim,embed_dim)
self.fc=nn.Linear(embed_dim,embed_dim)
self.scale=math.sqrt(embed_dim)
def forward(self,x,iscausal):
batch_size,seq_len,embed_dim=x.shape
# Q,K,V: batch_size,seq_len,embed_dim
Q=self.WQ(x)
K=self.WK(x)
V=self.WV(x)
# Q,K,V: batch_size,seq_len,embed_dim -> batch_size,seq_len,self.nums_heads,self.head_dim -> batch_size,self.nums_heads,seq_len,self.head_dim
Q=Q.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
K=K.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
V=V.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
# batch_size,self.nums_heads,seq_len,seq_len
attn_scores=torch.matmul(Q,K.transpose(-1,-2))/self.scale
if iscausal:
mask = torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
attn_scores=attn_scores.masked_fill(mask==1,float('-inf'))
attn_weights=torch.softmax(attn_scores,-1)
# batch_size,self.nums_heads,seq_len,head_dim
output=torch.matmul(attn_weights,V)
# batch_size,self.nums_heads,seq_len,head_dim -> batch_size,seq_len,self.nums_heads,head_dim -> batch_size,seq_len,self.embed_dim
output=output.transpose(1,2).contiguous().view(batch_size,seq_len,self.embed_dim)
output=self.fc(output)
return output,attn_weights
测试:
batch_size=2
seq_len=5
nums_heads=3
embed_dim=6
x=torch.randn(batch_size,seq_len,embed_dim)
mask=create_mask(seq_len)
multiheadattention=Multi_Head_Attention(embed_dim,nums_heads)
output,weights=multiheadattention(x,mask)
print(output)
print(weights)
代码中的一些用法解释:
1)output=output.transpose(1,2).contiguous()中的contiguous()
contiguous() 是一个重要的张量方法,用于确保张量在内存中是连续存储的
Grouped-Query-Attention
import torch
import torch.nn as nn
class GroupedQueryAttention(nn.Module):
def __init__(self,embed_dim,nums_head,nums_group,dropout):
super().__init__()
self.embed_dim = embed_dim
self.nums_head = nums_head
self.nums_group = nums_group
assert nums_head%nums_group == 0
assert embed_dim%nums_head == 0
self.head_dim = embed_dim/nums_head
self.num_heads_per_group = nums_head/nums_group
self.scale = self.head_dim**0.5
self.wq = nn.Linear(embed_dim,embed_dim)
self.wk = nn.Linear(embed_dim,self.nums_group * self.head_dim)
self.wv = nn.Linear(embed_dim,self.nums_group * self.head_dim)
self.wo = nn.Linear(embed_dim,embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self,x,is_causal):
batch_size,seq_len,_ = x.shape
# q: (batch_size,seq_len,embed_dim)
q = self.wq(x)
# k v: (batch_size,seq_len,nums_group * head_dim)
k = self.wk(x)
v = self.wv(x)
# q: (batch_size,seq_len,embed_dim) -> (batch_size,seq_len,nums_head,head_dim) -> (batch_size,nums_head,seq_len,head_dim)
q = q.view(batch_size,seq_len,self.nums_head,self.head_dim).transpose(1,2)
# k: (batch_size,seq_len,nums_group * head_dim) -> (batch_size,seq_len,nums_group,head_dim) -> (batch_size,nums_group,seq_len,head_dim)
k = k.view(batch_size,seq_len,self.nums_group,self.head_dim).transpose(1,2)
# 将每组的键和值复制到组内的每个头
# k.unsqueeze(2): (batch_size,nums_group,seq_len,head_dim) -> (batch_size,nums_group,1,seq_len,head_dim)
# k.repeat: (batch_size,nums_group,1,seq_len,head_dim) -> (batch_size,nums_group,num_heads_per_group,seq_len,head_dim)
# k : (batch_size,self.nums_head,seq_len,self.head_dim)
k = k.unsqueeze(2).repeat(1,1,self.num_heads_per_group,1,1).view(batch_size,self.nums_head,seq_len,self.head_dim)
# v: (batch_size,seq_len,nums_group * head_dim) -> (batch_size,seq_len,nums_group,head_dim) -> (batch_size,nums_group,seq_len,head_dim)
v = v.view(batch_size,seq_len,self.nums_group,self.head_dim).transpose(1,2)
# v : (batch_size,self.nums_head,seq_len,self.head_dim)
v = v.unsqueeze(2).repeat(1,1,self.num_heads_per_group,1,1).view(batch_size,self.nums_head,seq_len,self.head_dim)
# attn_scores: (batch_size,self.nums_head,seq_len,seq_len)
attn_scores = torch.matmul(q,k.transpose(-1,-2)) / self.scale
if is_causal:
mask = torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
attn_scores = attn_scores.masked_fill(mask==1,float('-inf'))
# attn_weights: (batch_size,self.nums_head,seq_len,seq_len)
attn_weights = torch.softmax(attn_scores,dim=-1)
attn_weights = self.dropout(attn_weights)
# attn_outputs : (batch_size,self.nums_head,seq_len,self.head_dim)
attn_outputs = torch.matmul(attn_weights,v)
# attn_outputs : (batch_size,self.nums_head,seq_len,self.head_dim) -> (batch_size,seq_len,self.nums_head,self.head_dim) -> (batch_size,seq_len,self.embed_dim)
attn_outputs = attn_outputs.transpose(1,2).contiguous.view(batch_size,seq_len,self.embed_dim)
# outputs: (batch_size,seq_len,self.embed_dim)
outputs = self.wo(attn_outputs)
outputs = self.dropout(outputs)
return outputs
关于注意力机制的问题
为什么qk的乘积要除以 d k \sqrt{d_k} dk
避免qk点积值过大导致梯度消失问题,使得Softmax归一化时结果更稳定
为什么qk点积值过大导致梯度消失问题?


为什么注意力得分和最终输出要过dropout
- 对attn_weights使用 Dropout:随机将部分注意力权重置为 0,强制模型不过度依赖某些特定 token 之间的关联。
- 对outputs使用 Dropout:随机将输出特征向量中的部分维度置为 0,防止模型过度依赖特定维度的特征。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)