参考视频:面试必刷:大模型为什么要使用多头注意力?_哔哩哔哩_bilibili

详解文章:Transformer内容详解(通透版)-CSDN博客

单头注意力的劣势:单头注意力只能从一个角度“看”输入序列,计算得到的注意力权重反映的是一种特定的关注模式。

多头注意力将注意力分为了多个“头”,每个头独立计算注意力,关注输入的不同子空间或不同方面的特征。这样,模型能够并行地捕捉到多种不同类型的语义关系

将输入投射到多个不同的低维空间,分别计算注意力,最后再concat拼接,通过线性变换融合丰富了模型的表达能力,使得Transformer能够学习复杂的组合特征。同时,每个注意力头的参数量和计算复杂度降低,有助于提升训练的稳定性和效率,有利于收敛

单头注意力:

import torch
import torch.nn as nn
 
class Self_Attention(nn.Module):
    def __init__(self, dim, dk, dv):
        super().__init__()
        self.scale = dk ** -0.5
        self.q = nn.Linear(dim, dk)
        self.k = nn.Linear(dim, dk)
        self.v = nn.Linear(dim, dv)
 
    def forward(self, x):
        # x: [batch, seq_len, dim]
        q = self.q(x)  # [batch, seq_len, dk]
        k = self.k(x)  # [batch, seq_len, dk]
        v = self.v(x)  # [batch, seq_len, dv]
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [batch, seq_len, seq_len]
        attn = attn.softmax(dim=-1)
        out = attn @ v  # [batch, seq_len, dv]
        return out

多头注意力:

import torch
import torch.nn as nn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, dk, dv, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.dk = dk
        self.dv = dv
        self.q_linear = nn.Linear(dim, dk * num_heads)
        self.k_linear = nn.Linear(dim, dk * num_heads)
        self.v_linear = nn.Linear(dim, dv * num_heads)
        self.out_linear = nn.Linear(dv * num_heads, dim)
 
    def forward(self, x):
        B, N, _ = x.shape  # batch, seq_len, dim
        Q = self.q_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2)  # [B, heads, N, dk]
        K = self.k_linear(x).view(B, N, self.num_heads, self.dk).transpose(1, 2)
        V = self.v_linear(x).view(B, N, self.num_heads, self.dv).transpose(1, 2)
        # Attention
        attn = (Q @ K.transpose(-2, -1)) / (self.dk ** 0.5)  # [B, heads, N, N]
        attn = attn.softmax(dim=-1)
        out = attn @ V  # [B, heads, N, dv]
        out = out.transpose(1, 2).reshape(B, N, self.num_heads * self.dv)  # [B, N, heads*dv]
        out = self.out_linear(out)  # [B, N, dim]
        return out

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐