多头注意力机制(Multi-Head Attention)是Transformer模型的核心组件之一,由Vaswani等人在2017年的论文《Attention Is All You Need》中提出。它通过并行计算多个独立的注意力头(Attention Head),增强模型捕捉不同子空间语义信息的能力,从而提高对复杂上下文关系的建模效果。


核心思想

  1. 单头注意力(Single-Head Attention)的局限性

    • 传统注意力机制通过查询(Query)、键(Key)、值(Value)的交互计算权重,但单一注意力头可能只能关注一种模式的信息。
    • 例如:在机器翻译中,可能需要同时关注句子的语法结构、语义重点、指代关系等不同维度的信息。
  2. 多头注意力的设计动机

    • 通过多个独立的注意力头并行计算,每个头学习不同的关注模式(如局部依赖、长程依赖等)。
    • 类似卷积神经网络(CNN)中多个滤波器提取不同特征。

具体结构

  1. 输入拆分

    • 输入向量(Query、Key、Value)被线性投影到多个低维子空间,每个子空间对应一个注意力头。
    • 例如:假设输入维度为d_model=512,头数为h=8,则每个头的维度为d_k = d_v = 512/8 = 64
  2. 每个头的独立计算

    • 每个头通过缩放点积注意力(Scaled Dot-Product Attention)计算:
      [
      \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
      ]
    • √d_k的作用是缩放点积,防止数值过大导致梯度消失。
  3. 多头输出的拼接与整合

    • 所有头的输出拼接后,通过线性变换(权重矩阵W^O)映射回原始维度。
      [
      \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, …, \text{head}_h)W^O
      ]

优势与作用

  1. 多样化特征提取

    • 不同头可以关注输入的不同部分。例如:
      • 一个头可能捕捉局部依赖(如短语结构),
      • 另一个头可能捕捉长距离依赖(如句子主题)。
  2. 增强模型鲁棒性

    • 并行计算降低了对单一注意力模式的依赖,提升模型泛化能力。
  3. 高效并行计算

    • 所有注意力头可以并行计算,充分利用GPU等硬件加速。

实际应用示例

  • 机器翻译
    • 某头可能关注动词与宾语的匹配,
    • 另一头可能关注代词与先行词的关系。
  • 文本摘要
    • 不同头可能分别关注关键词、时间顺序或因果关系。

代码实现(简化版)

import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, h=8):
        super().__init__()
        self.d_model = d_model
        self.h = h
        self.d_k = d_model // h
        
        # 线性变换矩阵
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        # 线性投影并分头
        Q = self.W_Q(Q).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        K = self.W_K(K).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        V = self.W_V(V).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        
        # 计算缩放点积注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V)
        
        # 拼接多头结果并输出
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_O(output)

常见问题

  1. 头数如何选择?

    • 典型配置为8或16个头(如BERT、GPT),但需根据任务和模型规模调整。头数过多可能导致计算冗余。
  2. 与CNN/RNN的区别?

    • 多头注意力无需递归或卷积操作,直接建模全局依赖,且支持并行计算。
  3. 参数量如何?

    • 参数量主要来自线性变换矩阵(W_Q, W_K, W_V, W_O),总参数量为4*d_model²

多头注意力机制通过分而治之的策略,显著提升了Transformer对复杂语义关系的建模能力,成为现代NLP模型的基石。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐