【昇腾CANN训练营·行业篇】长序列推理救星:FlashDecoding算子开发与KV Cache并行规约实战
摘要:2025年昇腾CANN训练营第二季推出FlashDecoding专题课程,聚焦大模型推理中的Decode阶段性能优化。针对长上下文场景下Attention计算的访存瓶颈,提出KVCache切分策略(Split-K),通过OnlineSoftmax数学公式实现分块结果的无损合并。课程详细讲解AscendC实现方案,包括Stage1分块计算和Stage2全局规约两个核心Kernel,并分析异步流
训练营简介 2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。
报名链接:https://www.hiascend.com/developer/activities/cann20252#cann-camp-2502-intro

前言
在大模型推理的 Decode(生成) 阶段,每一次 Token 生成,Query 的长度只有 1。 如果上下文长度(KV Cache)达到了 100k,计算 $Attention(Q, K, V)$ 本质上就是把一个 $1 \times D$ 的向量,和一个 $100k \times D$ 的大矩阵做乘法(GEMV)。
这是一个典型的 Memory Bound(访存受限) 场景。 更糟糕的是,FlashAttention V2 的并行策略是 Batch x Head。在 Decode 阶段,Batch=1,Head=32,意味着只有 32 个 AI Core 在干活,剩下的几百个 Core 都在围观!
FlashDecoding 的破局之道在于:既然 Head 不够分,那就把 KV Cache 切开分! 把 100k 的 KV Cache 切成 100 份,分给 100 个核同时算,最后再把结果拼起来。这听起来简单,但Softmax 的归一化怎么处理?分块计算的结果怎么合并?
本期文章将带你攻克这个大模型推理的“提速神器”。
一、 核心图解:人多力量大 vs 众口难调
FlashDecoding 的核心是 Split-K(K轴切分),但难点在于 Reduce(合并)。

二、 算法原理:Online Softmax 的合并公式
假设我们将 KV Cache 切分成两块:$Block_1$ 和 $Block_2$。 Core 1 计算出:$O_1, m_1, l_1$ (局部输出,局部最大值,局部指数和)。 Core 2 计算出:$O_2, m_2, l_2$。
我们要得到的全局结果是基于全局最大值 $m_{global} = \max(m_1, m_2)$ 的。 利用 Softmax 的数学性质,合并公式如下:
-
更新全局 Max: $m_{new} = \max(m_1, m_2)$
-
计算缩放因子: $factor_1 = e^{m_1 - m_{new}}$, $factor_2 = e^{m_2 - m_{new}}$
-
合并指数和: $l_{new} = l_1 \cdot factor_1 + l_2 \cdot factor_2$
-
合并输出: $O_{new} = O_1 \cdot factor_1 + O_2 \cdot factor_2$
-
最终归一化: $O_{final} = O_{new} / l_{new}$
这就是我们需要在 Ascend C 算子中实现的逻辑。
三、 实战:Ascend C 实现 FlashDecoding
FlashDecoding 通常分为两个 Kernel:
-
FlashDecodingStage1: 各个 Core 独立计算分块的 Attention,输出
Partial_O(未归一化),LogSumExp(即 m+log(l))。 -
FlashDecodingStage2: 一个 Reduce Kernel,读取所有 Partial 结果,应用合并公式,计算最终 Output。
我们重点实现 Stage 1 的核心逻辑。
3.1 Kernel 类定义
输入 Query 是 [1, HeadDim],KV Cache 是 [SeqLen, HeadDim]。
class KernelFlashDecoding {
public:
__aicore__ inline void Init(GM_ADDR query, GM_ADDR key, GM_ADDR value,
GM_ADDR output, GM_ADDR stats, // stats 存 max 和 sum
uint32_t seqLen, uint32_t headDim) {
// Tiling: 每个核处理 seqLen 的一部分 (例如 128 个 Token)
// block_idx 决定当前核处理哪一段 KV
// ...
}
};
3.2 Compute 核心逻辑 (Stage 1)
这是一个典型的 GEMV(矩阵向量乘)+ Softmax 过程。
__aicore__ inline void Compute(int32_t i) {
// 1. Load Query (常驻 UB)
// Query 很小,只有 1xDim,直接广播使用
// 2. Load KV Block
// keyLocal: [BlockSize, Dim]
DataCopy(keyLocal, keyGm[offset], blockSize * dim);
// 3. QK^T (MatMul or Vector Mul)
// 由于 Q 是向量,这里其实是 Batch Vector Mul
// scores[j] = dot(Q, K[j])
// 建议使用 Muls + ReduceSum 或者 Cube 1xN 计算
// 结果 scores: [BlockSize]
// 4. Local Softmax Statistics
// 4.1 找局部最大值 m_block
ReduceMax(maxLoc, scores, blockSize);
float m_val = maxLoc.GetValue(0);
// 4.2 减 Max 并求 Exp
// scores = exp(scores - m_val)
Adds(scores, scores, -m_val, blockSize);
Exp(scores, scores, blockSize);
// 4.3 求局部 Sum l_block
ReduceSum(sumLoc, scores, blockSize);
float l_val = sumLoc.GetValue(0);
// 5. P * V (Weighted Sum)
// O_block = scores * V_block
// V_block: [BlockSize, Dim]
// 这是一个加权求和过程,Vector 逐行相乘再相加
// 6. Write Back Partial Results
// 我们不除以 l_val,而是直接把非归一化的 O_block 和 (m_val, l_val) 写回去
// 留给 Stage 2 做全局合并
DataCopy(outputGm[block_idx], o_block, dim);
DataCopy(statsGm[block_idx], {m_val, l_val}, 2);
}
3.3 进阶:Stage 2 (全局规约)
Stage 2 的 Kernel 只需要启动一个核(或者每个 Head 启动一个核)。 它读取 Stage 1 产生的所有 Partial Outputs 和 Stats。
// 伪代码:合并逻辑
void ReduceStage() {
float global_m = -INF;
float global_l = 0;
Vector global_o = {0};
// 遍历所有分块
for (int i = 0; i < num_blocks; i++) {
float m_i = stats[i].m;
float l_i = stats[i].l;
Vector o_i = outputs[i];
// 更新全局 Max
float new_m = max(global_m, m_i);
// 计算缩放系数
float factor_global = exp(global_m - new_m);
float factor_i = exp(m_i - new_m);
// 更新全局 Accumulator
global_l = global_l * factor_global + l_i * factor_i;
global_o = global_o * factor_global + o_i * factor_i;
global_m = new_m;
}
// 最终归一化
FinalOutput = global_o / global_l;
}
四、 性能优化的“胜负手”
FlashDecoding 的瓶颈在于 Reduce 阶段 和 KV Cache 的读取。
-
异步流水线:Stage 1 中读取 KV Block 和计算 QK 应高度并行。利用 Ascend C 的
TPipe掩盖 MTE 搬运延迟。 -
Atomic Add (原子归约): 如果不使用 Stage 2 Kernel,也可以让 Stage 1 的每个核直接通过
SetAtomicAdd将结果累加到 Global Memory。-
难点:原子加只支持 Sum,不支持 Max。所以需要两轮:第一轮原子更新 Max,第二轮原子累加 Sum 和 Output。这通常不如独立的 Stage 2 Kernel 高效。
-
-
Page Attention 适配: 实际场景中 KV Cache 是非连续的(Paged)。在 Load KV Block 时,需要根据 Block Table 进行 Gather 操作(参考第四十四期稀疏计算),而不能简单的 DataCopy。
五、 总结
FlashDecoding 是大模型推理性能优化的皇冠。
-
并行维度变革:从
Batch并行转向SeqLen并行,解决了 Decode 阶段 GPU/NPU 利用率低的问题。 -
数学技巧:利用 Online Softmax 的重缩放(Rescaling)公式,实现了分块结果的无损合并。
-
Ascend C 实践:通过 Stage1(Map)+ Stage2(Reduce)的模式,完美契合了 AI Core 的计算架构。
掌握了 FlashDecoding,你就掌握了让 DeepSeek、Llama3 等长文本模型在昇腾上飞速运行的秘诀。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐

所有评论(0)