DeepSeekNSA省力的同时还能提速!

一、引言:为什么要关注稀疏注意力?

在 Transformer 闪耀 NLP 舞台的时代,模型对长文本的处理仍是痛点

  • 计算瓶颈:全注意力机制需对每对 Token 计算相似度,复杂度为 O(n2)O(n^2)O(n2),长文本时计算量呈平方增长。
  • 显存压力:在 GPU 上训练或推理长序列时,显存消耗剧增,甚至溢出。
  • 应用场景:文档检索、代码审查、学术论文阅读等场合,需要模型理解成千上万的 Token。

2025 年 ACL’25 最佳论文——《Native Sparse Attention (NSA)》,提出了一种既大幅降低计算成本提升模型性能的稀疏注意力方案,开启了长上下文处理新纪元。

二、核心原理拆解

2.1 三条稀疏分支如何协作?
  1. 全局压缩注意力

    • 原理:将长度 nnn 的序列通过平均/最大池化等操作压缩到长度 mmm,再在压缩后序列上执行全注意力。
    • 计算复杂度由 O(n2)O(n^2)O(n2) 降至 O(m2)O(m^2)O(m2):前者表示随着序列长度平方倍增长,后者仅随压缩后长度平方增长。
    • 示例:当 n=65536n=65536n=65536m=1024m=1024m=1024 时,计算量从 4.3×1094.3\times10^94.3×109 降至 1.0×1061.0\times10^61.0×106(减少约四个数量级)。
    • 比喻:快速扫读目录,掌握整体框架。
  2. 选择性注意力(Selective Sparsity)

    • 原理:为每个 Token 计算重要性分数(基于键-值内积或门控网络),只保留 Top-kkk 个最重要的 Token 进行注意力计算,复杂度约 O(nk)O(nk)O(nk).
    • 动态调整kkk 可根据任务难度自动伸缩,平衡覆盖率与效率。
    • 比喻:用书签标记重点段落,只精读这些片段。
  3. 滑动窗口注意力(Sliding Window)

    • 原理:以步长 sss 和窗口宽度 www 划分序列,在每个窗口内执行全注意力,局部复杂度 O(ws)O(ws)O(ws)
    • 参数示例:常用 w=512,s=256w=512, s=256w=512,s=256,可根据 GPU 显存灵活设置。
    • 比喻:用放大镜逐段扫描,确保连续上下文不丢失。

三条分支并行计算后,通过可学习权重 (α,β,γ)(\alpha,\beta,\gamma)(α,β,γ) 加权融合:

NSA(X)=α Aglobal(X)+β Aselective(X)+γ Awindow(X) \mathrm{NSA}(X)=\alpha\,A_{global}(X)+\beta\,A_{selective}(X)+\gamma\,A_{window}(X) NSA(X)=αAglobal(X)+βAselective(X)+γAwindow(X)

2.2 补充知识点
  • 对比常见稀疏策略

    • 局部注意力(Local):仅滑窗,缺乏全局感知;
    • 块稀疏(Block Sparse):固定块交叉,局部全局互限;
    • BigBird:随机+局部+全局,随机性带来不稳定;
    • NSA:全局压缩+选择性+滑窗,系统性覆盖全局、重点、局部,兼顾效率与稳定。
  • 复杂度公式:NSA 总体复杂度为 O(m2+nk+ws)O(m^2 + nk + ws)O(m2+nk+ws),当 m,k,w,s≪nm,k,w,s\ll nm,k,w,sn 时,远低于 O(n2)O(n^2)O(n2)

  • 可视化示例:专栏提供的注意力热力图对比,直观展示 NSA 如何同时捕捉全局脉络与局部细节。

2.3 伪代码示例
# NSA 三分支核心伪码(示例)
import torch.nn.functional as F

def global_compress_attention(X, m):
    # X: [B, L, D]
    X_pooled = F.adaptive_avg_pool1d(X.transpose(1,2), m).transpose(1,2)
    return full_attention(X_pooled)

def selective_attention(X, k):
    # 计算重要性分数并取 top-k
    scores = (X @ X.transpose(-1,-2)).mean(dim=1)  # 简化示例
    idx = scores.topk(k, dim=-1).indices
    return full_attention(X[idx], X)

def sliding_window_attention(X, w, s):
    # 分段滑窗
    outs = []
    for i in range(0, X.size(1)-w+1, s):
        outs.append(full_attention(X[:, i:i+w], X[:, i:i+w]))
    return torch.cat(outs, dim=1)

# 融合输出
 def NSA_attention(X, config):
    g = global_compress_attention(X, config.m)
    s = selective_attention(X, config.k)
    l = sliding_window_attention(X, config.w, config.s)
    return config.alpha*g + config.beta*s + config.gamma*l

此伪代码仅用于流程示意,不可直接用于生产环境。


三、快速记忆核心知识点

  • 全局压缩 = 粗读目录:快速建立整体框架;
  • 选择性 = 精读标记:专注最重要内容;
  • 滑动窗口 = 放大镜:细致扫描局部上下文。

三步结合,让模型既能“见微知著”,又能“略窥全豹”。


四、量化评估

评估类型 序列长度 全注意力延迟 NSA 延迟 加速比 备注
推理(Forward) 64k 24s 2.07s ~11.6× 无性能损耗
训练(Backward) 32k 18s 3s ~6× 包含自定义反向算子
显存峰值 64k 100% ~65% 节省显存
通用基准(9项平均) +0.022 性能略升
数学推理(GSM8K) +0.034 推理能力增强
长文本任务(LongBench) +0.032 精度提升

解读:NSA 在长文本场景下实现10×以上加速,同时因聚焦重要信息,性能亦得到提升。



Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐