DeepSeek-V4:面向高效百万 Token 上下文智能的探索

0. 写在前面

DeepSeek-V4 是 DeepSeek 面向百万 token 长上下文能力推出的重要技术报告。与单纯扩大模型参数规模不同,DeepSeek-V4 更关注如何在超长上下文场景下同时兼顾推理效率、KV cache 占用和长程信息建模能力。本文围绕 DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence 进行学习型精读,重点梳理其 Hybrid Attention Architecture、Compressed Sparse Attention,CSA、Heavily Compressed Attention,HCA、Manifold-Constrained Hyper-Connections,mHC、Muon optimizer、DeepSeekMoE、Multi-Token Prediction,MTP,以及长上下文推理优化等关键内容。

写这篇博客的目的并不是复现 DeepSeek-V4 的官方工程系统,而是尝试从论文阅读者和代码学习者的角度,把其中比较抽象的架构设计拆解成更容易理解的模块。对于刚接触长上下文大模型、MoE 架构和 KV cache 优化的读者来说,直接阅读技术报告可能会遇到术语密集、系统细节复杂、算法与工程交织较深等问题。因此,本文会先用较直观的方式解释 DeepSeek-V4 为什么强调百万 token 上下文,再分析 CSA、HCA、mHC、MoE、MTP 等核心设计背后的动机,最后给出一组非官方 toy-level PyTorch 推测实现,帮助理解这些结构的基本思想。

需要说明的是,本文代码仅用于学习和理解论文思想,不是 DeepSeek 官方源码,也不能复现 DeepSeek-V4 的真实训练、推理效率或模型能力。真实 DeepSeek-V4 涉及大规模 MoE、DSA、mHC、Muon optimizer、FP4/FP8 混合精度、fused kernel、分布式训练和复杂 KV cache 管理等系统级工程,远超一个简化示例的范围。本文更适合作为一篇“论文精读 + 思想拆解 + 非官方最小代码理解”的学习笔记。


1. 论文基本信息

DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence 是 DeepSeek-AI 发布的一篇面向长上下文大语言模型的技术报告,公开于 arXiv,编号为 arXiv:2606.19348。该报告围绕百万 token 上下文建模展开,重点研究如何在极长序列输入条件下降低单 token 推理计算量、减少 KV cache 占用,并维持模型对长程依赖和复杂任务的建模能力。

DeepSeek-V4 系列包括 DeepSeek-V4-ProDeepSeek-V4-Flash 两个版本,均采用 Mixture-of-Experts,MoE 架构并支持 1M tokens 上下文长度。其中,DeepSeek-V4-Pro 的总参数量为 1.6T,激活参数量为 49B;DeepSeek-V4-Flash 的总参数量为 284B,激活参数量为 13B。从定位上看,Pro 版本更强调复杂推理、代码和智能体任务能力,Flash 版本则更强调推理效率、响应速度和部署成本控制。

相较于以往模型主要依赖扩大上下文窗口的做法,DeepSeek-V4 更强调长上下文场景下的效率优化。其核心方法包括由 CSA 与 HCA 组成的 Hybrid Attention Architecture、用于增强传统 residual connection 的 mHC,以及用于提升训练收敛速度和稳定性的 Muon optimizer。从研究定位来看,DeepSeek-V4 不仅是一项模型架构更新,也是一项面向超长上下文推理的系统工程探索。其技术路线将注意力压缩、稀疏计算、MoE 扩展、混合精度表示和推理系统优化结合起来,为代码理解、多文档推理、长程 agent 任务和复杂知识整合等应用场景提供了新的参考。


2. 为什么 DeepSeek-V4 要关注百万 Token 上下文?

最近大模型的发展已经不只是“参数越大越好”。长上下文能力、推理成本、KV cache 显存占用、agent 任务表现,正在变成越来越重要的竞争点。DeepSeek-V4 这篇技术报告正好集中讨论了这些问题:如何让模型在百万 token 级别上下文中仍然保持较高效率,并且能够服务代码、推理、agent 等复杂任务。

传统 Transformer 的 self-attention 复杂度与序列长度强相关。对于普通聊天任务,几十 K token 上下文可能已经够用;但对于更复杂的任务,长上下文会变得非常重要。例如:

  • 阅读整个代码仓库并修复 bug;
  • 分析多文件、多日志、多轮测试结果;
  • 对几十篇论文进行综合综述;
  • 对法律文书、合同、证据链进行交叉引用;
  • 在 agent 任务中保留完整工具调用轨迹;
  • 在长期任务中持续引用历史状态和中间结论。

这些任务有一个共同点:信息不是集中在某一个片段里,而是分散在很长的上下文中。

传统 RAG 可以缓解上下文窗口不足的问题,但 RAG 也有局限。例如,检索可能漏掉关键片段,检索结果之间缺乏全局关系,多轮 agent 的中间状态容易被截断,长程推理时模型也很难同时看到足够完整的证据链。因此,百万 token 上下文的价值不只是“能塞更多内容”,而是让模型在更完整的上下文空间中进行全局推理,减少信息截断和检索遗漏带来的错误。


3. DeepSeek-V4 的整体架构理解

DeepSeek-V4 仍然保留 Transformer 架构,并继续使用 DeepSeek 系列中的一些重要设计,例如 DeepSeekMoEMulti-Token Prediction,MTP。在此基础上,DeepSeek-V4 引入了几个关键升级:

  1. Hybrid Attention Architecture

    • Compressed Sparse Attention,CSA
    • Heavily Compressed Attention,HCA
  2. Manifold-Constrained Hyper-Connections,mHC

    • 用于增强传统 residual connection
  3. Muon optimizer

    • 用于提升训练收敛速度和稳定性
  4. 系统级训练与推理优化

    • MoE fused kernel
    • TileLang
    • deterministic kernel
    • tensor-level checkpointing
    • hybrid ZeRO
    • two-stage contextual parallelism
    • heterogeneous KV cache
    • on-disk KV cache storage
    • FP4 / FP8 mixed precision

可以把整体结构粗略理解成:

Input Tokens
    |
Token Embedding
    |
[Hybrid Attention + mHC + DeepSeekMoE] × N
    |
MTP / LM Head
    |
Output Tokens

其中每一层大致可以理解为:

Hidden States
    |
Hybrid Attention
    |---- CSA: compressed KV cache + sparse attention
    |---- HCA: heavily compressed KV cache + dense attention
    |
mHC / enhanced residual connection
    |
DeepSeekMoE FFN
    |
Output Hidden States

这说明 DeepSeek-V4 的重点不是单个技巧,而是把模型架构、优化器、训练系统、推理系统和长上下文任务结合到一起。


4. 核心技术一:Hybrid Attention Architecture

DeepSeek-V4 最重要的结构升级之一是 Hybrid Attention Architecture,可以翻译为“混合注意力架构”。它由两个部分组成:

Hybrid Attention Architecture = CSA + HCA

其中:

模块 英文全称 中文理解 主要思想
CSA Compressed Sparse Attention 压缩稀疏注意力 沿序列维度压缩 KV cache,然后执行 DeepSeek Sparse Attention
HCA Heavily Compressed Attention 重压缩注意力 / 高度压缩注意力 对 KV cache 进行更强压缩,但保留 dense attention

这里一定要注意:CSA 和 HCA 讨论的是 KV cache 的压缩,不是简单把原始文本压缩成摘要。

KV cache 是大模型推理中的核心缓存结构。随着上下文长度增加,KV cache 会占用越来越多显存。对于百万 token 上下文,KV cache 的存储和访问成本会变得非常高。因此,DeepSeek-V4 的思路不是让每一个 token 都以完整 KV 形式长期保留,而是对 KV cache 进行压缩、稀疏选择和系统级管理。


5. 核心技术二:CSA,Compressed Sparse Attention

5.1 CSA 的基本思想

CSA 的全称是 Compressed Sparse Attention,可以翻译为“压缩稀疏注意力”。官方技术报告中的关键思想可以概括为:

CSA compresses the KV caches along the sequence dimension and then performs DeepSeek Sparse Attention.

对应中文是:

CSA 沿序列维度压缩 KV cache,然后执行 DeepSeek Sparse Attention。

也就是说,CSA 有两个动作:

  1. 沿序列维度压缩 KV cache;
  2. 在压缩后的 KV cache 上执行 DeepSeek Sparse Attention,简称 DSA。

传统 attention 的形式是:

Attention(Q, K, V) = softmax(QK^T / sqrt(d))V

如果序列长度是 n,那么 QK^T 的规模大致是 n × n。对于 1M token 来说,直接做 dense attention 成本会非常高。

CSA 的直觉是:

原始 KV cache
    |
沿序列维度压缩
    |
compressed KV cache
    |
稀疏选择重要位置
    |
执行 sparse attention

可以用阅读长文档的过程做类比:近处内容需要更细粒度地看,远处内容可以压缩成更粗粒度的表示,而真正参与注意力计算的是更重要、更相关的部分。CSA 的目标不是完整保留所有 token-level 细节,而是在长上下文中尽量降低计算量和缓存开销。

5.2 Dense Attention、Sparse Attention 与 CSA 的计算例子

为了理解 DeepSeek-V4 中的 Compressed Sparse Attention,CSA,我们先从最基础的 Q、K、V 计算开始。注意,这里的例子是为了帮助理解 attention 的计算逻辑,不是 DeepSeek-V4 的官方实现。

第一步:从 token 向量到 Q、K、V

假设一句话有 3 个 token:

token 1 = 我
token 2 = 喜欢
token 3 = 学习

经过 token embedding 之后,每个 token 会变成一个 hidden vector。为了方便计算,假设每个向量只有 2 维:

h1 = [1, 0]   # “我”
h2 = [0, 1]   # “喜欢”
h3 = [1, 1]   # “学习”

在 attention 中,每个 hidden vector 会分别经过三个 projection matrix,也就是三个投影矩阵:

Q = h × WQ
K = h × WK
V = h × WV

其中:

矩阵 中文含义 作用
WQ Query 投影矩阵 把 hidden vector 变成 Query
WK Key 投影矩阵 把 hidden vector 变成 Key
WV Value 投影矩阵 把 hidden vector 变成 Value

projection matrix 可以理解为“线性变换矩阵”或“映射矩阵”。它的作用是把同一个 token 的 hidden vector 映射成三种不同用途的向量:Q 用于查询,K 用于匹配,V 用于提供真正被汇总的信息内容。

为了让计算更简单,这里先假设三个 projection matrix 都是单位矩阵:

WQ = WK = WV = I

所以:

Q1 = K1 = V1 = h1 = [1, 0]
Q2 = K2 = V2 = h2 = [0, 1]
Q3 = K3 = V3 = h3 = [1, 1]

这里需要记住一句话:

Q 和 K 用来计算“应该关注谁”,V 才是真正被加权汇总的信息内容。

第二步:Dense Attention 如何计算

现在我们以 token 3,也就是“学习”为当前 token,计算它应该关注哪些 token。

当前 token 的 Query 是:

Q3 = [1, 1]

所有 token 的 Key 是:

K1 = [1, 0]
K2 = [0, 1]
K3 = [1, 1]

所有 token 的 Value 是:

V1 = [1, 0]
V2 = [0, 1]
V3 = [1, 1]

attention score 的计算公式是:

score = Q · K / sqrt(d)

这里向量维度 d = 2,所以:

sqrt(d) = sqrt(2) ≈ 1.414

分别计算 token 3 对三个 token 的关注分数:

Q3 · K1 = [1, 1] · [1, 0] = 1
score_31 = 1 / 1.414 ≈ 0.707

Q3 · K2 = [1, 1] · [0, 1] = 1
score_32 = 1 / 1.414 ≈ 0.707

Q3 · K3 = [1, 1] · [1, 1] = 2
score_33 = 2 / 1.414 ≈ 1.414

所以原始 scores 是:

scores = [0.707, 0.707, 1.414]

接着对 scores 做 softmax:
softmax 的数学表示可以写成:

softmax ( x i ) = e x i ∑ j = 1 n e x j \text{softmax}(x_i)=\frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}} softmax(xi)=j=1nexjexi

softmax([0.707, 0.707, 1.414])
≈ [0.248, 0.248, 0.503]

得到 attention 权重:

被关注 token score attention weight
token 1:我 0.707 0.248
token 2:喜欢 0.707 0.248
token 3:学习 1.414 0.503

最后用这些权重加权求和 V:

output_3 = 0.248 × V1 + 0.248 × V2 + 0.503 × V3

这里的output3是第 3 个 token 的 attention 输出向量,捏可以理解为是学习了上下文的表示向量。

代入:

V1 = [1, 0]
V2 = [0, 1]
V3 = [1, 1]

计算:

output_3 = 0.248 × [1, 0]
         + 0.248 × [0, 1]
         + 0.503 × [1, 1]

         = [0.248, 0]
         + [0, 0.248]
         + [0.503, 0.503]

         = [0.751, 0.751]

这就是 dense attention 的计算结果。

Dense attention 的特点是:

所有可见 token 都参与 attention 计算。

在这个例子中,token 1、token 2、token 3 全部参与最终输出:

output_3 = 0.248V1 + 0.248V2 + 0.503V3

也就是说,dense attention 是“全部都看”,只是不同 token 的权重不同。
第三步:Sparse Attention 如何计算

现在我们把 dense attention 改成 sparse attention。

前面 dense attention 已经算出了 token 3 对三个 token 的 attention score:

scores = [0.707, 0.707, 1.414]

这三个分数分别对应:

token 1:“我”     score = 0.707
token 2:“喜欢”   score = 0.707
token 3:“学习”   score = 1.414

这里要先明确一个容易误解的地方:

top-1 不是指“第 1 个 token”
top-1 是指“分数排名第 1 的 token”

也就是说,top-1 表示只保留 attention score 最大的那个位置。

在当前例子中,三个分数是:

[0.707, 0.707, 1.414]

其中最大的是:

1.414

它对应的是:

token 3:“学习”

所以如果我们只保留 top-1,也就是只保留分数最高的 token,那么只保留 token 3:

top-1 = token 3

此时 token 1 和 token 2 虽然原来也有 attention score,但是在 sparse attention 中会被屏蔽掉,不再参与后面的 V 加权求和。

为了实现这种“屏蔽”,通常会把不保留的位置设为负无穷:

sparse_scores = [-∞, -∞, 1.414]

这里的 -∞ 可以理解为:

这个位置被屏蔽了
这个位置后面不会分到注意力权重

为什么设成 -∞ 之后 softmax 权重会变成 0 呢?

因为 softmax 的数学形式是:

softmax(x_i) = e^(x_i) / sum(e^(x_j))

如果某个位置是 -∞,那么:

e^(-∞) = 0

所以被设为 -∞ 的位置,在 softmax 之后权重就是 0。

因此:

softmax([-∞, -∞, 1.414]) = [0, 0, 1]

这表示:

token 1 的注意力权重 = 0
token 2 的注意力权重 = 0
token 3 的注意力权重 = 1

所以输出变成:

sparse_output_3 = 0 × V1 + 0 × V2 + 1 × V3
                = V3

前面我们设定:

V3 = [1, 1]

所以:

sparse_output_3 = [1, 1]

这就是 sparse attention top-1 的计算过程。

和 dense attention 对比一下:

Dense Attention:
output_3 = 0.248 × V1 + 0.248 × V2 + 0.503 × V3

dense attention 中,token 1、token 2、token 3 的 V 都参与了最终输出。

而 sparse attention top-1 中:

Sparse Attention top-1:
sparse_output_3 = 1 × V3

只有分数最高的 token 3 参与了最终输出。

直观理解就是:

Dense Attention:
token 3 同时看 token 1、token 2、token 3。

Sparse Attention top-1:
token 3 只看 attention score 最高的那个 token。
在这个例子中,分数最高的是 token 3,所以只看 token 3。

接着看 top-2 的情况。

top-2 的意思是:

只保留 attention score 排名前 2 的 token

原始 scores 仍然是:

scores = [0.707, 0.707, 1.414]

其中分数最高的是 token 3:

token 3: 1.414

第二高的是 token 1 或 token 2,因为它们的分数相同:

token 1: 0.707
token 2: 0.707

这里为了方便演示,我们假设保留 token 1 和 token 3,屏蔽 token 2。

于是 sparse scores 变成:

sparse_scores = [0.707, -∞, 1.414]

注意,这里不是直接使用 dense attention 里面原来的权重:

dense attention 原来的权重是 [0.248, 0.248, 0.503]

而是要对新的 sparse_scores 重新做 softmax:

softmax([0.707, -∞, 1.414])

因为 token 2 被设成了 -∞,所以 token 2 的权重变成 0。剩下 token 1 和 token 3 的权重会重新归一化,并且它们加起来等于 1。

softmax 后大约是:

softmax([0.707, -∞, 1.414])
≈ [0.330, 0.000, 0.670]

这表示:

token 1 的注意力权重 ≈ 0.330
token 2 的注意力权重 = 0.000
token 3 的注意力权重 ≈ 0.670

所以输出为:

sparse_output_3 = 0.330 × V1 + 0.000 × V2 + 0.670 × V3

因为 token 2 的权重是 0,所以可以省略:

sparse_output_3 = 0.330 × V1 + 0.670 × V3

代入前面的 Value 向量:

V1 = [1, 0]
V3 = [1, 1]

得到:

sparse_output_3 = 0.330 × [1, 0] + 0.670 × [1, 1]
                = [0.330, 0] + [0.670, 0.670]
                = [1.000, 0.670]

所以 top-2 sparse attention 的输出是:

sparse_output_3 = [1.000, 0.670]

这里最关键的是理解 sparse attention 的两个动作:

第一步:根据 attention score 选择 top-k 个重要位置
第二步:只对保留下来的位置重新做 softmax 和 V 加权求和

因此,sparse attention 不是简单地把 dense attention 里面的小权重删掉,而是先屏蔽掉不重要的位置,再在剩下的位置之间重新分配注意力权重。

所以 dense attention 和 sparse attention 的区别可以总结为:

方法 参与计算的 token softmax 范围 输出方式 直观理解
Dense Attention 所有 token 都参与 对所有 score 做 softmax 所有 V 加权求和 全部都看
Sparse Attention top-1 只保留分数最高的 1 个 token 只对保留位置重新 softmax 只使用 1 个 V 只看最重要的一个
Sparse Attention top-2 只保留分数最高的 2 个 token 只对保留位置重新 softmax 只使用 2 个 V 只看最重要的两个

一句话总结:

Dense attention 是所有 token 都看;
Sparse attention 是先选出 attention score 最高的 top-k 个 token,
然后只让这些 token 参与后续的 softmax 和 V 加权求和。

第四步:Compressed Sparse Attention 如何计算

现在进一步理解 DeepSeek-V4 中的 Compressed Sparse Attention,CSA

普通 sparse attention 是:

原始 K/V
    ↓
选择重要位置
    ↓
对重要位置做 attention

而 Compressed Sparse Attention 的直觉是:

原始 KV cache
    ↓
先沿序列维度压缩 KV cache
    ↓
得到 compressed K/V
    ↓
再在 compressed K/V 上做 sparse attention

也就是说,CSA 比普通 sparse attention 多了一个关键步骤:

compression

它不是直接在完整 KV cache 上做稀疏选择,而是先把 KV cache 沿序列维度压缩,再在压缩后的表示中选择重要位置。

举一个简化例子。假设原来有 8 个历史 token 的 KV cache:

K1, K2, K3, K4, K5, K6, K7, K8
V1, V2, V3, V4, V5, V6, V7, V8

如果压缩比例是 2,可以把每两个 token 压缩成一个 compressed KV:

CK1 = compress(K1, K2)
CK2 = compress(K3, K4)
CK3 = compress(K5, K6)
CK4 = compress(K7, K8)

CV1 = compress(V1, V2)
CV2 = compress(V3, V4)
CV3 = compress(V5, V6)
CV4 = compress(V7, V8)

原来有 8 个 KV 位置,现在变成 4 个 compressed KV 位置:

原始长度: 8
压缩后长度: 4

然后当前 token 的 Query 不再和 8 个完整 Key 全部匹配,而是和 4 个 compressed Key 匹配:

Q · CK1
Q · CK2
Q · CK3
Q · CK4

假设得到的分数是:

compressed_scores = [2.0, 0.5, 1.5, -1.0]

如果只保留 top-2,那么选择:

CK1: 2.0
CK3: 1.5

其余位置丢掉:

sparse_compressed_scores = [2.0, -∞, 1.5, -∞]

softmax 后:

softmax([2.0, -∞, 1.5, -∞])
≈ [0.622, 0.000, 0.378, 0.000]

最终输出是:

output = 0.622 × CV1 + 0.378 × CV3

这就是 Compressed Sparse Attention 的直觉:

先把长 KV cache 压短,
再只选择压缩表示中的重要位置参与 attention。

第五步:Dense Attention、Sparse Attention 与 CSA 对比

可以把三者放在一起看:

方法 是否压缩 KV 是否稀疏选择 参与 attention 的对象
Dense Attention 所有原始 K/V
Sparse Attention 部分原始 K/V
Compressed Sparse Attention 部分 compressed K/V

如果用一句话理解:

Dense Attention:所有原始 K/V 都看。
Sparse Attention:只看部分重要的原始 K/V。
Compressed Sparse Attention:先压缩 K/V,再只看部分重要的 compressed K/V。

第六步:和 HCA 的区别

DeepSeek-V4 中除了 CSA,还有 HCA,也就是 Heavily Compressed Attention

HCA 的全称是:

Heavily Compressed Attention

可以翻译为:

重压缩注意力

或者:

高度压缩注意力

这里的关键词有两个:

Heavily Compressed:更强程度压缩
Attention:注意力

所以 HCA 的核心直觉是:

先把 KV cache 压缩得更短,
然后在这个更短的 compressed K/V 上做 dense attention。

也就是说,HCA 不是直接在原始 K/V 上做 dense attention,而是在 强压缩之后的 K/V 上做 dense attention。

它的流程可以写成:

原始 KV cache
    ↓
更强程度压缩
    ↓
得到更短的 compressed K/V
    ↓
当前 Query 和所有 compressed Key 计算分数
    ↓
对所有 compressed positions 做 softmax
    ↓
加权求和所有 compressed Value

这里最容易误解的是:

HCA 里的 dense attention
不是看所有原始 K/V,
而是看所有压缩后的 compressed K/V。

也就是说,HCA 的 dense 是指:

在 compressed K/V 空间里全部都看

不是指:

在原始百万 token 的 K/V 空间里全部都看

这两者差别很大。


先回顾一下 CSA 的做法。

CSA 是:

原始 KV cache
    ↓
压缩
    ↓
得到 compressed K/V
    ↓
再做 sparse attention,只选择 top-k 重要位置

也就是说,CSA 有两个减少计算量的动作:

第一步:压缩 K/V,让 K/V 位置变少
第二步:稀疏选择,只保留 top-k 重要 compressed KV 位置

如果原来有 8 个 KV 位置:

(K1,V1), (K2,V2), (K3,V3), (K4,V4),
(K5,V5), (K6,V6), (K7,V7), (K8,V8)

CSA 可以先压缩成 4 个 compressed KV 位置:

(CK1,CV1), (CK2,CV2), (CK3,CV3), (CK4,CV4)

然后再从这 4 个里面选择 top-k。

例如 compressed scores 是:

compressed_scores = [2.0, 0.5, 1.5, -1.0]

如果 CSA 保留 top-2,那么只选:

CK1 / CV1
CK3 / CV3

屏蔽掉:

CK2 / CV2
CK4 / CV4

于是 sparse compressed scores 是:

sparse_compressed_scores = [2.0, -∞, 1.5, -∞]

softmax 后大约是:

[0.622, 0.000, 0.378, 0.000]

最终输出是:

output = 0.622 × CV1 + 0.378 × CV3

也就是说,CSA 在压缩之后,仍然只使用一部分 compressed Value。

所以 CSA 的直觉是:

压缩后再挑重点。

HCA 的思路不一样。

HCA 也会压缩 K/V,但是压缩得更狠。压缩完成之后,HCA 不再做 top-k sparse selection,而是对所有压缩后的 K/V 做 dense attention。

举一个简化例子。

假设原来还是有 8 个 KV 位置:

(K1,V1), (K2,V2), (K3,V3), (K4,V4),
(K5,V5), (K6,V6), (K7,V7), (K8,V8)

CSA 的教学例子里,我们可能把 8 个位置压缩成 4 个:

(CK1,CV1), (CK2,CV2), (CK3,CV3), (CK4,CV4)

而 HCA 更强压缩,可能把 8 个位置压缩成 2 个:

(HK1,HV1), (HK2,HV2)

这里为了区分,可以用:

HK = HCA compressed Key
HV = HCA compressed Value

如果用相邻分块压缩作为教学例子,可以写成:

HK1 = compress(K1, K2, K3, K4)
HK2 = compress(K5, K6, K7, K8)

HV1 = compress(V1, V2, V3, V4)
HV2 = compress(V5, V6, V7, V8)

这表示:

HK1 / HV1 代表 token 1 到 token 4 这一大段的压缩表示
HK2 / HV2 代表 token 5 到 token 8 这一大段的压缩表示

注意,这里的压缩比 CSA 更强。

CSA 中可能是:

2 个 token 压缩成 1 个 compressed KV

HCA 中可能是:

4 个 token 压缩成 1 个 compressed KV

这就是 HCA 里 Heavily Compressed 的含义:

压缩得更狠,压缩后的 KV 序列更短。

接下来 HCA 做 dense attention。

假设当前 Query 是 Q。

压缩后有两个 compressed Key:

HK1, HK2

那么 HCA 会计算:

Q · HK1
Q · HK2

假设得到分数:

hca_scores = [1.2, 0.8]

HCA 不会像 CSA 那样只保留 top-1,也不会把某些位置设成 -∞

HCA 会直接对所有 compressed scores 做 softmax:

softmax([1.2, 0.8])
≈ [0.599, 0.401]

然后加权求和所有 compressed Value:

hca_output = 0.599 × HV1 + 0.401 × HV2

这里两个 compressed Value 都参与了输出。

所以 HCA 的 dense attention 是:

在压缩后的 KV 位置上全部都看。

不是:

在原始 KV 位置上全部都看。

这一点非常关键。


可以把 CSA 和 HCA 的完整过程放在一起对比。

假设原来有 8 个原始 KV 位置。

CSA 可以理解为:

原始 KV:
8 个位置
    ↓
中等程度压缩
    ↓
4 个 compressed KV 位置
    ↓
sparse attention
    ↓
只选择 top-2 个 compressed KV 位置参与输出

所以 CSA 实际使用的是:

2 个 compressed KV 位置

HCA 可以理解为:

原始 KV:
8 个位置
    ↓
更强程度压缩
    ↓
2 个 heavily compressed KV 位置
    ↓
dense attention
    ↓
2 个 heavily compressed KV 位置全部参与输出

所以 HCA 实际使用的是:

所有 heavily compressed KV 位置

也就是说:

CSA 是“压缩后再挑一部分看”。
HCA 是“压缩得更短,然后压缩后的都看”。

为什么 HCA 压缩得更狠之后还可以做 dense attention?

因为 dense attention 的成本主要取决于 K/V 位置数量。

如果原始 K/V 很长,dense attention 很贵。

例如原来有:

1,000,000 个原始 KV 位置

如果直接 dense attention,当前 Query 要看:

1,000,000 个 Key

这会非常贵。

但是 HCA 先把 KV cache 强压缩。

例如把:

1,000,000 个原始 KV 位置

压缩成:

10,000 个 heavily compressed KV 位置

那么 HCA 的 dense attention 实际是在这 10,000 个 compressed KV 上做,而不是在 1,000,000 个原始 KV 上做。

所以计算从:

Q 看 1,000,000 个原始 Key

变成:

Q 看 10,000 个 heavily compressed Key

虽然 HCA 在 compressed KV 上是 dense 的,但因为 compressed KV 的长度已经大幅缩短,所以 dense 的成本仍然可以接受。

换句话说:

HCA 不是靠 sparse selection 省计算,
而是靠更强的 KV 压缩省计算。

用复杂度直觉看也更清楚。

假设原始 KV 长度是:

n

如果直接 dense attention,那么当前 Query 要访问:

n 个 K/V 位置

如果 CSA 先压缩到:

n_csa

然后只选 top-k,那么实际参与 V 加权求和的位置大约是:

k

所以 CSA 的直觉是:

n → n_csa → k

如果 HCA 更强压缩到:

n_hca

然后在 n_hca 上做 dense attention,那么实际参与 V 加权求和的位置是:

n_hca

所以 HCA 的直觉是:

n → n_hca

其中通常可以理解为:

n_hca < n_csa

也就是说,HCA 压缩后的序列更短。

因此:

CSA:
先压缩到一个中等长度,再通过 top-k 进一步减少参与计算的位置。

HCA:
直接压缩到更短长度,然后在这个短序列上全部参与计算。

可以用一个更大的例子说明。

假设原始 KV cache 有:

1,000,000 个 KV 位置

CSA 可以粗略理解为:

1,000,000 个原始 KV
    ↓
压缩成 100,000 个 compressed KV
    ↓
sparse attention 只选 top-k 个重要位置

如果 top-k 是 1,024,那么 CSA 最后真正参与加权求和的是:

1,024 个 compressed Value

也就是说,CSA 的特点是:

保留较多 compressed 候选位置,
但最后只选择其中最相关的少数位置。

HCA 可以粗略理解为:

1,000,000 个原始 KV
    ↓
更强压缩成 10,000 个 heavily compressed KV
    ↓
dense attention 使用这 10,000 个 heavily compressed KV

也就是说,HCA 的特点是:

候选位置本身已经非常少,
所以可以把这些压缩后的位置全部看一遍。

这里的 100,000、10,000、1,024 都是为了帮助理解的示意数字,不代表官方固定参数。


为什么 DeepSeek-V4 需要同时有 CSA 和 HCA?

因为二者解决的问题不完全一样。

CSA 更像是:

从较细粒度的压缩表示里,挑出和当前 Query 最相关的重点信息。

它适合保留一些细节相关性。

例如,长文档或长代码里某些片段和当前问题高度相关,CSA 可以通过 top-k 选择把它们挑出来。

HCA 更像是:

提供一个更粗粒度、更全局的上下文背景。

它把很长的 KV cache 压缩得更短,然后在这个较短的全局压缩表示上做 dense attention。虽然每个 compressed KV 更粗糙,但它能让模型看到更完整的压缩后全局信息。

所以可以这样理解:

CSA 更偏向“找重点细节”。
HCA 更偏向“保留全局背景”。

二者是互补关系,不是谁完全替代谁。


再用读书来类比。

Dense attention 像是:

把整本书每一句都仔细看。

Sparse attention 像是:

从整本书里挑最相关的几句话看。

CSA 像是:

先把整本书按段落压缩成段落摘要,
再从这些段落摘要里挑最相关的几个看。

HCA 像是:

先把整本书压缩成更短的章节摘要,
然后把所有章节摘要都看一遍。

所以:

CSA:摘要比较细,但只挑重点看。
HCA:摘要更粗,但压缩后的摘要全部看。

这个类比虽然不等于真实实现,但可以帮助理解二者的差异。


因此,CSA 和 HCA 的区别可以总结为:

模块 压缩程度 Attention 方式 实际访问对象 直观理解
CSA 中等压缩 Sparse attention top-k 个 compressed K/V 压缩后只看重点
HCA 更强压缩 Dense attention 所有 heavily compressed K/V 压缩得更狠,但压缩后的都看

更具体地说:

CSA = compression + sparse attention

意思是:

先压缩 K/V,
再从压缩后的 K/V 中选 top-k 个重要位置。

而:

HCA = heavier compression + dense attention

意思是:

先对 K/V 做更强压缩,
再让所有压缩后的 K/V 位置都参与 attention。

所以中文可以写成:

CSA = 压缩 + 稀疏注意力
HCA = 更强压缩 + 压缩空间内的密集注意力

这里最好把 HCA 写成“压缩空间内的密集注意力”,因为它不是对原始百万 token 做 dense attention,而是在更短的 heavily compressed KV 上做 dense attention。

一句话总结:

CSA 是先把 KV cache 压短,再从压缩后的 K/V 里挑重点;
HCA 是把 KV cache 压得更短,然后让压缩后的 K/V 全部参与 attention。

因此,HCA 的 dense attention 并不和高效率矛盾,因为它 dense 的对象已经不是原始长序列,而是强压缩后的短 K/V 序列。

最后总结

在 attention 中,Q、K、V 的作用可以这样记:

Q:当前 token 想找什么
K:每个 token 提供什么索引
V:真正被拿走的信息内容

Dense attention 的流程是:

Q 匹配所有 K
    ↓
softmax 得到所有位置的权重
    ↓
加权求和所有 V

Sparse attention 的流程是:

Q 匹配 K
    ↓
只保留重要 K
    ↓
只加权对应的 V

Compressed Sparse Attention 的流程是:

先压缩 KV cache
    ↓
Q 匹配 compressed K
    ↓
只保留重要 compressed K
    ↓
只加权对应的 compressed V

一句话总结:

Dense attention 是“所有位置都参与计算”;sparse attention 是“只让重要位置参与计算”;Compressed Sparse Attention 则是在长上下文场景下先压缩 KV cache,再在压缩后的表示中做稀疏注意力,从而降低计算和存储压力。

5.3 CSA 的 toy-level 代码理解

下面给出一个非官方 toy-level CSA 实现。它的核心步骤是:先压缩 K/V,再计算 Q 和压缩后 K 的相似度,最后只保留 top-k 位置参与加权求和。

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SequenceCompressor(nn.Module):
    def __init__(self, dim: int, compression_ratio: int = 4):
        super().__init__()
        if compression_ratio <= 0:
            raise ValueError("compression_ratio must be positive")
        self.dim = dim
        self.compression_ratio = compression_ratio
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 3:
            raise ValueError("x must be [batch, seq_len, dim]")
        b, n, d = x.shape
        if d != self.dim:
            raise ValueError(f"expected dim={self.dim}, got dim={d}")
        r = self.compression_ratio
        pad_len = (r - n % r) % r
        if pad_len > 0:
            pad = torch.zeros(b, pad_len, d, device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=1)
        n2 = x.shape[1]
        x = x.view(b, n2 // r, r, d).mean(dim=2)
        return self.proj(x)


class ToyCompressedSparseAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 4,
        compression_ratio: int = 4,
        topk: int = 8,
    ):
        super().__init__()
        if dim % num_heads != 0:
            raise ValueError("dim must be divisible by num_heads")
        if topk <= 0:
            raise ValueError("topk must be positive")
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.topk = topk
        self.q_proj = nn.Linear(dim, dim)
        self.k_compressor = SequenceCompressor(dim, compression_ratio)
        self.v_compressor = SequenceCompressor(dim, compression_ratio)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, n, d = x.shape
        x = x.view(b, n, self.num_heads, self.head_dim)
        return x.transpose(1, 2)

    def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, h, n, d = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(b, n, h * d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.split_heads(self.q_proj(x))
        k_base = self.k_compressor(x)
        v_base = self.v_compressor(x)
        k = self.split_heads(self.k_proj(k_base))
        v = self.split_heads(self.v_proj(v_base))
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        topk = min(self.topk, scores.shape[-1])
        topk_values, topk_indices = torch.topk(scores, k=topk, dim=-1)
        sparse_scores = torch.full_like(scores, float("-inf"))
        sparse_scores.scatter_(-1, topk_indices, topk_values)
        attn = F.softmax(sparse_scores, dim=-1)
        out = torch.matmul(attn, v)
        out = self.merge_heads(out)
        return self.o_proj(out)


if __name__ == "__main__":
    x = torch.randn(2, 1024, 256)
    csa = ToyCompressedSparseAttention(
        dim=256,
        num_heads=8,
        compression_ratio=4,
        topk=16,
    )
    y = csa(x)
    print("input shape:", x.shape)
    print("output shape:", y.shape)

6. 核心技术三:HCA,Heavily Compressed Attention

6.1 HCA 的基本思想

HCA 的全称是 Heavily Compressed Attention,可以翻译为“重压缩注意力”或“高度压缩注意力”。与 CSA 不同,HCA 的特点是对 KV cache 做更强压缩,但仍然保留 dense attention。

可以粗略理解为:

原始 KV cache
    |
更激进的压缩
    |
更短的 compressed KV cache
    |
dense attention

CSA 和 HCA 的区别可以这样理解:

模块 KV 压缩程度 Attention 方式 更适合承担的作用
CSA 中等压缩 Sparse attention 保留关键细节和相关片段
HCA 更强压缩 Dense attention 提供全局语义背景

CSA 更像是“从压缩后的长上下文中找重点”,HCA 更像是“保留一个高度压缩的全局记忆”。二者不是替代关系,而是互补关系。

6.2 HCA 的 toy-level 代码理解

HCA 和 CSA 的主要区别是:CSA 是压缩后再稀疏选择,而 HCA 是压缩得更狠,但在压缩后的 K/V 上仍然做 dense attention。

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SequenceCompressor(nn.Module):
    def __init__(self, dim: int, compression_ratio: int):
        super().__init__()
        if compression_ratio <= 0:
            raise ValueError("compression_ratio must be positive")
        self.dim = dim
        self.compression_ratio = compression_ratio
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, n, d = x.shape
        r = self.compression_ratio
        pad_len = (r - n % r) % r
        if pad_len > 0:
            pad = torch.zeros(b, pad_len, d, device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=1)
        n2 = x.shape[1]
        x = x.view(b, n2 // r, r, d).mean(dim=2)
        return self.proj(x)


class ToyHeavilyCompressedAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int = 4, compression_ratio: int = 16):
        super().__init__()
        if dim % num_heads != 0:
            raise ValueError("dim must be divisible by num_heads")
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.q_proj = nn.Linear(dim, dim)
        self.k_compressor = SequenceCompressor(dim, compression_ratio)
        self.v_compressor = SequenceCompressor(dim, compression_ratio)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim)

    def split_heads(self, x):
        b, n, d = x.shape
        x = x.view(b, n, self.num_heads, self.head_dim)
        return x.transpose(1, 2)

    def merge_heads(self, x):
        b, h, n, d = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(b, n, h * d)

    def forward(self, x):
        q = self.split_heads(self.q_proj(x))
        compressed_k = self.k_compressor(x)
        compressed_v = self.v_compressor(x)
        k = self.split_heads(self.k_proj(compressed_k))
        v = self.split_heads(self.v_proj(compressed_v))
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = self.merge_heads(out)
        return self.o_proj(out)


if __name__ == "__main__":
    x = torch.randn(2, 1024, 256)
    hca = ToyHeavilyCompressedAttention(
        dim=256,
        num_heads=8,
        compression_ratio=16,
    )
    y = hca(x)
    print("input shape:", x.shape)
    print("output shape:", y.shape)

这个例子体现了 HCA 的核心直觉:

原始序列 x
    ↓
更强压缩 K/V
    ↓
Q attend compressed K/V
    ↓
dense attention

它不是 DeepSeek-V4 的官方实现,只是帮助理解“更强压缩 + dense attention”的结构思想。


7. 核心技术四:mHC,Manifold-Constrained Hyper-Connections

7.1 mHC 的基本思想

mHC 的全称是 Manifold-Constrained Hyper-Connections,可以翻译为“流形约束超连接”。

这个名字比较抽象,可以先拆成两部分理解:

Hyper-Connections = 超连接,也就是比普通 residual connection 更丰富的连接方式
Manifold-Constrained = 流形约束,也就是给这些连接方式加上稳定性约束

所以,mHC 可以先粗略理解为:

mHC = 有约束的多通道残差连接

要理解 mHC,需要先理解普通 residual connection。

普通 Transformer 中常见的 residual connection 是:

x_new = x_old + f(x_old)

其中:

x_old      = 这一层输入进来的原始 hidden state
f(x_old)   = attention 或 MLP 处理后得到的新信息
x_new      = 加上新信息后的 hidden state

它的意思是:

新的表示 = 原来的表示 + 这一层学到的修改量

举个简单例子:

x_old = [1.0, 2.0]
f(x_old) = [0.1, -0.2]

那么:

x_new = x_old + f(x_old)
      = [1.0, 2.0] + [0.1, -0.2]
      = [1.1, 1.8]

可以看到,输出不是完全替换原来的 x_old,而是在原来的基础上做了一点修改。

这就是 residual connection 的核心:

不要让每一层都完全重写 hidden state,
而是在原始表示的基础上补充新信息。

可以用改论文来类比:

x_old = 原来的论文
f(x_old) = 修改意见
x_new = 原来的论文 + 修改意见

也就是说,residual connection 会保留原来的信息,同时加入当前层学到的新信息。

这种结构有两个重要好处。

第一个好处是保留信息。

如果没有 residual connection,每一层都可能把前面的信息覆盖掉。模型越深,早期信息越容易丢失。

有 residual connection 之后,原始信息可以通过这条路径继续往后传:

x_old 直接传到 x_new

第二个好处是训练更稳定。

在深层神经网络中,梯度需要从后面一层一层传回前面。如果每一层都很复杂,梯度可能会变弱或不稳定。

residual connection 给梯度提供了一条更直接的通路:

后面的梯度
    ↓
通过 residual path 回到前面的 hidden state

所以深层 Transformer 才能更容易训练。


普通 residual connection 可以理解成一条主要的信息通道:

x0 → x1 → x2 → x3 → x4

每一层都沿着这条通道更新:

上一层 hidden state
    ↓
加上 attention 或 MLP 产生的新信息
    ↓
下一层 hidden state

这种方式稳定、简单,但也有一个问题:

信息主要沿着一条 residual stream 传播,连接方式比较单一。

也就是说,普通 residual connection 更像是一条主干道。信息可以稳定地往后走,但路线比较固定。


Hyper-Connections 可以理解为对普通 residual connection 的扩展。

普通 residual connection 是一条信息通道:

一条 residual stream:
x0 → x1 → x2 → x3

Hyper-Connections 则可以理解为多条信息通道:

多条 residual streams:

stream 1: s1_0 → s1_1 → s1_2 → s1_3
stream 2: s2_0 → s2_1 → s2_2 → s2_3
stream 3: s3_0 → s3_1 → s3_2 → s3_3

直观理解就是:

普通 residual connection:
一条路传信息。

Hyper-Connections:
多条路一起传信息。

这样做的好处是,模型的信息流动方式更丰富。

普通 residual connection 里,每一层主要只能沿着同一条 hidden state 通道继续更新。而 Hyper-Connections 允许模型保留多条 residual streams,让不同信息可以在不同通道中传播。

可以把它类比成写论文时保留多个版本:

普通 residual connection:
只保留一个主版本,每次都在这个版本上修改。

Hyper-Connections:
同时保留多个版本,不同版本可以记录不同方向的信息。

例如:

stream 1 可能更偏向保留原始语义信息
stream 2 可能更偏向保留上下文交互信息
stream 3 可能更偏向保留深层抽象信息

这只是帮助理解的类比,不代表真实模型中每条 stream 一定有这样明确的语义分工。

Hyper-Connections 的目标是:

让深层模型拥有更丰富的信息流动路径。

但是,Hyper-Connections 也会带来一个问题:连接方式变多以后,信息流可能变得不稳定。

普通 residual connection 虽然简单,但它有一个很重要的稳定性来源:

x_old 可以比较直接地传到 x_new

也就是说,模型至少有一条接近“原样保留”的路径。

这条路径可以理解为:

identity path

中文可以叫:

恒等映射路径

它的意思是:

即使当前层学到的新信息 f(x) 不够好,
原始信息 x 也可以继续传下去。

但是 Hyper-Connections 增加了更多连接以后,多个信息通道之间可能会互相混合。

比如:

stream 1 的信息流到 stream 2
stream 2 的信息流到 stream 3
stream 3 的信息又混回 stream 1

如果这种混合没有约束,就可能出现两个问题。

第一个问题是信息被过度放大:

某些通道的信息越来越强
    ↓
数值不稳定
    ↓
训练变困难

第二个问题是信息被过度削弱:

某些通道的信息越来越弱
    ↓
重要信息被冲淡
    ↓
深层模型难以保留早期信息

所以 Hyper-Connections 的矛盾是:

连接更多,表达能力更强;
但是连接太自由,稳定性可能变差。

mHC 的作用就是解决这个矛盾。

mHC 不是简单地增加更多连接,也不是让多条通道随便混合,而是:

允许多条 residual streams 存在,
但要求这些连接方式满足一定的稳定性约束。

这里的 Manifold-Constrained 可以先不用按很复杂的数学理解。对于入门阅读来说,可以先把它理解成:

给连接方式加规则。

也就是说,mHC 的直观含义是:

Hyper-Connections 提供更多信息通道;
Manifold-Constrained 给这些通道之间的连接加规则;
mHC 让信息可以多路径流动,但不至于乱流。

可以用交通来类比:

普通 residual connection:
只有一条主路,稳定,但路线少。

Hyper-Connections:
开了很多条路,路线更多,但如果没有规则,容易混乱。

mHC:
开很多条路,同时加交通规则,让信息能多路径流动,但整体仍然稳定。

这个类比能帮助理解 mHC 的核心思想:

不是只追求更多连接,
而是追求有约束、更稳定的连接。

因此,mHC 和普通 residual connection 的关系可以这样理解:

结构 信息通道 连接特点 优点 风险
Residual Connection 一条主要通道 x_old 直接加到输出上 简单、稳定、容易训练 信息流动方式较单一
Hyper-Connections 多条通道 多个 residual streams 之间可以交互 表达能力更强,信息路径更多 连接太自由可能不稳定
mHC 有约束的多条通道 多通道连接受到约束 同时增强表达能力和稳定性 结构更复杂

从 DeepSeek-V4 的角度看,CSA 和 HCA 主要解决的是长上下文 attention 的计算和 KV cache 压力;而 mHC 解决的是深层模型中信息怎么稳定传播的问题。

可以这样区分:

CSA / HCA:
解决“长上下文中 attention 怎么算得更省”。

mHC:
解决“很深的模型中信息怎么传得更稳”。

这里还要注意,mHC 不能简单等同于普通 gated residual。

普通 gated residual 通常可以写成:

x_new = x_old + gate × f(x_old)

它的意思是给当前层新增信息加一个门控,控制 f(x_old) 加多少。

但是 mHC 的重点不是简单控制一个 f(x_old) 加多少,而是扩展 residual connection 的连接结构,让信息可以通过多个 residual streams 传播,并且对这些连接加约束。

所以:

Gated residual:
主要是在一条 residual stream 上控制新增信息的比例。

mHC:
关注多条 residual streams 之间如何稳定连接。

因此,博客里的 toy gated residual 只能用来帮助理解“增强残差路径”的动机,不能说它就是 mHC 的真实实现。

更准确的表述应该是:

本文后面的 ToyGatedResidual 不是 mHC。
它只是用一个简单代码例子说明:在普通 residual connection 之外,可以加入额外机制来调节信息流。
真实 mHC 涉及 Hyper-Connections 和 manifold constraint,不能用几行 gated residual 代码完整复现。

一句话总结:

Residual connection 是一条稳定的信息直通路径;
Hyper-Connections 把这条路径扩展成多条信息通道;
mHC 则是在多条通道之间加入约束规则,
让模型既能获得更丰富的信息流动方式,又能保持深层训练的稳定性。

所以,mHC 的重点不是 attention 怎么算,也不是 KV cache 怎么压缩,而是:

深层 Transformer 中,信息如何在层与层之间更稳定、更灵活地传播。

7.2 mHC 的 toy-level 代码理解

mHC 不是普通 gated residual,也不是简单的残差相加。真实 mHC 涉及更复杂的连接约束和表示传播方式。为了帮助理解“增强 residual connection”这个动机,可以先写一个极简的 gated residual 作为类比。

普通 residual connection 是:

x = x + f(x)

toy gated residual 可以写成:

x = x + gate(x) * f(x)

代码如下:

import torch
import torch.nn as nn


class ToyGatedResidual(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor, fx: torch.Tensor) -> torch.Tensor:
        return x + self.gate(x) * fx


class ToyBlockWithEnhancedResidual(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )
        self.residual = ToyGatedResidual(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        fx = self.ffn(self.norm(x))
        return self.residual(x, fx)


if __name__ == "__main__":
    x = torch.randn(2, 128, 256)
    block = ToyBlockWithEnhancedResidual(dim=256)
    y = block(x)
    print("input shape:", x.shape)
    print("output shape:", y.shape)

需要强调的是:

ToyGatedResidual 不是 mHC。
ToyGatedResidual 只是为了帮助理解“增强残差路径”的动机。

更准确地说,mHC 的意义在于:当模型很深、上下文很长、模块很复杂时,模型需要更稳定、更灵活的层间信息流。普通 residual connection 是直接相加,而 mHC 试图让层间连接具有更强的表达能力和更好的训练稳定性。


8. 核心技术五:Muon Optimizer

8.1 Muon Optimizer 的基本思想

DeepSeek-V4 引入了 Muon optimizer,用于提升训练收敛速度和稳定性。

对于大模型训练来说,优化器不是细节,而是核心工程之一。尤其是 DeepSeek-V4 这种模型同时包含超大规模 MoE、长上下文注意力压缩、多精度训练、多种并行策略以及后训练中的 SFT / RL / distillation,训练稳定性非常重要。

从论文角度看,Muon optimizer 的作用可以概括为:

更快收敛 + 更稳定训练

不过,本文不会尝试复现 Muon optimizer,因为完整训练 DeepSeek-V4 级别模型涉及大规模分布式训练系统,不适合用一个简单脚本模拟。

8.2 Optimizer 在训练流程中的位置

优化器和模型结构不同,它不是一个简单模块,不能像 attention 或 MoE 那样用几十行代码准确复现。因此,本文不实现 Muon optimizer。为了帮助理解优化器在训练中的位置,下面用 AdamW 写一个普通训练步骤示意。

import torch
import torch.nn as nn


model = nn.Linear(256, 32000)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=0.01,
)

input_x = torch.randn(4, 256)
target = torch.randint(0, 32000, (4,))

criterion = nn.CrossEntropyLoss()

logits = model(input_x)
loss = criterion(logits, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()

print("loss:", loss.item())

这里用 AdamW 只是演示 optimizer 在训练流程中的作用,不代表 DeepSeek-V4 使用 AdamW,也不代表 Muon 的真实实现。


9. DeepSeek-V4 的效率来自哪里?

DeepSeek-V4 在 1M-token context 场景下,相比 DeepSeek-V3.2 显著降低了单 token 推理 FLOPs 和 KV cache 占用。这背后的原因不是一个技巧,而是多个层面的共同作用。

9.1 注意力层面的优化

传统长上下文 attention 的主要瓶颈是:

长序列 attention 计算成本高
KV cache 占用大
KV cache 读写开销大

DeepSeek-V4 通过 CSA 和 HCA 对 KV cache 进行压缩,并结合稀疏注意力和强压缩 dense attention,从架构层面降低长上下文成本。

9.2 MoE 层面的优化

MoE 的核心思想是:模型总参数很多,但每次前向只激活部分专家。

例如:

DeepSeek-V4-Pro: 1.6T total / 49B activated
DeepSeek-V4-Flash: 284B total / 13B activated

这说明 MoE 的重点是用更大的参数容量提升模型能力,同时控制每次前向计算的实际成本。

9.3 MoE 的 toy-level 代码理解

MoE,Mixture-of-Experts,可以理解为“专家混合”。普通 FFN 是所有 token 都经过同一个 MLP:

x → FFN → output

MoE 则是先用 router 判断每个 token 应该交给哪些 expert:

x → router → top-k experts → weighted sum

下面是一个极简 toy MoE 实现:

import torch
import torch.nn as nn
import torch.nn.functional as F


class ToyExpert(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x):
        return self.net(x)


class ToyMoE(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        num_experts: int = 4,
        top_k: int = 2,
    ):
        super().__init__()
        if top_k > num_experts:
            raise ValueError("top_k must be <= num_experts")
        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(dim, num_experts)
        self.experts = nn.ModuleList(
            [ToyExpert(dim, hidden_dim) for _ in range(num_experts)]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, n, d = x.shape
        router_logits = self.router(x)
        router_probs = F.softmax(router_logits, dim=-1)
        topk_probs, topk_indices = torch.topk(
            router_probs,
            k=self.top_k,
            dim=-1,
        )
        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
        output = torch.zeros_like(x)

        for expert_id, expert in enumerate(self.experts):
            expert_out = expert(x)
            weight = torch.zeros(b, n, device=x.device, dtype=x.dtype)

            for k in range(self.top_k):
                weight = weight + torch.where(
                    topk_indices[:, :, k] == expert_id,
                    topk_probs[:, :, k],
                    torch.zeros_like(topk_probs[:, :, k]),
                )

            output = output + expert_out * weight.unsqueeze(-1)

        return output


if __name__ == "__main__":
    x = torch.randn(2, 16, 256)
    moe = ToyMoE(
        dim=256,
        hidden_dim=1024,
        num_experts=4,
        top_k=2,
    )
    y = moe(x)
    print("input shape:", x.shape)
    print("output shape:", y.shape)

这个 toy MoE 体现了 MoE 的核心逻辑:

router 给每个 token 分配 expert
    ↓
只激活 top-k expert
    ↓
把多个 expert 的输出加权求和

这就是为什么 MoE 模型可以拥有很大的总参数量,但每次实际激活的参数量相对较小。

9.4 精度层面的优化

DeepSeek-V4 官方模型卡中列出的精度包括:

Base model: FP8 Mixed
Instruct model: FP4 + FP8 Mixed

其中,instruct 模型使用 FP4 + FP8 mixed precision,可以进一步降低存储和计算成本。

9.5 系统工程层面的优化

DeepSeek-V4 技术报告还提到一系列基础设施优化,例如:

  • single fused kernel for MoE modules;
  • TileLang;
  • deterministic kernel libraries;
  • tensor-level checkpointing;
  • hybrid ZeRO;
  • two-stage contextual parallelism;
  • heterogeneous KV cache;
  • on-disk KV cache storage。

这说明大模型效率不是只靠一个公式,而是来自模型架构、优化器、训练系统、推理系统、精度策略和硬件适配的共同作用。


10. 预训练与后训练流程

DeepSeek-V4 的预训练规模非常大。官方技术报告中提到:

DeepSeek-V4-Flash: 32T tokens
DeepSeek-V4-Pro: 33T tokens

也就是说,两个模型都在超过 32T 的高质量 token 上进行了预训练。

后训练方面,DeepSeek-V4 使用了两阶段范式:

阶段一:独立培养不同领域专家
阶段二:通过 on-policy distillation 进行统一模型整合

可以粗略理解为:先针对数学、代码、agent、指令跟随等领域训练不同专家能力,再通过统一蒸馏,把不同领域能力整合到一个模型中。

这对 agent 任务尤其重要,因为 agent 任务通常不是单一步骤,而是包含任务理解、计划生成、工具调用、文件读取、代码修改、测试反馈、错误修复和多轮反思。因此,DeepSeek-V4 的后训练目标不是只提升普通问答,而是进一步强化长程推理和 agentic capabilities。

10.1 MTP 的基本思想

MTP 的全称是 Multi-Token Prediction,可以翻译为“多 token 预测”。普通语言模型训练时,通常是预测下一个 token:

输入:  我 喜欢
预测:  学习

也就是:

predict next token

MTP 的直觉是:模型不仅预测下一个 token,还可以同时预测后面多个 token,从而提供更丰富的训练信号。

例如:

输入:  我 喜欢
预测:  学习 深度 模型

10.2 MTP 的 toy-level 代码理解

下面是一个非常简化的 toy MTP head:

import torch
import torch.nn as nn


class ToyMTPHead(nn.Module):
    def __init__(self, dim: int, vocab_size: int, num_future_tokens: int = 3):
        super().__init__()
        self.num_future_tokens = num_future_tokens
        self.heads = nn.ModuleList(
            [
                nn.Linear(dim, vocab_size)
                for _ in range(num_future_tokens)
            ]
        )

    def forward(self, hidden_states: torch.Tensor):
        logits_list = []

        for head in self.heads:
            logits = head(hidden_states)
            logits_list.append(logits)

        return logits_list


if __name__ == "__main__":
    batch_size = 2
    seq_len = 8
    dim = 256
    vocab_size = 32000

    hidden_states = torch.randn(batch_size, seq_len, dim)

    mtp_head = ToyMTPHead(
        dim=dim,
        vocab_size=vocab_size,
        num_future_tokens=3,
    )

    logits_list = mtp_head(hidden_states)

    for i, logits in enumerate(logits_list):
        print(f"predict t+{i+1} logits shape:", logits.shape)

如果要计算 toy MTP loss,可以这样理解:

import torch
import torch.nn.functional as F


def toy_mtp_loss(logits_list, input_ids):
    total_loss = 0.0
    valid_heads = 0

    for i, logits in enumerate(logits_list):
        shift = i + 1

        if input_ids.shape[1] <= shift:
            continue

        pred_logits = logits[:, :-shift, :].contiguous()
        labels = input_ids[:, shift:].contiguous()

        loss = F.cross_entropy(
            pred_logits.view(-1, pred_logits.size(-1)),
            labels.view(-1),
        )

        total_loss = total_loss + loss
        valid_heads += 1

    return total_loss / max(valid_heads, 1)


if __name__ == "__main__":
    batch_size = 2
    seq_len = 8
    vocab_size = 32000

    input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

    logits_list = [
        torch.randn(batch_size, seq_len, vocab_size),
        torch.randn(batch_size, seq_len, vocab_size),
        torch.randn(batch_size, seq_len, vocab_size),
    ]

    loss = toy_mtp_loss(logits_list, input_ids)
    print("toy mtp loss:", loss.item())

MTP 的直觉可以总结为:

普通 LM:
当前位置只预测下一个 token。

MTP:
当前位置同时预测未来多个 token。

这可以增加训练信号密度,也有助于模型学习更长跨度的 token 关系。这里的代码只是 toy-level 理解,不代表 DeepSeek-V4 的真实 MTP 实现。


11. 为什么 DeepSeek-V4 对 Agent 很重要?

DeepSeek 官方发布中特别强调了 agent 能力。这很合理,因为 agent 任务天然依赖长上下文。

一个典型代码修复 agent 的上下文可能包含:

用户问题
仓库结构
相关源码文件
错误日志
测试输出
历史修改
工具调用轨迹
中间推理状态
失败尝试
最终补丁

如果上下文窗口不足,模型很容易丢失关键信息。例如,忘记之前读过的文件,忘记测试失败原因,忘记自己已经尝试过的修复方案,对不同日志之间的关系判断错误,或者无法从长轨迹中定位真正的 bug。

百万 token 上下文的价值在于,它可以让 agent 在更完整的任务轨迹中进行推理,而不是不断依赖外部检索和截断摘要。这对代码智能体、科研智能体、文档智能体、医学文献分析、法律证据链分析等场景都非常重要。


12. 非官方 PyTorch 推测复现说明

下面进入完整 toy 模型部分。

再次强调:

本文代码不是 DeepSeek 官方实现。
本文代码不是 DeepSeek-V4 真实源码。
本文代码不能复现 DeepSeek-V4 的真实性能。
本文代码只用于理解 CSA、HCA、MoE、MTP 和 Hybrid Attention 的结构直觉。

本文完整 toy 模型会组合以下部分:

  • Toy CSA:先压缩 K/V,再做 top-k sparse attention;
  • Toy HCA:更强压缩 K/V,再做 dense attention;
  • Toy Hybrid Attention:用 gate 融合 CSA 和 HCA 输出;
  • Toy MoE:用 router 选择 top-k experts;
  • Toy Gated Residual:用作 mHC 的非官方类比;
  • Toy MTP Head:同时预测未来多个 token。

为了保持代码简单,本文没有实现真实 DeepSeek Sparse Attention、真实 KV cache 管理、真实 mHC、真实 DeepSeekMoE、真实 MTP、FP4 / FP8、fused kernel、TileLang、分布式训练,也没有实现 causal mask 的严格长上下文推理逻辑。


13. 完整 Toy 代码

下面给出完整代码。建议保存为:

toy_deepseek_v4_style_model.py

然后直接运行:

python toy_deepseek_v4_style_model.py

完整代码如下:

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class SequenceCompressor(nn.Module):
    def __init__(self, dim: int, compression_ratio: int = 4):
        super().__init__()
        if compression_ratio <= 0:
            raise ValueError("compression_ratio must be positive")
        self.dim = dim
        self.compression_ratio = compression_ratio
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 3:
            raise ValueError("x must be [batch, seq_len, dim]")

        b, n, d = x.shape
        if d != self.dim:
            raise ValueError(f"expected dim={self.dim}, got dim={d}")

        r = self.compression_ratio
        pad_len = (r - n % r) % r

        if pad_len > 0:
            pad = torch.zeros(b, pad_len, d, device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad], dim=1)

        n2 = x.shape[1]
        x = x.view(b, n2 // r, r, d).mean(dim=2)
        return self.proj(x)


class ToyCompressedSparseAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 4,
        compression_ratio: int = 4,
        topk: int = 8,
    ):
        super().__init__()
        if dim % num_heads != 0:
            raise ValueError("dim must be divisible by num_heads")
        if topk <= 0:
            raise ValueError("topk must be positive")

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.topk = topk

        self.q_proj = nn.Linear(dim, dim)
        self.k_compressor = SequenceCompressor(dim, compression_ratio)
        self.v_compressor = SequenceCompressor(dim, compression_ratio)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, n, d = x.shape
        x = x.view(b, n, self.num_heads, self.head_dim)
        return x.transpose(1, 2)

    def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, h, n, d = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(b, n, h * d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.split_heads(self.q_proj(x))

        k_base = self.k_compressor(x)
        v_base = self.v_compressor(x)

        k = self.split_heads(self.k_proj(k_base))
        v = self.split_heads(self.v_proj(v_base))

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        topk = min(self.topk, scores.shape[-1])
        topk_values, topk_indices = torch.topk(scores, k=topk, dim=-1)

        sparse_scores = torch.full_like(scores, float("-inf"))
        sparse_scores.scatter_(-1, topk_indices, topk_values)

        attn = F.softmax(sparse_scores, dim=-1)
        out = torch.matmul(attn, v)

        out = self.merge_heads(out)
        return self.o_proj(out)


class ToyHeavilyCompressedAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 4,
        compression_ratio: int = 16,
    ):
        super().__init__()
        if dim % num_heads != 0:
            raise ValueError("dim must be divisible by num_heads")

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.q_proj = nn.Linear(dim, dim)
        self.k_compressor = SequenceCompressor(dim, compression_ratio)
        self.v_compressor = SequenceCompressor(dim, compression_ratio)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, n, d = x.shape
        x = x.view(b, n, self.num_heads, self.head_dim)
        return x.transpose(1, 2)

    def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        b, h, n, d = x.shape
        x = x.transpose(1, 2).contiguous()
        return x.view(b, n, h * d)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.split_heads(self.q_proj(x))

        k_base = self.k_compressor(x)
        v_base = self.v_compressor(x)

        k = self.split_heads(self.k_proj(k_base))
        v = self.split_heads(self.v_proj(v_base))

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)

        out = torch.matmul(attn, v)
        out = self.merge_heads(out)
        return self.o_proj(out)


class ToyHybridAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 4,
        csa_compression: int = 4,
        hca_compression: int = 16,
        topk: int = 8,
    ):
        super().__init__()
        self.csa = ToyCompressedSparseAttention(
            dim=dim,
            num_heads=num_heads,
            compression_ratio=csa_compression,
            topk=topk,
        )

        self.hca = ToyHeavilyCompressedAttention(
            dim=dim,
            num_heads=num_heads,
            compression_ratio=hca_compression,
        )

        self.gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Sigmoid(),
        )

        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        csa_out = self.csa(x)
        hca_out = self.hca(x)
        gate = self.gate(x)
        y = gate * csa_out + (1.0 - gate) * hca_out
        return self.out_proj(y)


class ToyGatedResidual(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(dim, dim),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor, fx: torch.Tensor) -> torch.Tensor:
        return x + self.gate(x) * fx


class ToyExpert(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class ToyMoE(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        num_experts: int = 4,
        top_k: int = 2,
    ):
        super().__init__()
        if top_k > num_experts:
            raise ValueError("top_k must be <= num_experts")

        self.num_experts = num_experts
        self.top_k = top_k
        self.router = nn.Linear(dim, num_experts)
        self.experts = nn.ModuleList(
            [ToyExpert(dim, hidden_dim) for _ in range(num_experts)]
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        b, n, d = x.shape
        router_logits = self.router(x)
        router_probs = F.softmax(router_logits, dim=-1)

        topk_probs, topk_indices = torch.topk(
            router_probs,
            k=self.top_k,
            dim=-1,
        )

        topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
        output = torch.zeros_like(x)

        for expert_id, expert in enumerate(self.experts):
            expert_out = expert(x)
            weight = torch.zeros(b, n, device=x.device, dtype=x.dtype)

            for k in range(self.top_k):
                weight = weight + torch.where(
                    topk_indices[:, :, k] == expert_id,
                    topk_probs[:, :, k],
                    torch.zeros_like(topk_probs[:, :, k]),
                )

            output = output + expert_out * weight.unsqueeze(-1)

        return output


class ToyTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 4,
        moe_hidden_dim: int = 1024,
        num_experts: int = 4,
        top_k: int = 2,
    ):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = ToyHybridAttention(dim=dim, num_heads=num_heads)

        self.norm2 = nn.LayerNorm(dim)
        self.moe = ToyMoE(
            dim=dim,
            hidden_dim=moe_hidden_dim,
            num_experts=num_experts,
            top_k=top_k,
        )

        self.resid1 = ToyGatedResidual(dim)
        self.resid2 = ToyGatedResidual(dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_out = self.attn(self.norm1(x))
        x = self.resid1(x, attn_out)

        moe_out = self.moe(self.norm2(x))
        x = self.resid2(x, moe_out)

        return x


class ToyMTPHead(nn.Module):
    def __init__(self, dim: int, vocab_size: int, num_future_tokens: int = 3):
        super().__init__()
        self.num_future_tokens = num_future_tokens
        self.heads = nn.ModuleList(
            [nn.Linear(dim, vocab_size) for _ in range(num_future_tokens)]
        )

    def forward(self, hidden_states: torch.Tensor):
        logits_list = []
        for head in self.heads:
            logits = head(hidden_states)
            logits_list.append(logits)
        return logits_list


class ToyDeepSeekV4StyleModel(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        dim: int = 256,
        num_layers: int = 2,
        num_heads: int = 8,
        max_seq_len: int = 2048,
        moe_hidden_dim: int = 1024,
        num_experts: int = 4,
        top_k: int = 2,
        num_future_tokens: int = 3,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.dim = dim
        self.max_seq_len = max_seq_len

        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.blocks = nn.ModuleList(
            [
                ToyTransformerBlock(
                    dim=dim,
                    num_heads=num_heads,
                    moe_hidden_dim=moe_hidden_dim,
                    num_experts=num_experts,
                    top_k=top_k,
                )
                for _ in range(num_layers)
            ]
        )

        self.norm = nn.LayerNorm(dim)
        self.lm_head = nn.Linear(dim, vocab_size, bias=False)
        self.mtp_head = ToyMTPHead(
            dim=dim,
            vocab_size=vocab_size,
            num_future_tokens=num_future_tokens,
        )

    def forward(self, input_ids: torch.Tensor):
        if input_ids.dim() != 2:
            raise ValueError("input_ids must be [batch, seq_len]")

        b, n = input_ids.shape
        if n > self.max_seq_len:
            raise ValueError(f"seq_len={n} exceeds max_seq_len={self.max_seq_len}")

        positions = torch.arange(n, device=input_ids.device).unsqueeze(0).expand(b, n)

        x = self.token_emb(input_ids) + self.pos_emb(positions)

        for block in self.blocks:
            x = block(x)

        hidden_states = self.norm(x)

        next_token_logits = self.lm_head(hidden_states)
        mtp_logits_list = self.mtp_head(hidden_states)

        return next_token_logits, mtp_logits_list


def toy_mtp_loss(logits_list, input_ids):
    total_loss = 0.0
    valid_heads = 0

    for i, logits in enumerate(logits_list):
        shift = i + 1

        if input_ids.shape[1] <= shift:
            continue

        pred_logits = logits[:, :-shift, :].contiguous()
        labels = input_ids[:, shift:].contiguous()

        loss = F.cross_entropy(
            pred_logits.view(-1, pred_logits.size(-1)),
            labels.view(-1),
        )

        total_loss = total_loss + loss
        valid_heads += 1

    return total_loss / max(valid_heads, 1)


def main():
    torch.manual_seed(42)

    vocab_size = 32000
    batch_size = 2
    seq_len = 128
    dim = 256
    num_layers = 2
    num_heads = 8

    model = ToyDeepSeekV4StyleModel(
        vocab_size=vocab_size,
        dim=dim,
        num_layers=num_layers,
        num_heads=num_heads,
        max_seq_len=2048,
        moe_hidden_dim=1024,
        num_experts=4,
        top_k=2,
        num_future_tokens=3,
    )

    input_ids = torch.randint(
        low=0,
        high=vocab_size,
        size=(batch_size, seq_len),
    )

    next_token_logits, mtp_logits_list = model(input_ids)

    print("input_ids shape:", input_ids.shape)
    print("next_token_logits shape:", next_token_logits.shape)

    for i, logits in enumerate(mtp_logits_list):
        print(f"mtp logits t+{i+1} shape:", logits.shape)

    loss = toy_mtp_loss(mtp_logits_list, input_ids)
    print("toy mtp loss:", loss.item())


if __name__ == "__main__":
    main()

预期输出类似:

input_ids shape: torch.Size([2, 128])
next_token_logits shape: torch.Size([2, 128, 32000])
mtp logits t+1 shape: torch.Size([2, 128, 32000])
mtp logits t+2 shape: torch.Size([2, 128, 32000])
mtp logits t+3 shape: torch.Size([2, 128, 32000])
toy mtp loss: ...

14. 代码结构解释

上面的完整 toy 代码主要包含以下模块。

14.1 SequenceCompressor

SequenceCompressor 用 mean pooling 模拟沿序列维度压缩:

[batch, seq_len, dim]
        |
按 compression_ratio 分块
        |
mean pooling
        |
[batch, compressed_len, dim]

如果 seq_len = 1024compression_ratio = 4,那么压缩后的长度大约是 256。真实 DeepSeek-V4 的 KV cache 压缩方式会复杂得多,这里只是为了展示“沿序列维度压缩”的基本直觉。

14.2 ToyCompressedSparseAttention

ToyCompressedSparseAttention 用来模拟 CSA 的基本思想:

原始序列 x
    |
生成 Q
    |
压缩 x 得到 compressed K/V
    |
Q attend compressed K/V
    |
top-k 稀疏选择

它体现的是:不让每个 token attend 到完整历史,而是在压缩后的表示中选择重要位置。

14.3 ToyHeavilyCompressedAttention

ToyHeavilyCompressedAttention 用来模拟 HCA 的基本思想。它使用更大的压缩比例,然后在更短的 compressed K/V 上做 dense attention。

代码差异可以简单概括为:

CSA: 压缩 K/V + top-k sparse attention
HCA: 更强压缩 K/V + dense attention

14.4 ToyHybridAttention

ToyHybridAttention 把 Toy CSA 和 Toy HCA 组合起来:

output = gate * CSA(x) + (1 - gate) * HCA(x)

真实 DeepSeek-V4 的融合方式不一定是这样。这里的 gate 只是为了帮助理解“稀疏细节”和“全局压缩语义”的互补关系。

14.5 ToyMoE

ToyMoE 用 router 为每个 token 选择 top-k experts,再把 expert 输出加权求和。它模拟的是 MoE 的核心思想:模型可以拥有多个 expert,但每次只激活其中一部分。

14.6 ToyGatedResidual

DeepSeek-V4 使用的是 mHC,即 Manifold-Constrained Hyper-Connections。本文代码中的 ToyGatedResidual 不是 mHC,只是为了帮助理解“增强残差路径”这个动机。

这一点必须明确:

ToyGatedResidual 不是 mHC。
ToyGatedResidual 只是为了帮助理解增强 residual connection 的一种 toy 写法。

14.7 ToyMTPHead

ToyMTPHead 用多个 head 同时预测未来多个 token。它模拟的是 MTP 的基本直觉:当前位置不仅预测下一个 token,也可以预测更远的未来 token,从而提供更密集的训练信号。


15. 复杂度直觉分析

假设原始序列长度为 n,压缩比例为 r。

普通 dense attention 的复杂度可以粗略理解为:

O(n²)

如果把 K/V 压缩到 n/r,那么 attention 复杂度可以粗略变成:

O(n × n/r)

如果再做 top-k 稀疏选择,那么复杂度可以进一步接近:

O(n × k)

当然,真实 DeepSeek-V4 的效率不能只用这个公式解释,因为还涉及:

  • KV cache 的实际存储布局;
  • 稀疏 attention 的 kernel 实现;
  • MoE expert routing;
  • GPU / NPU 内存带宽;
  • distributed communication;
  • FP4 / FP8 mixed precision;
  • fused kernel;
  • heterogeneous KV cache;
  • on-disk KV cache。

因此,本文的复杂度分析只是直觉说明,不是对官方系统的精确复现。


16. 本文 toy 代码和官方 DeepSeek-V4 的差距

本文代码和官方 DeepSeek-V4 差距非常大,主要体现在:

项目 本文 toy 代码 官方 DeepSeek-V4
模型规模 几层小模型 284B / 1.6T MoE
Attention 简化压缩 + top-k CSA + HCA + DSA
KV cache 无真实缓存系统 复杂 KV cache 压缩与管理
Residual 简单 gated residual mHC
FFN Toy MoE DeepSeekMoE
MTP Toy MTP head 官方 MTP 机制
优化器 未训练 Muon optimizer
精度 FP32 / 默认 PyTorch FP8 Mixed / FP4 + FP8 Mixed
Kernel PyTorch eager fused kernel / TileLang 等
训练 无真实训练 32T / 33T tokens 级别预训练
推理 toy forward 1M context 高效推理系统

所以,本文代码只能叫:

非官方 toy-level 推测实现

不能叫:

DeepSeek-V4 官方源码复现

也不能叫:

DeepSeek-V4 完整复现

17. 我对 DeepSeek-V4 的理解

DeepSeek-V4 的意义可以从三个角度理解。

17.1 长上下文从“能放进去”走向“能高效用起来”

很多模型都在扩大上下文窗口,但上下文长度变长以后,真正困难的是:

推理成本
KV cache 显存
长程信息利用能力

DeepSeek-V4 的重点是通过 CSA、HCA 和系统优化,让百万 token 上下文不只是形式上支持,而是更接近工程可用。

17.2 MoE 继续成为大模型扩展的重要路线

DeepSeek-V4-Pro 总参数达到 1.6T,但激活参数为 49B。这说明 MoE 的核心优势仍然很明显:

总容量大
激活成本相对可控

对于需要更强知识、推理、代码和 agent 能力的大模型来说,MoE 仍然是一条重要路线。

17.3 Agent 会成为长上下文模型的重要应用场景

DeepSeek-V4 官方发布中特别提到 agent 能力。这说明长上下文不只是为了“读长文档”,更是为了支撑复杂 agent workflow。

未来很多模型能力可能不只体现在单轮问答,而是体现在:

  • 能不能读完整仓库;
  • 能不能持续调用工具;
  • 能不能根据测试结果修复代码;
  • 能不能在长任务中保持状态;
  • 能不能从长轨迹中定位错误原因。

这也是百万 token context 的真正价值。


18. 总结

DeepSeek-V4 是一次面向百万 token 上下文的系统级模型升级。

它的核心不是简单“参数更大”,而是通过以下组合提升长上下文效率:

Hybrid Attention Architecture
= CSA + HCA

CSA
= KV cache compression + DeepSeek Sparse Attention

HCA
= heavier KV cache compression + dense attention

DeepSeekMoE
= large total parameters + limited activated parameters

MTP
= multi-token prediction

mHC
= enhanced residual connections

Muon optimizer
= faster convergence + improved training stability

System optimization
= fused kernel + TileLang + ZeRO + checkpointing + heterogeneous KV cache + on-disk KV cache

如果用一句话总结 DeepSeek-V4:

DeepSeek-V4 面向百万 token 长上下文和 agentic intelligence,通过 KV cache 压缩、稀疏注意力、MoE、MTP、mHC、Muon optimizer 和系统级推理优化,共同降低超长上下文推理的计算与存储成本。

本文的 toy PyTorch 代码只是为了帮助理解 Hybrid Attention、MoE、MTP 和增强残差路径的基本思路,不能代表官方实现。


Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐