Self-Attention

代码:

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self,embed_dim):
        super(SelfAttention,self).__init__()
        self.embed_dim=embed_dim

        self.WQ=nn.Linear(embed_dim,embed_dim)
        self.WK=nn.Linear(embed_dim,embed_dim)
        self.WV=nn.Linear(embed_dim,embed_dim)
        self.dropout=nn.Dropout(0.1)

    def forward(self,x,iscausal):
        """
        输入序列x(batch_size,seq_len,embed_dim)
        """
        # (batch_size,seq_len,embed_dim)
        Q=self.WQ(x) 
        # (batch_size,seq_len,embed_dim)
        K=self.WK(x)
        # (batch_size,seq_len,embed_dim)
        V=self.WV(x)
        # K(batch_size,seq_len,embed_dim) ,K.transpose(-2,-1)交换张量的最后一个维度和倒数第二个维度
        attention_scores=torch.matmul(Q,K.transpose(-2,-1))/(self.embed_dim**0.5)
        # 被掩码的位置设为 -inf
        if iscausal:
        	# 生成一个 (seq_len, seq_len) 的上三角矩阵
        	mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
            attention_scores=attention_scores.masked_fill(mask==1,float('-inf'))
        # 沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度),对 seq_len 维度计算,让每个 Query 的注意力总和为 1
        attention_weights=torch.softmax(attention_scores,-1)
        output=torch.matmul(attention_weights,V)
        return output,attention_weights
    

测试:

batch_size=2
seq_len=5
embed_dim=5
x=torch.rand(batch_size,seq_len,embed_dim)

mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
print("mask:")
print(mask)
print("-"*50)

self_attention=SelfAttention(embed_dim)

output_with_mask,weights_with_mask=self_attention(x,iscausal=True)
print("output_with_mask:")
print(output_with_mask)
print("weights_with_mask:")
print(weights_with_mask)
print("-"*50)

output_without_mask,weights_without_mask=self_attention(x,iscausal=False)
print("output_without_mask:")
print(output_without_mask)
print("weights_without_mask:")
print(weights_without_mask)

结果:

mask:
tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])
--------------------------------------------------
output_with_mask:
tensor([[[ 7.9193e-01, -2.4524e-01, -1.4308e-01,  1.0665e-01, -1.1564e-01],
         [ 7.6841e-01, -3.4178e-01, -1.3795e-01,  7.9713e-02, -1.8217e-01],
         [ 8.2544e-01, -2.9498e-01, -5.1584e-02,  5.1098e-02, -1.2348e-01],
         [ 7.6096e-01, -2.4629e-01, -9.5025e-02,  1.1596e-01, -2.4383e-01],
         [ 7.0658e-01, -2.9171e-01, -1.2499e-01,  1.4677e-01, -2.1432e-01]],

        [[ 6.4955e-01, -7.1487e-02,  7.6053e-04,  2.5570e-01,  7.1325e-02],
         [ 7.5188e-01, -1.7114e-01, -5.5963e-02,  1.4976e-01, -1.5343e-01],
         [ 7.2989e-01, -1.1941e-01, -8.5776e-02,  1.8858e-01, -1.7605e-01],
         [ 7.6747e-01, -1.8748e-01, -4.1792e-02,  1.2632e-01, -1.0911e-01],
         [ 7.2204e-01, -1.3516e-01, -6.3888e-02,  1.8362e-01, -1.0724e-01]]],
       grad_fn=<UnsafeViewBackward0>)
weights_with_mask:
tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4874, 0.5126, 0.0000, 0.0000, 0.0000],
         [0.3224, 0.3519, 0.3257, 0.0000, 0.0000],
         [0.2442, 0.2569, 0.2344, 0.2646, 0.0000],
         [0.1942, 0.1951, 0.1755, 0.2123, 0.2228]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4710, 0.5290, 0.0000, 0.0000, 0.0000],
         [0.3236, 0.3420, 0.3344, 0.0000, 0.0000],
         [0.2372, 0.2532, 0.2297, 0.2798, 0.0000],
         [0.1986, 0.2023, 0.2036, 0.1898, 0.2057]]],
       grad_fn=<SoftmaxBackward0>)
--------------------------------------------------
output_without_mask:
tensor([[[ 0.7117, -0.2931, -0.1214,  0.1422, -0.2092],
         [ 0.7153, -0.2916, -0.1187,  0.1398, -0.2083],
         [ 0.7159, -0.2925, -0.1177,  0.1389, -0.2070],
         [ 0.7104, -0.2917, -0.1218,  0.1436, -0.2120],
         [ 0.7066, -0.2917, -0.1250,  0.1468, -0.2143]],

        [[ 0.7257, -0.1405, -0.0612,  0.1784, -0.1052],
         [ 0.7272, -0.1414, -0.0627,  0.1772, -0.1101],
         [ 0.7236, -0.1370, -0.0637,  0.1817, -0.1086],
         [ 0.7324, -0.1501, -0.0565,  0.1689, -0.1016],
         [ 0.7220, -0.1352, -0.0639,  0.1836, -0.1072]]],
       grad_fn=<UnsafeViewBackward0>)
weights_without_mask:
tensor([[[0.1965, 0.2041, 0.1833, 0.2004, 0.2157],
         [0.1965, 0.2067, 0.1906, 0.1992, 0.2071],
         [0.1919, 0.2095, 0.1939, 0.1964, 0.2082],
         [0.1914, 0.2014, 0.1838, 0.2074, 0.2159],
         [0.1942, 0.1951, 0.1755, 0.2123, 0.2228]],

        [[0.1975, 0.2042, 0.1980, 0.2035, 0.1968],
         [0.1918, 0.2155, 0.1994, 0.1989, 0.1944],
         [0.1960, 0.2072, 0.2026, 0.1919, 0.2023],
         [0.1941, 0.2072, 0.1880, 0.2290, 0.1817],
         [0.1986, 0.2023, 0.2036, 0.1898, 0.2057]]],
       grad_fn=<SoftmaxBackward0>)
    

代码中的一些用法解释:

1)torch.nn.Linear ()

torch.nn.Linear() 是 PyTorch 最基础的全连接层(线性变换层),用于执行以下操作:
在这里插入图片描述

nn.Linear(in_features, out_features, bias=True),in_features 是输入特征维度,out_features是输出特征维度,bias表示是否使用偏置项,默认为 True

2)K.transpose(-2,-1)

在 PyTorch 中,torch.transpose() 用于交换张量的两个维度。参数 -2 和 -1 是指张量的倒数第二个维度和最后一个维度。

K.transpose(-2, -1) 和 K.transpose(-1, -2) 都是交换最后两个维度,它们的效果完全相同。

3)torch.triu(torch.ones(seq_len, seq_len), diagonal=1)


mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)  # 生成上三角部分为 1

此函数解释:

  • torch.ones(seq_len, seq_len) 生成 seq_len × seq_len 矩阵,所有元素都是 1
  • torch.triu(…, diagonal=1)
    torch.triu()如果不设置diagonal,即diagonal默认为0,表示取该矩阵的上三角部分为1(包含主对角线及以上的元素),其余元素置为 0。
    如果diagonal=1,表示把主对角线本身置为 0,上三角部分是主对角线以上的元素

这样返回一个seq_len=5的mask矩阵:

tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]])

4)attention_scores.masked_fill(mask==1,float('-inf'))

作用: 根据 mask矩阵 进行掩码处理,将 mask == 1 的位置填充为 -inf(负无穷),使其在 softmax 计算时权重变为 0

masked_fill 属于 PyTorch 张量(torch.Tensor)的方法,用于根据布尔掩码(mask)填充指定值。masked_fill 的语法: tensor.masked_fill(mask, value)

tensor:要修改的张量
mask:布尔掩码(True/False 或 0/1)
value:要填充的值(如 -inf)

解释:attention_scores.masked_fill(mask==1,float('-inf'))

mask == 1 选取 应该被屏蔽的位置(即 上三角部分)
masked_fill(mask == 1, -inf) 把上三角部分设为-inf
这样,Softmax 后被屏蔽的部分变成 0,不会影响注意力计算。

5) torch.softmax(attention_scores,-1)
作用: 对 attention_scores 进行 Softmax 归一化,确保注意力权重(attn_weights)的总和为 1,控制每个 Token 对序列中其他 Token 的关注程度

torch.softmax(input, dim) 语法:

input:要进行 Softmax 计算的张量
dim:沿着哪个维度进行 Softmax 计算(dim=-1 表示最后一个维度)

Multi-Head-Attention

import torch
import torch.nn as nn
import math

class Multi_Head_Attention(nn.Module):
    def __init__(self,embed_dim,nums_heads):
        super(Multi_Head_Attention,self).__init__()
        
        assert embed_dim % nums_heads ==0,"embed_dim 必须能被 num_heads 整除"
        self.embed_dim=embed_dim
        self.nums_heads=nums_heads
        self.head_dim=embed_dim//nums_heads

        self.WQ=nn.Linear(embed_dim,embed_dim)
        self.WK=nn.Linear(embed_dim,embed_dim)
        self.WV=nn.Linear(embed_dim,embed_dim)

        self.fc=nn.Linear(embed_dim,embed_dim)
        self.scale=math.sqrt(embed_dim)

    def forward(self,x,iscausal):
        batch_size,seq_len,embed_dim=x.shape
        # Q,K,V: batch_size,seq_len,embed_dim
        Q=self.WQ(x)
        K=self.WK(x)
        V=self.WV(x)
        # Q,K,V: batch_size,seq_len,embed_dim -> batch_size,seq_len,self.nums_heads,self.head_dim -> batch_size,self.nums_heads,seq_len,self.head_dim
        Q=Q.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
        K=K.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
        V=V.view(batch_size,seq_len,self.nums_heads,self.head_dim).transpose(1,2)
        # batch_size,self.nums_heads,seq_len,seq_len
        attn_scores=torch.matmul(Q,K.transpose(-1,-2))/self.scale
        
        if iscausal:
        		mask = torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
            attn_scores=attn_scores.masked_fill(mask==1,float('-inf'))
        attn_weights=torch.softmax(attn_scores,-1)
        # batch_size,self.nums_heads,seq_len,head_dim
        output=torch.matmul(attn_weights,V)
        # batch_size,self.nums_heads,seq_len,head_dim -> batch_size,seq_len,self.nums_heads,head_dim -> batch_size,seq_len,self.embed_dim
        output=output.transpose(1,2).contiguous().view(batch_size,seq_len,self.embed_dim)
        output=self.fc(output)
        return output,attn_weights
    

测试:

batch_size=2
seq_len=5
nums_heads=3
embed_dim=6
x=torch.randn(batch_size,seq_len,embed_dim)
mask=create_mask(seq_len)

multiheadattention=Multi_Head_Attention(embed_dim,nums_heads)
output,weights=multiheadattention(x,mask)
print(output)
print(weights)

代码中的一些用法解释:

1)output=output.transpose(1,2).contiguous()中的contiguous()

contiguous() 是一个重要的张量方法,用于确保张量在内存中是连续存储的

Grouped-Query-Attention

import torch 
import torch.nn as nn

class GroupedQueryAttention(nn.Module):
    def __init__(self,embed_dim,nums_head,nums_group,dropout):
        super().__init__()
        self.embed_dim = embed_dim
        self.nums_head = nums_head
        self.nums_group = nums_group
        assert nums_head%nums_group == 0
        assert embed_dim%nums_head == 0
        self.head_dim = embed_dim/nums_head
        self.num_heads_per_group = nums_head/nums_group
        self.scale = self.head_dim**0.5

        self.wq = nn.Linear(embed_dim,embed_dim)
        self.wk = nn.Linear(embed_dim,self.nums_group * self.head_dim)
        self.wv = nn.Linear(embed_dim,self.nums_group * self.head_dim)
        self.wo = nn.Linear(embed_dim,embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,is_causal):
        batch_size,seq_len,_ = x.shape
        # q: (batch_size,seq_len,embed_dim)
        q = self.wq(x)
        # k v: (batch_size,seq_len,nums_group * head_dim)
        k = self.wk(x)
        v = self.wv(x)
        # q: (batch_size,seq_len,embed_dim) -> (batch_size,seq_len,nums_head,head_dim) ->  (batch_size,nums_head,seq_len,head_dim) 
        q = q.view(batch_size,seq_len,self.nums_head,self.head_dim).transpose(1,2)
        # k: (batch_size,seq_len,nums_group * head_dim) -> (batch_size,seq_len,nums_group,head_dim) ->  (batch_size,nums_group,seq_len,head_dim) 
        k = k.view(batch_size,seq_len,self.nums_group,self.head_dim).transpose(1,2)
        # 将每组的键和值复制到组内的每个头
        #  k.unsqueeze(2): (batch_size,nums_group,seq_len,head_dim) -> (batch_size,nums_group,1,seq_len,head_dim)
        #  k.repeat: (batch_size,nums_group,1,seq_len,head_dim) -> (batch_size,nums_group,num_heads_per_group,seq_len,head_dim)
        #  k : (batch_size,self.nums_head,seq_len,self.head_dim) 
        k = k.unsqueeze(2).repeat(1,1,self.num_heads_per_group,1,1).view(batch_size,self.nums_head,seq_len,self.head_dim) 
        
        # v: (batch_size,seq_len,nums_group * head_dim) -> (batch_size,seq_len,nums_group,head_dim) ->  (batch_size,nums_group,seq_len,head_dim) 
        v = v.view(batch_size,seq_len,self.nums_group,self.head_dim).transpose(1,2)
        #  v : (batch_size,self.nums_head,seq_len,self.head_dim) 
        v = v.unsqueeze(2).repeat(1,1,self.num_heads_per_group,1,1).view(batch_size,self.nums_head,seq_len,self.head_dim)

        # attn_scores:  (batch_size,self.nums_head,seq_len,seq_len) 
        attn_scores = torch.matmul(q,k.transpose(-1,-2)) / self.scale
        if is_causal:
            mask = torch.triu(torch.ones(seq_len,seq_len),diagonal=1)
            attn_scores = attn_scores.masked_fill(mask==1,float('-inf'))
        
        # attn_weights:  (batch_size,self.nums_head,seq_len,seq_len) 
        attn_weights = torch.softmax(attn_scores,dim=-1)
        attn_weights = self.dropout(attn_weights)
        # attn_outputs : (batch_size,self.nums_head,seq_len,self.head_dim) 
        attn_outputs = torch.matmul(attn_weights,v)
        # attn_outputs : (batch_size,self.nums_head,seq_len,self.head_dim)  -> (batch_size,seq_len,self.nums_head,self.head_dim) -> (batch_size,seq_len,self.embed_dim)
        attn_outputs = attn_outputs.transpose(1,2).contiguous.view(batch_size,seq_len,self.embed_dim)
        # outputs: (batch_size,seq_len,self.embed_dim)
        outputs = self.wo(attn_outputs)
        outputs = self.dropout(outputs)

        return outputs

关于注意力机制的问题

为什么qk的乘积要除以 d k \sqrt{d_k} dk

避免qk点积值过大导致梯度消失问题,使得Softmax归一化时结果更稳定

为什么qk点积值过大导致梯度消失问题?

在这里插入图片描述
在这里插入图片描述

为什么注意力得分和最终输出要过dropout

  1. 对attn_weights使用 Dropout:随机将部分注意力权重置为 0,强制模型不过度依赖某些特定 token 之间的关联。
  2. 对outputs使用 Dropout:随机将输出特征向量中的部分维度置为 0,防止模型过度依赖特定维度的特征。
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐