训练营简介 2025年昇腾CANN训练营第二季,基于CANN开源开放全场景,推出0基础入门系列、码力全开特辑、开发者案例等专题课程,助力不同阶段开发者快速提升算子开发技能。获得Ascend C算子中级认证,即可领取精美证书,完成社区任务更有机会赢取华为手机,平板、开发板等大奖。

报名链接:https://www.hiascend.com/developer/activities/cann20252#cann-camp-2502-intro

前言

在大模型推理的 Decode(生成) 阶段,每一次 Token 生成,Query 的长度只有 1。 如果上下文长度(KV Cache)达到了 100k,计算 $Attention(Q, K, V)$ 本质上就是把一个 $1 \times D$ 的向量,和一个 $100k \times D$ 的大矩阵做乘法(GEMV)。

这是一个典型的 Memory Bound(访存受限) 场景。 更糟糕的是,FlashAttention V2 的并行策略是 Batch x Head。在 Decode 阶段,Batch=1,Head=32,意味着只有 32 个 AI Core 在干活,剩下的几百个 Core 都在围观!

FlashDecoding 的破局之道在于:既然 Head 不够分,那就把 KV Cache 切开分! 把 100k 的 KV Cache 切成 100 份,分给 100 个核同时算,最后再把结果拼起来。这听起来简单,但Softmax 的归一化怎么处理?分块计算的结果怎么合并?

本期文章将带你攻克这个大模型推理的“提速神器”。

一、 核心图解:人多力量大 vs 众口难调

FlashDecoding 的核心是 Split-K(K轴切分),但难点在于 Reduce(合并)

二、 算法原理:Online Softmax 的合并公式

假设我们将 KV Cache 切分成两块:$Block_1$ 和 $Block_2$。 Core 1 计算出:$O_1, m_1, l_1$ (局部输出,局部最大值,局部指数和)。 Core 2 计算出:$O_2, m_2, l_2$。

我们要得到的全局结果是基于全局最大值 $m_{global} = \max(m_1, m_2)$ 的。 利用 Softmax 的数学性质,合并公式如下:

  1. 更新全局 Max: $m_{new} = \max(m_1, m_2)$

  2. 计算缩放因子: $factor_1 = e^{m_1 - m_{new}}$, $factor_2 = e^{m_2 - m_{new}}$

  3. 合并指数和: $l_{new} = l_1 \cdot factor_1 + l_2 \cdot factor_2$

  4. 合并输出: $O_{new} = O_1 \cdot factor_1 + O_2 \cdot factor_2$

  5. 最终归一化: $O_{final} = O_{new} / l_{new}$

这就是我们需要在 Ascend C 算子中实现的逻辑。

三、 实战:Ascend C 实现 FlashDecoding

FlashDecoding 通常分为两个 Kernel:

  1. FlashDecodingStage1: 各个 Core 独立计算分块的 Attention,输出 Partial_O (未归一化), LogSumExp (即 m+log(l))。

  2. FlashDecodingStage2: 一个 Reduce Kernel,读取所有 Partial 结果,应用合并公式,计算最终 Output。

我们重点实现 Stage 1 的核心逻辑。

3.1 Kernel 类定义

输入 Query 是 [1, HeadDim],KV Cache 是 [SeqLen, HeadDim]

class KernelFlashDecoding {
public:
    __aicore__ inline void Init(GM_ADDR query, GM_ADDR key, GM_ADDR value, 
                                GM_ADDR output, GM_ADDR stats, // stats 存 max 和 sum
                                uint32_t seqLen, uint32_t headDim) {
        // Tiling: 每个核处理 seqLen 的一部分 (例如 128 个 Token)
        // block_idx 决定当前核处理哪一段 KV
        // ...
    }
};

3.2 Compute 核心逻辑 (Stage 1)

这是一个典型的 GEMV(矩阵向量乘)+ Softmax 过程。

__aicore__ inline void Compute(int32_t i) {
    // 1. Load Query (常驻 UB)
    // Query 很小,只有 1xDim,直接广播使用
    
    // 2. Load KV Block
    // keyLocal: [BlockSize, Dim]
    DataCopy(keyLocal, keyGm[offset], blockSize * dim);

    // 3. QK^T (MatMul or Vector Mul)
    // 由于 Q 是向量,这里其实是 Batch Vector Mul
    // scores[j] = dot(Q, K[j])
    // 建议使用 Muls + ReduceSum 或者 Cube 1xN 计算
    // 结果 scores: [BlockSize]
    
    // 4. Local Softmax Statistics
    // 4.1 找局部最大值 m_block
    ReduceMax(maxLoc, scores, blockSize);
    float m_val = maxLoc.GetValue(0);
    
    // 4.2 减 Max 并求 Exp
    // scores = exp(scores - m_val)
    Adds(scores, scores, -m_val, blockSize);
    Exp(scores, scores, blockSize);
    
    // 4.3 求局部 Sum l_block
    ReduceSum(sumLoc, scores, blockSize);
    float l_val = sumLoc.GetValue(0);

    // 5. P * V (Weighted Sum)
    // O_block = scores * V_block
    // V_block: [BlockSize, Dim]
    // 这是一个加权求和过程,Vector 逐行相乘再相加
    
    // 6. Write Back Partial Results
    // 我们不除以 l_val,而是直接把非归一化的 O_block 和 (m_val, l_val) 写回去
    // 留给 Stage 2 做全局合并
    
    DataCopy(outputGm[block_idx], o_block, dim);
    DataCopy(statsGm[block_idx], {m_val, l_val}, 2);
}

3.3 进阶:Stage 2 (全局规约)

Stage 2 的 Kernel 只需要启动一个核(或者每个 Head 启动一个核)。 它读取 Stage 1 产生的所有 Partial OutputsStats

// 伪代码:合并逻辑
void ReduceStage() {
    float global_m = -INF;
    float global_l = 0;
    Vector global_o = {0};

    // 遍历所有分块
    for (int i = 0; i < num_blocks; i++) {
        float m_i = stats[i].m;
        float l_i = stats[i].l;
        Vector o_i = outputs[i];

        // 更新全局 Max
        float new_m = max(global_m, m_i);
        
        // 计算缩放系数
        float factor_global = exp(global_m - new_m);
        float factor_i = exp(m_i - new_m);

        // 更新全局 Accumulator
        global_l = global_l * factor_global + l_i * factor_i;
        global_o = global_o * factor_global + o_i * factor_i;
        
        global_m = new_m;
    }

    // 最终归一化
    FinalOutput = global_o / global_l;
}

四、 性能优化的“胜负手”

FlashDecoding 的瓶颈在于 Reduce 阶段KV Cache 的读取

  1. 异步流水线:Stage 1 中读取 KV Block 和计算 QK 应高度并行。利用 Ascend C 的 TPipe 掩盖 MTE 搬运延迟。

  2. Atomic Add (原子归约): 如果不使用 Stage 2 Kernel,也可以让 Stage 1 的每个核直接通过 SetAtomicAdd 将结果累加到 Global Memory。

    • 难点:原子加只支持 Sum,不支持 Max。所以需要两轮:第一轮原子更新 Max,第二轮原子累加 Sum 和 Output。这通常不如独立的 Stage 2 Kernel 高效。

  3. Page Attention 适配: 实际场景中 KV Cache 是非连续的(Paged)。在 Load KV Block 时,需要根据 Block Table 进行 Gather 操作(参考第四十四期稀疏计算),而不能简单的 DataCopy。

五、 总结

FlashDecoding 是大模型推理性能优化的皇冠。

  1. 并行维度变革:从 Batch 并行转向 SeqLen 并行,解决了 Decode 阶段 GPU/NPU 利用率低的问题。

  2. 数学技巧:利用 Online Softmax 的重缩放(Rescaling)公式,实现了分块结果的无损合并。

  3. Ascend C 实践:通过 Stage1(Map)+ Stage2(Reduce)的模式,完美契合了 AI Core 的计算架构。

掌握了 FlashDecoding,你就掌握了让 DeepSeek、Llama3 等长文本模型在昇腾上飞速运行的秘诀。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐