CANN-ATB图优化-昇腾NPU怎么把PyTorch计算图改造成融合算子

HuggingFace 模型的 PyTorch 计算图是逐算子定义的——一层 Transformer 里有 11 个独立 kernel。ATB 的图优化器把它改造成 3-4 个融合 kernel。这篇拆解 ATB 的图优化具体做了什么替换。

优化器的工作流程

1. 解析 HuggingFace 模型 → 得到 PyTorch FX Graph
2. Pattern Matching → 识别可融合的子图模式
3. 算子替换 → 把子图替换成融合算子
4. 常量折叠 → 编译时算出可以提前计算的值
5. 内存规划 → 分配融合算子的中间 buffer

步骤 2 是核心,ATB 有一组预定义的 Pattern 规则。

Pattern 规则清单

Pattern 1:QKV + RoPE → MergedMatMul + RotaryEmbedding

匹配前:
  x → Linear(Q) → RoPE(Q)
  x → Linear(K) → RoPE(K)
  x → Linear(V)
  3 个独立子图,3 次 HBM 读 x

匹配后:
  x → MergedMatMul(Q,K,V) → RotaryEmbedding(Q,K)
  1 个融合算子,1 次 HBM 读 x

条件:Q/K/V Linear 共享同一个输入 x,且权重在同一个设备上。

Pattern 2:SDPA → FlashAttention

匹配前:
  Q·K^T → Scale → Mask → Softmax → Drop → ·V

匹配后:
  FlashAttention(Q, K, V, mask, scale)

条件:输入 Q/K/V 的 dtype 是 float16/bfloat16,head_dim 是 128 的倍数。

Pattern 3:Gate + Up + SiLU → FusedGateUp

匹配前:
  x → Linear(Gate) → SiLU
  x → Linear(Up) → ×
  1 个乘法 + 1 个 Down Linear

匹配后:
  x → FusedGateUp → Down Linear
  MergedMatMul(Gate,Up) + SiLU + ElementWise 一次完成

Pattern 4:Add + LayerNorm → FusedAddNorm

匹配前:
  residual + sublayer_out → LayerNorm

匹配后:
  FusedAddNorm(residual, sublayer_out)

优化效果

Llama2-7B 单层 kernel 数量:

优化阶段 kernel 数 HBM 读写 (GB)
原始 PyTorch 11 3.8
+ Pattern 1,2 6 1.6
+ Pattern 3 4 1.1
+ Pattern 4 3 0.8

从 11 个 kernel 降到 3 个,HBM 读写减少 79%。

调试方法

图优化的替换过程可以通过日志查看:

import atb

model = atb.LLM("meta-llama/Llama-2-7b-hf", device="npu:0",
                 log_level="debug")

# 日志会输出:
# [GraphOptimizer] Pattern QKV+RoPE matched at layer 0
# [GraphOptimizer] Replaced with MergedMatMul+RotaryEmbedding
# [GraphOptimizer] Pattern SDPA matched at layer 0
# [GraphOptimizer] Replaced with FlashAttention

如果某个 Pattern 没匹配上,常见原因:

  1. 模型结构跟 Pattern 不一致(比如自定义的 Attention 实现)
  2. dtype 或维度不满足条件
  3. 模型用了 torch.jit.script 而不是 torch.fx——ATB 只支持 FX Graph

自定义 Pattern

如果你的模型有 ATB 没覆盖的融合模式,可以注册自定义 Pattern:

from atb import GraphOptimizer, Pattern, FusionOp

# 定义自定义 Pattern
class MyPattern(Pattern):
    def match(self, graph):
        # 匹配逻辑:找到自定义的子图
        for node in graph.nodes:
            if node.op == "call_function" and node.target == my_custom_op:
                return [node]
        return None

    def replace(self, matched_nodes):
        # 替换逻辑
        return FusionOp("my_fused_op", ...)

optimizer = GraphOptimizer()
optimizer.register_pattern(MyPattern())

ATB 的图优化是推理加速的第一道关卡。如果图优化没做好,后面再怎么调 Tiling 和调度都没用——kernel 数量决定了性能天花板。调试时先看日志确认所有 Pattern 都匹配上了。仓库在这里:

https://atomgit.com/cann/ATB

Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐