CANN-ATB图优化-昇腾NPU怎么把PyTorch计算图改造成融合算子
·
CANN-ATB图优化-昇腾NPU怎么把PyTorch计算图改造成融合算子
HuggingFace 模型的 PyTorch 计算图是逐算子定义的——一层 Transformer 里有 11 个独立 kernel。ATB 的图优化器把它改造成 3-4 个融合 kernel。这篇拆解 ATB 的图优化具体做了什么替换。
优化器的工作流程
1. 解析 HuggingFace 模型 → 得到 PyTorch FX Graph
2. Pattern Matching → 识别可融合的子图模式
3. 算子替换 → 把子图替换成融合算子
4. 常量折叠 → 编译时算出可以提前计算的值
5. 内存规划 → 分配融合算子的中间 buffer
步骤 2 是核心,ATB 有一组预定义的 Pattern 规则。
Pattern 规则清单
Pattern 1:QKV + RoPE → MergedMatMul + RotaryEmbedding
匹配前:
x → Linear(Q) → RoPE(Q)
x → Linear(K) → RoPE(K)
x → Linear(V)
3 个独立子图,3 次 HBM 读 x
匹配后:
x → MergedMatMul(Q,K,V) → RotaryEmbedding(Q,K)
1 个融合算子,1 次 HBM 读 x
条件:Q/K/V Linear 共享同一个输入 x,且权重在同一个设备上。
Pattern 2:SDPA → FlashAttention
匹配前:
Q·K^T → Scale → Mask → Softmax → Drop → ·V
匹配后:
FlashAttention(Q, K, V, mask, scale)
条件:输入 Q/K/V 的 dtype 是 float16/bfloat16,head_dim 是 128 的倍数。
Pattern 3:Gate + Up + SiLU → FusedGateUp
匹配前:
x → Linear(Gate) → SiLU
x → Linear(Up) → ×
1 个乘法 + 1 个 Down Linear
匹配后:
x → FusedGateUp → Down Linear
MergedMatMul(Gate,Up) + SiLU + ElementWise 一次完成
Pattern 4:Add + LayerNorm → FusedAddNorm
匹配前:
residual + sublayer_out → LayerNorm
匹配后:
FusedAddNorm(residual, sublayer_out)
优化效果
Llama2-7B 单层 kernel 数量:
| 优化阶段 | kernel 数 | HBM 读写 (GB) |
|---|---|---|
| 原始 PyTorch | 11 | 3.8 |
| + Pattern 1,2 | 6 | 1.6 |
| + Pattern 3 | 4 | 1.1 |
| + Pattern 4 | 3 | 0.8 |
从 11 个 kernel 降到 3 个,HBM 读写减少 79%。
调试方法
图优化的替换过程可以通过日志查看:
import atb
model = atb.LLM("meta-llama/Llama-2-7b-hf", device="npu:0",
log_level="debug")
# 日志会输出:
# [GraphOptimizer] Pattern QKV+RoPE matched at layer 0
# [GraphOptimizer] Replaced with MergedMatMul+RotaryEmbedding
# [GraphOptimizer] Pattern SDPA matched at layer 0
# [GraphOptimizer] Replaced with FlashAttention
如果某个 Pattern 没匹配上,常见原因:
- 模型结构跟 Pattern 不一致(比如自定义的 Attention 实现)
- dtype 或维度不满足条件
- 模型用了
torch.jit.script而不是torch.fx——ATB 只支持 FX Graph
自定义 Pattern
如果你的模型有 ATB 没覆盖的融合模式,可以注册自定义 Pattern:
from atb import GraphOptimizer, Pattern, FusionOp
# 定义自定义 Pattern
class MyPattern(Pattern):
def match(self, graph):
# 匹配逻辑:找到自定义的子图
for node in graph.nodes:
if node.op == "call_function" and node.target == my_custom_op:
return [node]
return None
def replace(self, matched_nodes):
# 替换逻辑
return FusionOp("my_fused_op", ...)
optimizer = GraphOptimizer()
optimizer.register_pattern(MyPattern())
ATB 的图优化是推理加速的第一道关卡。如果图优化没做好,后面再怎么调 Tiling 和调度都没用——kernel 数量决定了性能天花板。调试时先看日志确认所有 Pattern 都匹配上了。仓库在这里:
https://atomgit.com/cann/ATB
更多推荐


所有评论(0)