DeepSeek-V4:面向高效百万 Token 上下文智能的探索
DeepSeek-V4:面向高效百万 Token 上下文智能的探索
-
- 0. 写在前面
- 1. 论文基本信息
- 2. 为什么 DeepSeek-V4 要关注百万 Token 上下文?
- 3. DeepSeek-V4 的整体架构理解
- 4. 核心技术一:Hybrid Attention Architecture
- 5. 核心技术二:CSA,Compressed Sparse Attention
- 6. 核心技术三:HCA,Heavily Compressed Attention
- 7. 核心技术四:mHC,Manifold-Constrained Hyper-Connections
- 8. 核心技术五:Muon Optimizer
- 9. DeepSeek-V4 的效率来自哪里?
- 10. 预训练与后训练流程
- 11. 为什么 DeepSeek-V4 对 Agent 很重要?
- 12. 非官方 PyTorch 推测复现说明
- 13. 完整 Toy 代码
- 14. 代码结构解释
- 15. 复杂度直觉分析
- 16. 本文 toy 代码和官方 DeepSeek-V4 的差距
- 17. 我对 DeepSeek-V4 的理解
- 18. 总结
0. 写在前面
DeepSeek-V4 是 DeepSeek 面向百万 token 长上下文能力推出的重要技术报告。与单纯扩大模型参数规模不同,DeepSeek-V4 更关注如何在超长上下文场景下同时兼顾推理效率、KV cache 占用和长程信息建模能力。本文围绕 DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence 进行学习型精读,重点梳理其 Hybrid Attention Architecture、Compressed Sparse Attention,CSA、Heavily Compressed Attention,HCA、Manifold-Constrained Hyper-Connections,mHC、Muon optimizer、DeepSeekMoE、Multi-Token Prediction,MTP,以及长上下文推理优化等关键内容。
写这篇博客的目的并不是复现 DeepSeek-V4 的官方工程系统,而是尝试从论文阅读者和代码学习者的角度,把其中比较抽象的架构设计拆解成更容易理解的模块。对于刚接触长上下文大模型、MoE 架构和 KV cache 优化的读者来说,直接阅读技术报告可能会遇到术语密集、系统细节复杂、算法与工程交织较深等问题。因此,本文会先用较直观的方式解释 DeepSeek-V4 为什么强调百万 token 上下文,再分析 CSA、HCA、mHC、MoE、MTP 等核心设计背后的动机,最后给出一组非官方 toy-level PyTorch 推测实现,帮助理解这些结构的基本思想。
需要说明的是,本文代码仅用于学习和理解论文思想,不是 DeepSeek 官方源码,也不能复现 DeepSeek-V4 的真实训练、推理效率或模型能力。真实 DeepSeek-V4 涉及大规模 MoE、DSA、mHC、Muon optimizer、FP4/FP8 混合精度、fused kernel、分布式训练和复杂 KV cache 管理等系统级工程,远超一个简化示例的范围。本文更适合作为一篇“论文精读 + 思想拆解 + 非官方最小代码理解”的学习笔记。
1. 论文基本信息
DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence 是 DeepSeek-AI 发布的一篇面向长上下文大语言模型的技术报告,公开于 arXiv,编号为 arXiv:2606.19348。该报告围绕百万 token 上下文建模展开,重点研究如何在极长序列输入条件下降低单 token 推理计算量、减少 KV cache 占用,并维持模型对长程依赖和复杂任务的建模能力。
DeepSeek-V4 系列包括 DeepSeek-V4-Pro 与 DeepSeek-V4-Flash 两个版本,均采用 Mixture-of-Experts,MoE 架构并支持 1M tokens 上下文长度。其中,DeepSeek-V4-Pro 的总参数量为 1.6T,激活参数量为 49B;DeepSeek-V4-Flash 的总参数量为 284B,激活参数量为 13B。从定位上看,Pro 版本更强调复杂推理、代码和智能体任务能力,Flash 版本则更强调推理效率、响应速度和部署成本控制。
相较于以往模型主要依赖扩大上下文窗口的做法,DeepSeek-V4 更强调长上下文场景下的效率优化。其核心方法包括由 CSA 与 HCA 组成的 Hybrid Attention Architecture、用于增强传统 residual connection 的 mHC,以及用于提升训练收敛速度和稳定性的 Muon optimizer。从研究定位来看,DeepSeek-V4 不仅是一项模型架构更新,也是一项面向超长上下文推理的系统工程探索。其技术路线将注意力压缩、稀疏计算、MoE 扩展、混合精度表示和推理系统优化结合起来,为代码理解、多文档推理、长程 agent 任务和复杂知识整合等应用场景提供了新的参考。
2. 为什么 DeepSeek-V4 要关注百万 Token 上下文?
最近大模型的发展已经不只是“参数越大越好”。长上下文能力、推理成本、KV cache 显存占用、agent 任务表现,正在变成越来越重要的竞争点。DeepSeek-V4 这篇技术报告正好集中讨论了这些问题:如何让模型在百万 token 级别上下文中仍然保持较高效率,并且能够服务代码、推理、agent 等复杂任务。
传统 Transformer 的 self-attention 复杂度与序列长度强相关。对于普通聊天任务,几十 K token 上下文可能已经够用;但对于更复杂的任务,长上下文会变得非常重要。例如:
- 阅读整个代码仓库并修复 bug;
- 分析多文件、多日志、多轮测试结果;
- 对几十篇论文进行综合综述;
- 对法律文书、合同、证据链进行交叉引用;
- 在 agent 任务中保留完整工具调用轨迹;
- 在长期任务中持续引用历史状态和中间结论。
这些任务有一个共同点:信息不是集中在某一个片段里,而是分散在很长的上下文中。
传统 RAG 可以缓解上下文窗口不足的问题,但 RAG 也有局限。例如,检索可能漏掉关键片段,检索结果之间缺乏全局关系,多轮 agent 的中间状态容易被截断,长程推理时模型也很难同时看到足够完整的证据链。因此,百万 token 上下文的价值不只是“能塞更多内容”,而是让模型在更完整的上下文空间中进行全局推理,减少信息截断和检索遗漏带来的错误。
3. DeepSeek-V4 的整体架构理解
DeepSeek-V4 仍然保留 Transformer 架构,并继续使用 DeepSeek 系列中的一些重要设计,例如 DeepSeekMoE 和 Multi-Token Prediction,MTP。在此基础上,DeepSeek-V4 引入了几个关键升级:
-
Hybrid Attention Architecture
- Compressed Sparse Attention,CSA
- Heavily Compressed Attention,HCA
-
Manifold-Constrained Hyper-Connections,mHC
- 用于增强传统 residual connection
-
Muon optimizer
- 用于提升训练收敛速度和稳定性
-
系统级训练与推理优化
- MoE fused kernel
- TileLang
- deterministic kernel
- tensor-level checkpointing
- hybrid ZeRO
- two-stage contextual parallelism
- heterogeneous KV cache
- on-disk KV cache storage
- FP4 / FP8 mixed precision
可以把整体结构粗略理解成:
Input Tokens
|
Token Embedding
|
[Hybrid Attention + mHC + DeepSeekMoE] × N
|
MTP / LM Head
|
Output Tokens
其中每一层大致可以理解为:
Hidden States
|
Hybrid Attention
|---- CSA: compressed KV cache + sparse attention
|---- HCA: heavily compressed KV cache + dense attention
|
mHC / enhanced residual connection
|
DeepSeekMoE FFN
|
Output Hidden States
这说明 DeepSeek-V4 的重点不是单个技巧,而是把模型架构、优化器、训练系统、推理系统和长上下文任务结合到一起。
4. 核心技术一:Hybrid Attention Architecture
DeepSeek-V4 最重要的结构升级之一是 Hybrid Attention Architecture,可以翻译为“混合注意力架构”。它由两个部分组成:
Hybrid Attention Architecture = CSA + HCA
其中:
| 模块 | 英文全称 | 中文理解 | 主要思想 |
|---|---|---|---|
| CSA | Compressed Sparse Attention | 压缩稀疏注意力 | 沿序列维度压缩 KV cache,然后执行 DeepSeek Sparse Attention |
| HCA | Heavily Compressed Attention | 重压缩注意力 / 高度压缩注意力 | 对 KV cache 进行更强压缩,但保留 dense attention |
这里一定要注意:CSA 和 HCA 讨论的是 KV cache 的压缩,不是简单把原始文本压缩成摘要。
KV cache 是大模型推理中的核心缓存结构。随着上下文长度增加,KV cache 会占用越来越多显存。对于百万 token 上下文,KV cache 的存储和访问成本会变得非常高。因此,DeepSeek-V4 的思路不是让每一个 token 都以完整 KV 形式长期保留,而是对 KV cache 进行压缩、稀疏选择和系统级管理。
5. 核心技术二:CSA,Compressed Sparse Attention
5.1 CSA 的基本思想
CSA 的全称是 Compressed Sparse Attention,可以翻译为“压缩稀疏注意力”。官方技术报告中的关键思想可以概括为:
CSA compresses the KV caches along the sequence dimension and then performs DeepSeek Sparse Attention.
对应中文是:
CSA 沿序列维度压缩 KV cache,然后执行 DeepSeek Sparse Attention。
也就是说,CSA 有两个动作:
- 沿序列维度压缩 KV cache;
- 在压缩后的 KV cache 上执行 DeepSeek Sparse Attention,简称 DSA。
传统 attention 的形式是:
Attention(Q, K, V) = softmax(QK^T / sqrt(d))V
如果序列长度是 n,那么 QK^T 的规模大致是 n × n。对于 1M token 来说,直接做 dense attention 成本会非常高。
CSA 的直觉是:
原始 KV cache
|
沿序列维度压缩
|
compressed KV cache
|
稀疏选择重要位置
|
执行 sparse attention
可以用阅读长文档的过程做类比:近处内容需要更细粒度地看,远处内容可以压缩成更粗粒度的表示,而真正参与注意力计算的是更重要、更相关的部分。CSA 的目标不是完整保留所有 token-level 细节,而是在长上下文中尽量降低计算量和缓存开销。
5.2 Dense Attention、Sparse Attention 与 CSA 的计算例子
为了理解 DeepSeek-V4 中的 Compressed Sparse Attention,CSA,我们先从最基础的 Q、K、V 计算开始。注意,这里的例子是为了帮助理解 attention 的计算逻辑,不是 DeepSeek-V4 的官方实现。
第一步:从 token 向量到 Q、K、V
假设一句话有 3 个 token:
token 1 = 我
token 2 = 喜欢
token 3 = 学习
经过 token embedding 之后,每个 token 会变成一个 hidden vector。为了方便计算,假设每个向量只有 2 维:
h1 = [1, 0] # “我”
h2 = [0, 1] # “喜欢”
h3 = [1, 1] # “学习”
在 attention 中,每个 hidden vector 会分别经过三个 projection matrix,也就是三个投影矩阵:
Q = h × WQ
K = h × WK
V = h × WV
其中:
| 矩阵 | 中文含义 | 作用 |
|---|---|---|
| WQ | Query 投影矩阵 | 把 hidden vector 变成 Query |
| WK | Key 投影矩阵 | 把 hidden vector 变成 Key |
| WV | Value 投影矩阵 | 把 hidden vector 变成 Value |
projection matrix 可以理解为“线性变换矩阵”或“映射矩阵”。它的作用是把同一个 token 的 hidden vector 映射成三种不同用途的向量:Q 用于查询,K 用于匹配,V 用于提供真正被汇总的信息内容。
为了让计算更简单,这里先假设三个 projection matrix 都是单位矩阵:
WQ = WK = WV = I
所以:
Q1 = K1 = V1 = h1 = [1, 0]
Q2 = K2 = V2 = h2 = [0, 1]
Q3 = K3 = V3 = h3 = [1, 1]
这里需要记住一句话:
Q 和 K 用来计算“应该关注谁”,V 才是真正被加权汇总的信息内容。
第二步:Dense Attention 如何计算
现在我们以 token 3,也就是“学习”为当前 token,计算它应该关注哪些 token。
当前 token 的 Query 是:
Q3 = [1, 1]
所有 token 的 Key 是:
K1 = [1, 0]
K2 = [0, 1]
K3 = [1, 1]
所有 token 的 Value 是:
V1 = [1, 0]
V2 = [0, 1]
V3 = [1, 1]
attention score 的计算公式是:
score = Q · K / sqrt(d)
这里向量维度 d = 2,所以:
sqrt(d) = sqrt(2) ≈ 1.414
分别计算 token 3 对三个 token 的关注分数:
Q3 · K1 = [1, 1] · [1, 0] = 1
score_31 = 1 / 1.414 ≈ 0.707
Q3 · K2 = [1, 1] · [0, 1] = 1
score_32 = 1 / 1.414 ≈ 0.707
Q3 · K3 = [1, 1] · [1, 1] = 2
score_33 = 2 / 1.414 ≈ 1.414
所以原始 scores 是:
scores = [0.707, 0.707, 1.414]
接着对 scores 做 softmax:
softmax 的数学表示可以写成:
softmax ( x i ) = e x i ∑ j = 1 n e x j \text{softmax}(x_i)=\frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}} softmax(xi)=∑j=1nexjexi
softmax([0.707, 0.707, 1.414])
≈ [0.248, 0.248, 0.503]
得到 attention 权重:
| 被关注 token | score | attention weight |
|---|---|---|
| token 1:我 | 0.707 | 0.248 |
| token 2:喜欢 | 0.707 | 0.248 |
| token 3:学习 | 1.414 | 0.503 |
最后用这些权重加权求和 V:
output_3 = 0.248 × V1 + 0.248 × V2 + 0.503 × V3
这里的output3是第 3 个 token 的 attention 输出向量,捏可以理解为是学习了上下文的表示向量。
代入:
V1 = [1, 0]
V2 = [0, 1]
V3 = [1, 1]
计算:
output_3 = 0.248 × [1, 0]
+ 0.248 × [0, 1]
+ 0.503 × [1, 1]
= [0.248, 0]
+ [0, 0.248]
+ [0.503, 0.503]
= [0.751, 0.751]
这就是 dense attention 的计算结果。
Dense attention 的特点是:
所有可见 token 都参与 attention 计算。
在这个例子中,token 1、token 2、token 3 全部参与最终输出:
output_3 = 0.248V1 + 0.248V2 + 0.503V3
也就是说,dense attention 是“全部都看”,只是不同 token 的权重不同。
第三步:Sparse Attention 如何计算
现在我们把 dense attention 改成 sparse attention。
前面 dense attention 已经算出了 token 3 对三个 token 的 attention score:
scores = [0.707, 0.707, 1.414]
这三个分数分别对应:
token 1:“我” score = 0.707
token 2:“喜欢” score = 0.707
token 3:“学习” score = 1.414
这里要先明确一个容易误解的地方:
top-1 不是指“第 1 个 token”
top-1 是指“分数排名第 1 的 token”
也就是说,top-1 表示只保留 attention score 最大的那个位置。
在当前例子中,三个分数是:
[0.707, 0.707, 1.414]
其中最大的是:
1.414
它对应的是:
token 3:“学习”
所以如果我们只保留 top-1,也就是只保留分数最高的 token,那么只保留 token 3:
top-1 = token 3
此时 token 1 和 token 2 虽然原来也有 attention score,但是在 sparse attention 中会被屏蔽掉,不再参与后面的 V 加权求和。
为了实现这种“屏蔽”,通常会把不保留的位置设为负无穷:
sparse_scores = [-∞, -∞, 1.414]
这里的 -∞ 可以理解为:
这个位置被屏蔽了
这个位置后面不会分到注意力权重
为什么设成 -∞ 之后 softmax 权重会变成 0 呢?
因为 softmax 的数学形式是:
softmax(x_i) = e^(x_i) / sum(e^(x_j))
如果某个位置是 -∞,那么:
e^(-∞) = 0
所以被设为 -∞ 的位置,在 softmax 之后权重就是 0。
因此:
softmax([-∞, -∞, 1.414]) = [0, 0, 1]
这表示:
token 1 的注意力权重 = 0
token 2 的注意力权重 = 0
token 3 的注意力权重 = 1
所以输出变成:
sparse_output_3 = 0 × V1 + 0 × V2 + 1 × V3
= V3
前面我们设定:
V3 = [1, 1]
所以:
sparse_output_3 = [1, 1]
这就是 sparse attention top-1 的计算过程。
和 dense attention 对比一下:
Dense Attention:
output_3 = 0.248 × V1 + 0.248 × V2 + 0.503 × V3
dense attention 中,token 1、token 2、token 3 的 V 都参与了最终输出。
而 sparse attention top-1 中:
Sparse Attention top-1:
sparse_output_3 = 1 × V3
只有分数最高的 token 3 参与了最终输出。
直观理解就是:
Dense Attention:
token 3 同时看 token 1、token 2、token 3。
Sparse Attention top-1:
token 3 只看 attention score 最高的那个 token。
在这个例子中,分数最高的是 token 3,所以只看 token 3。
接着看 top-2 的情况。
top-2 的意思是:
只保留 attention score 排名前 2 的 token
原始 scores 仍然是:
scores = [0.707, 0.707, 1.414]
其中分数最高的是 token 3:
token 3: 1.414
第二高的是 token 1 或 token 2,因为它们的分数相同:
token 1: 0.707
token 2: 0.707
这里为了方便演示,我们假设保留 token 1 和 token 3,屏蔽 token 2。
于是 sparse scores 变成:
sparse_scores = [0.707, -∞, 1.414]
注意,这里不是直接使用 dense attention 里面原来的权重:
dense attention 原来的权重是 [0.248, 0.248, 0.503]
而是要对新的 sparse_scores 重新做 softmax:
softmax([0.707, -∞, 1.414])
因为 token 2 被设成了 -∞,所以 token 2 的权重变成 0。剩下 token 1 和 token 3 的权重会重新归一化,并且它们加起来等于 1。
softmax 后大约是:
softmax([0.707, -∞, 1.414])
≈ [0.330, 0.000, 0.670]
这表示:
token 1 的注意力权重 ≈ 0.330
token 2 的注意力权重 = 0.000
token 3 的注意力权重 ≈ 0.670
所以输出为:
sparse_output_3 = 0.330 × V1 + 0.000 × V2 + 0.670 × V3
因为 token 2 的权重是 0,所以可以省略:
sparse_output_3 = 0.330 × V1 + 0.670 × V3
代入前面的 Value 向量:
V1 = [1, 0]
V3 = [1, 1]
得到:
sparse_output_3 = 0.330 × [1, 0] + 0.670 × [1, 1]
= [0.330, 0] + [0.670, 0.670]
= [1.000, 0.670]
所以 top-2 sparse attention 的输出是:
sparse_output_3 = [1.000, 0.670]
这里最关键的是理解 sparse attention 的两个动作:
第一步:根据 attention score 选择 top-k 个重要位置
第二步:只对保留下来的位置重新做 softmax 和 V 加权求和
因此,sparse attention 不是简单地把 dense attention 里面的小权重删掉,而是先屏蔽掉不重要的位置,再在剩下的位置之间重新分配注意力权重。
所以 dense attention 和 sparse attention 的区别可以总结为:
| 方法 | 参与计算的 token | softmax 范围 | 输出方式 | 直观理解 |
|---|---|---|---|---|
| Dense Attention | 所有 token 都参与 | 对所有 score 做 softmax | 所有 V 加权求和 | 全部都看 |
| Sparse Attention top-1 | 只保留分数最高的 1 个 token | 只对保留位置重新 softmax | 只使用 1 个 V | 只看最重要的一个 |
| Sparse Attention top-2 | 只保留分数最高的 2 个 token | 只对保留位置重新 softmax | 只使用 2 个 V | 只看最重要的两个 |
一句话总结:
Dense attention 是所有 token 都看;
Sparse attention 是先选出 attention score 最高的 top-k 个 token,
然后只让这些 token 参与后续的 softmax 和 V 加权求和。
第四步:Compressed Sparse Attention 如何计算
现在进一步理解 DeepSeek-V4 中的 Compressed Sparse Attention,CSA。
普通 sparse attention 是:
原始 K/V
↓
选择重要位置
↓
对重要位置做 attention
而 Compressed Sparse Attention 的直觉是:
原始 KV cache
↓
先沿序列维度压缩 KV cache
↓
得到 compressed K/V
↓
再在 compressed K/V 上做 sparse attention
也就是说,CSA 比普通 sparse attention 多了一个关键步骤:
compression
它不是直接在完整 KV cache 上做稀疏选择,而是先把 KV cache 沿序列维度压缩,再在压缩后的表示中选择重要位置。
举一个简化例子。假设原来有 8 个历史 token 的 KV cache:
K1, K2, K3, K4, K5, K6, K7, K8
V1, V2, V3, V4, V5, V6, V7, V8
如果压缩比例是 2,可以把每两个 token 压缩成一个 compressed KV:
CK1 = compress(K1, K2)
CK2 = compress(K3, K4)
CK3 = compress(K5, K6)
CK4 = compress(K7, K8)
CV1 = compress(V1, V2)
CV2 = compress(V3, V4)
CV3 = compress(V5, V6)
CV4 = compress(V7, V8)
原来有 8 个 KV 位置,现在变成 4 个 compressed KV 位置:
原始长度: 8
压缩后长度: 4
然后当前 token 的 Query 不再和 8 个完整 Key 全部匹配,而是和 4 个 compressed Key 匹配:
Q · CK1
Q · CK2
Q · CK3
Q · CK4
假设得到的分数是:
compressed_scores = [2.0, 0.5, 1.5, -1.0]
如果只保留 top-2,那么选择:
CK1: 2.0
CK3: 1.5
其余位置丢掉:
sparse_compressed_scores = [2.0, -∞, 1.5, -∞]
softmax 后:
softmax([2.0, -∞, 1.5, -∞])
≈ [0.622, 0.000, 0.378, 0.000]
最终输出是:
output = 0.622 × CV1 + 0.378 × CV3
这就是 Compressed Sparse Attention 的直觉:
先把长 KV cache 压短,
再只选择压缩表示中的重要位置参与 attention。
第五步:Dense Attention、Sparse Attention 与 CSA 对比
可以把三者放在一起看:
| 方法 | 是否压缩 KV | 是否稀疏选择 | 参与 attention 的对象 |
|---|---|---|---|
| Dense Attention | 否 | 否 | 所有原始 K/V |
| Sparse Attention | 否 | 是 | 部分原始 K/V |
| Compressed Sparse Attention | 是 | 是 | 部分 compressed K/V |
如果用一句话理解:
Dense Attention:所有原始 K/V 都看。
Sparse Attention:只看部分重要的原始 K/V。
Compressed Sparse Attention:先压缩 K/V,再只看部分重要的 compressed K/V。
第六步:和 HCA 的区别
DeepSeek-V4 中除了 CSA,还有 HCA,也就是 Heavily Compressed Attention。
HCA 的全称是:
Heavily Compressed Attention
可以翻译为:
重压缩注意力
或者:
高度压缩注意力
这里的关键词有两个:
Heavily Compressed:更强程度压缩
Attention:注意力
所以 HCA 的核心直觉是:
先把 KV cache 压缩得更短,
然后在这个更短的 compressed K/V 上做 dense attention。
也就是说,HCA 不是直接在原始 K/V 上做 dense attention,而是在 强压缩之后的 K/V 上做 dense attention。
它的流程可以写成:
原始 KV cache
↓
更强程度压缩
↓
得到更短的 compressed K/V
↓
当前 Query 和所有 compressed Key 计算分数
↓
对所有 compressed positions 做 softmax
↓
加权求和所有 compressed Value
这里最容易误解的是:
HCA 里的 dense attention
不是看所有原始 K/V,
而是看所有压缩后的 compressed K/V。
也就是说,HCA 的 dense 是指:
在 compressed K/V 空间里全部都看
不是指:
在原始百万 token 的 K/V 空间里全部都看
这两者差别很大。
先回顾一下 CSA 的做法。
CSA 是:
原始 KV cache
↓
压缩
↓
得到 compressed K/V
↓
再做 sparse attention,只选择 top-k 重要位置
也就是说,CSA 有两个减少计算量的动作:
第一步:压缩 K/V,让 K/V 位置变少
第二步:稀疏选择,只保留 top-k 重要 compressed KV 位置
如果原来有 8 个 KV 位置:
(K1,V1), (K2,V2), (K3,V3), (K4,V4),
(K5,V5), (K6,V6), (K7,V7), (K8,V8)
CSA 可以先压缩成 4 个 compressed KV 位置:
(CK1,CV1), (CK2,CV2), (CK3,CV3), (CK4,CV4)
然后再从这 4 个里面选择 top-k。
例如 compressed scores 是:
compressed_scores = [2.0, 0.5, 1.5, -1.0]
如果 CSA 保留 top-2,那么只选:
CK1 / CV1
CK3 / CV3
屏蔽掉:
CK2 / CV2
CK4 / CV4
于是 sparse compressed scores 是:
sparse_compressed_scores = [2.0, -∞, 1.5, -∞]
softmax 后大约是:
[0.622, 0.000, 0.378, 0.000]
最终输出是:
output = 0.622 × CV1 + 0.378 × CV3
也就是说,CSA 在压缩之后,仍然只使用一部分 compressed Value。
所以 CSA 的直觉是:
压缩后再挑重点。
HCA 的思路不一样。
HCA 也会压缩 K/V,但是压缩得更狠。压缩完成之后,HCA 不再做 top-k sparse selection,而是对所有压缩后的 K/V 做 dense attention。
举一个简化例子。
假设原来还是有 8 个 KV 位置:
(K1,V1), (K2,V2), (K3,V3), (K4,V4),
(K5,V5), (K6,V6), (K7,V7), (K8,V8)
CSA 的教学例子里,我们可能把 8 个位置压缩成 4 个:
(CK1,CV1), (CK2,CV2), (CK3,CV3), (CK4,CV4)
而 HCA 更强压缩,可能把 8 个位置压缩成 2 个:
(HK1,HV1), (HK2,HV2)
这里为了区分,可以用:
HK = HCA compressed Key
HV = HCA compressed Value
如果用相邻分块压缩作为教学例子,可以写成:
HK1 = compress(K1, K2, K3, K4)
HK2 = compress(K5, K6, K7, K8)
HV1 = compress(V1, V2, V3, V4)
HV2 = compress(V5, V6, V7, V8)
这表示:
HK1 / HV1 代表 token 1 到 token 4 这一大段的压缩表示
HK2 / HV2 代表 token 5 到 token 8 这一大段的压缩表示
注意,这里的压缩比 CSA 更强。
CSA 中可能是:
2 个 token 压缩成 1 个 compressed KV
HCA 中可能是:
4 个 token 压缩成 1 个 compressed KV
这就是 HCA 里 Heavily Compressed 的含义:
压缩得更狠,压缩后的 KV 序列更短。
接下来 HCA 做 dense attention。
假设当前 Query 是 Q。
压缩后有两个 compressed Key:
HK1, HK2
那么 HCA 会计算:
Q · HK1
Q · HK2
假设得到分数:
hca_scores = [1.2, 0.8]
HCA 不会像 CSA 那样只保留 top-1,也不会把某些位置设成 -∞。
HCA 会直接对所有 compressed scores 做 softmax:
softmax([1.2, 0.8])
≈ [0.599, 0.401]
然后加权求和所有 compressed Value:
hca_output = 0.599 × HV1 + 0.401 × HV2
这里两个 compressed Value 都参与了输出。
所以 HCA 的 dense attention 是:
在压缩后的 KV 位置上全部都看。
不是:
在原始 KV 位置上全部都看。
这一点非常关键。
可以把 CSA 和 HCA 的完整过程放在一起对比。
假设原来有 8 个原始 KV 位置。
CSA 可以理解为:
原始 KV:
8 个位置
↓
中等程度压缩
↓
4 个 compressed KV 位置
↓
sparse attention
↓
只选择 top-2 个 compressed KV 位置参与输出
所以 CSA 实际使用的是:
2 个 compressed KV 位置
HCA 可以理解为:
原始 KV:
8 个位置
↓
更强程度压缩
↓
2 个 heavily compressed KV 位置
↓
dense attention
↓
2 个 heavily compressed KV 位置全部参与输出
所以 HCA 实际使用的是:
所有 heavily compressed KV 位置
也就是说:
CSA 是“压缩后再挑一部分看”。
HCA 是“压缩得更短,然后压缩后的都看”。
为什么 HCA 压缩得更狠之后还可以做 dense attention?
因为 dense attention 的成本主要取决于 K/V 位置数量。
如果原始 K/V 很长,dense attention 很贵。
例如原来有:
1,000,000 个原始 KV 位置
如果直接 dense attention,当前 Query 要看:
1,000,000 个 Key
这会非常贵。
但是 HCA 先把 KV cache 强压缩。
例如把:
1,000,000 个原始 KV 位置
压缩成:
10,000 个 heavily compressed KV 位置
那么 HCA 的 dense attention 实际是在这 10,000 个 compressed KV 上做,而不是在 1,000,000 个原始 KV 上做。
所以计算从:
Q 看 1,000,000 个原始 Key
变成:
Q 看 10,000 个 heavily compressed Key
虽然 HCA 在 compressed KV 上是 dense 的,但因为 compressed KV 的长度已经大幅缩短,所以 dense 的成本仍然可以接受。
换句话说:
HCA 不是靠 sparse selection 省计算,
而是靠更强的 KV 压缩省计算。
用复杂度直觉看也更清楚。
假设原始 KV 长度是:
n
如果直接 dense attention,那么当前 Query 要访问:
n 个 K/V 位置
如果 CSA 先压缩到:
n_csa
然后只选 top-k,那么实际参与 V 加权求和的位置大约是:
k
所以 CSA 的直觉是:
n → n_csa → k
如果 HCA 更强压缩到:
n_hca
然后在 n_hca 上做 dense attention,那么实际参与 V 加权求和的位置是:
n_hca
所以 HCA 的直觉是:
n → n_hca
其中通常可以理解为:
n_hca < n_csa
也就是说,HCA 压缩后的序列更短。
因此:
CSA:
先压缩到一个中等长度,再通过 top-k 进一步减少参与计算的位置。
HCA:
直接压缩到更短长度,然后在这个短序列上全部参与计算。
可以用一个更大的例子说明。
假设原始 KV cache 有:
1,000,000 个 KV 位置
CSA 可以粗略理解为:
1,000,000 个原始 KV
↓
压缩成 100,000 个 compressed KV
↓
sparse attention 只选 top-k 个重要位置
如果 top-k 是 1,024,那么 CSA 最后真正参与加权求和的是:
1,024 个 compressed Value
也就是说,CSA 的特点是:
保留较多 compressed 候选位置,
但最后只选择其中最相关的少数位置。
HCA 可以粗略理解为:
1,000,000 个原始 KV
↓
更强压缩成 10,000 个 heavily compressed KV
↓
dense attention 使用这 10,000 个 heavily compressed KV
也就是说,HCA 的特点是:
候选位置本身已经非常少,
所以可以把这些压缩后的位置全部看一遍。
这里的 100,000、10,000、1,024 都是为了帮助理解的示意数字,不代表官方固定参数。
为什么 DeepSeek-V4 需要同时有 CSA 和 HCA?
因为二者解决的问题不完全一样。
CSA 更像是:
从较细粒度的压缩表示里,挑出和当前 Query 最相关的重点信息。
它适合保留一些细节相关性。
例如,长文档或长代码里某些片段和当前问题高度相关,CSA 可以通过 top-k 选择把它们挑出来。
HCA 更像是:
提供一个更粗粒度、更全局的上下文背景。
它把很长的 KV cache 压缩得更短,然后在这个较短的全局压缩表示上做 dense attention。虽然每个 compressed KV 更粗糙,但它能让模型看到更完整的压缩后全局信息。
所以可以这样理解:
CSA 更偏向“找重点细节”。
HCA 更偏向“保留全局背景”。
二者是互补关系,不是谁完全替代谁。
再用读书来类比。
Dense attention 像是:
把整本书每一句都仔细看。
Sparse attention 像是:
从整本书里挑最相关的几句话看。
CSA 像是:
先把整本书按段落压缩成段落摘要,
再从这些段落摘要里挑最相关的几个看。
HCA 像是:
先把整本书压缩成更短的章节摘要,
然后把所有章节摘要都看一遍。
所以:
CSA:摘要比较细,但只挑重点看。
HCA:摘要更粗,但压缩后的摘要全部看。
这个类比虽然不等于真实实现,但可以帮助理解二者的差异。
因此,CSA 和 HCA 的区别可以总结为:
| 模块 | 压缩程度 | Attention 方式 | 实际访问对象 | 直观理解 |
|---|---|---|---|---|
| CSA | 中等压缩 | Sparse attention | top-k 个 compressed K/V | 压缩后只看重点 |
| HCA | 更强压缩 | Dense attention | 所有 heavily compressed K/V | 压缩得更狠,但压缩后的都看 |
更具体地说:
CSA = compression + sparse attention
意思是:
先压缩 K/V,
再从压缩后的 K/V 中选 top-k 个重要位置。
而:
HCA = heavier compression + dense attention
意思是:
先对 K/V 做更强压缩,
再让所有压缩后的 K/V 位置都参与 attention。
所以中文可以写成:
CSA = 压缩 + 稀疏注意力
HCA = 更强压缩 + 压缩空间内的密集注意力
这里最好把 HCA 写成“压缩空间内的密集注意力”,因为它不是对原始百万 token 做 dense attention,而是在更短的 heavily compressed KV 上做 dense attention。
一句话总结:
CSA 是先把 KV cache 压短,再从压缩后的 K/V 里挑重点;
HCA 是把 KV cache 压得更短,然后让压缩后的 K/V 全部参与 attention。
因此,HCA 的 dense attention 并不和高效率矛盾,因为它 dense 的对象已经不是原始长序列,而是强压缩后的短 K/V 序列。
最后总结
在 attention 中,Q、K、V 的作用可以这样记:
Q:当前 token 想找什么
K:每个 token 提供什么索引
V:真正被拿走的信息内容
Dense attention 的流程是:
Q 匹配所有 K
↓
softmax 得到所有位置的权重
↓
加权求和所有 V
Sparse attention 的流程是:
Q 匹配 K
↓
只保留重要 K
↓
只加权对应的 V
Compressed Sparse Attention 的流程是:
先压缩 KV cache
↓
Q 匹配 compressed K
↓
只保留重要 compressed K
↓
只加权对应的 compressed V
一句话总结:
Dense attention 是“所有位置都参与计算”;sparse attention 是“只让重要位置参与计算”;Compressed Sparse Attention 则是在长上下文场景下先压缩 KV cache,再在压缩后的表示中做稀疏注意力,从而降低计算和存储压力。
5.3 CSA 的 toy-level 代码理解
下面给出一个非官方 toy-level CSA 实现。它的核心步骤是:先压缩 K/V,再计算 Q 和压缩后 K 的相似度,最后只保留 top-k 位置参与加权求和。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SequenceCompressor(nn.Module):
def __init__(self, dim: int, compression_ratio: int = 4):
super().__init__()
if compression_ratio <= 0:
raise ValueError("compression_ratio must be positive")
self.dim = dim
self.compression_ratio = compression_ratio
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 3:
raise ValueError("x must be [batch, seq_len, dim]")
b, n, d = x.shape
if d != self.dim:
raise ValueError(f"expected dim={self.dim}, got dim={d}")
r = self.compression_ratio
pad_len = (r - n % r) % r
if pad_len > 0:
pad = torch.zeros(b, pad_len, d, device=x.device, dtype=x.dtype)
x = torch.cat([x, pad], dim=1)
n2 = x.shape[1]
x = x.view(b, n2 // r, r, d).mean(dim=2)
return self.proj(x)
class ToyCompressedSparseAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
compression_ratio: int = 4,
topk: int = 8,
):
super().__init__()
if dim % num_heads != 0:
raise ValueError("dim must be divisible by num_heads")
if topk <= 0:
raise ValueError("topk must be positive")
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.topk = topk
self.q_proj = nn.Linear(dim, dim)
self.k_compressor = SequenceCompressor(dim, compression_ratio)
self.v_compressor = SequenceCompressor(dim, compression_ratio)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
x = x.view(b, n, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
b, h, n, d = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(b, n, h * d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q = self.split_heads(self.q_proj(x))
k_base = self.k_compressor(x)
v_base = self.v_compressor(x)
k = self.split_heads(self.k_proj(k_base))
v = self.split_heads(self.v_proj(v_base))
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
topk = min(self.topk, scores.shape[-1])
topk_values, topk_indices = torch.topk(scores, k=topk, dim=-1)
sparse_scores = torch.full_like(scores, float("-inf"))
sparse_scores.scatter_(-1, topk_indices, topk_values)
attn = F.softmax(sparse_scores, dim=-1)
out = torch.matmul(attn, v)
out = self.merge_heads(out)
return self.o_proj(out)
if __name__ == "__main__":
x = torch.randn(2, 1024, 256)
csa = ToyCompressedSparseAttention(
dim=256,
num_heads=8,
compression_ratio=4,
topk=16,
)
y = csa(x)
print("input shape:", x.shape)
print("output shape:", y.shape)
6. 核心技术三:HCA,Heavily Compressed Attention
6.1 HCA 的基本思想
HCA 的全称是 Heavily Compressed Attention,可以翻译为“重压缩注意力”或“高度压缩注意力”。与 CSA 不同,HCA 的特点是对 KV cache 做更强压缩,但仍然保留 dense attention。
可以粗略理解为:
原始 KV cache
|
更激进的压缩
|
更短的 compressed KV cache
|
dense attention
CSA 和 HCA 的区别可以这样理解:
| 模块 | KV 压缩程度 | Attention 方式 | 更适合承担的作用 |
|---|---|---|---|
| CSA | 中等压缩 | Sparse attention | 保留关键细节和相关片段 |
| HCA | 更强压缩 | Dense attention | 提供全局语义背景 |
CSA 更像是“从压缩后的长上下文中找重点”,HCA 更像是“保留一个高度压缩的全局记忆”。二者不是替代关系,而是互补关系。
6.2 HCA 的 toy-level 代码理解
HCA 和 CSA 的主要区别是:CSA 是压缩后再稀疏选择,而 HCA 是压缩得更狠,但在压缩后的 K/V 上仍然做 dense attention。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SequenceCompressor(nn.Module):
def __init__(self, dim: int, compression_ratio: int):
super().__init__()
if compression_ratio <= 0:
raise ValueError("compression_ratio must be positive")
self.dim = dim
self.compression_ratio = compression_ratio
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
r = self.compression_ratio
pad_len = (r - n % r) % r
if pad_len > 0:
pad = torch.zeros(b, pad_len, d, device=x.device, dtype=x.dtype)
x = torch.cat([x, pad], dim=1)
n2 = x.shape[1]
x = x.view(b, n2 // r, r, d).mean(dim=2)
return self.proj(x)
class ToyHeavilyCompressedAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 4, compression_ratio: int = 16):
super().__init__()
if dim % num_heads != 0:
raise ValueError("dim must be divisible by num_heads")
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim)
self.k_compressor = SequenceCompressor(dim, compression_ratio)
self.v_compressor = SequenceCompressor(dim, compression_ratio)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def split_heads(self, x):
b, n, d = x.shape
x = x.view(b, n, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def merge_heads(self, x):
b, h, n, d = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(b, n, h * d)
def forward(self, x):
q = self.split_heads(self.q_proj(x))
compressed_k = self.k_compressor(x)
compressed_v = self.v_compressor(x)
k = self.split_heads(self.k_proj(compressed_k))
v = self.split_heads(self.v_proj(compressed_v))
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = self.merge_heads(out)
return self.o_proj(out)
if __name__ == "__main__":
x = torch.randn(2, 1024, 256)
hca = ToyHeavilyCompressedAttention(
dim=256,
num_heads=8,
compression_ratio=16,
)
y = hca(x)
print("input shape:", x.shape)
print("output shape:", y.shape)
这个例子体现了 HCA 的核心直觉:
原始序列 x
↓
更强压缩 K/V
↓
Q attend compressed K/V
↓
dense attention
它不是 DeepSeek-V4 的官方实现,只是帮助理解“更强压缩 + dense attention”的结构思想。
7. 核心技术四:mHC,Manifold-Constrained Hyper-Connections
7.1 mHC 的基本思想
mHC 的全称是 Manifold-Constrained Hyper-Connections,可以翻译为“流形约束超连接”。
这个名字比较抽象,可以先拆成两部分理解:
Hyper-Connections = 超连接,也就是比普通 residual connection 更丰富的连接方式
Manifold-Constrained = 流形约束,也就是给这些连接方式加上稳定性约束
所以,mHC 可以先粗略理解为:
mHC = 有约束的多通道残差连接
要理解 mHC,需要先理解普通 residual connection。
普通 Transformer 中常见的 residual connection 是:
x_new = x_old + f(x_old)
其中:
x_old = 这一层输入进来的原始 hidden state
f(x_old) = attention 或 MLP 处理后得到的新信息
x_new = 加上新信息后的 hidden state
它的意思是:
新的表示 = 原来的表示 + 这一层学到的修改量
举个简单例子:
x_old = [1.0, 2.0]
f(x_old) = [0.1, -0.2]
那么:
x_new = x_old + f(x_old)
= [1.0, 2.0] + [0.1, -0.2]
= [1.1, 1.8]
可以看到,输出不是完全替换原来的 x_old,而是在原来的基础上做了一点修改。
这就是 residual connection 的核心:
不要让每一层都完全重写 hidden state,
而是在原始表示的基础上补充新信息。
可以用改论文来类比:
x_old = 原来的论文
f(x_old) = 修改意见
x_new = 原来的论文 + 修改意见
也就是说,residual connection 会保留原来的信息,同时加入当前层学到的新信息。
这种结构有两个重要好处。
第一个好处是保留信息。
如果没有 residual connection,每一层都可能把前面的信息覆盖掉。模型越深,早期信息越容易丢失。
有 residual connection 之后,原始信息可以通过这条路径继续往后传:
x_old 直接传到 x_new
第二个好处是训练更稳定。
在深层神经网络中,梯度需要从后面一层一层传回前面。如果每一层都很复杂,梯度可能会变弱或不稳定。
residual connection 给梯度提供了一条更直接的通路:
后面的梯度
↓
通过 residual path 回到前面的 hidden state
所以深层 Transformer 才能更容易训练。
普通 residual connection 可以理解成一条主要的信息通道:
x0 → x1 → x2 → x3 → x4
每一层都沿着这条通道更新:
上一层 hidden state
↓
加上 attention 或 MLP 产生的新信息
↓
下一层 hidden state
这种方式稳定、简单,但也有一个问题:
信息主要沿着一条 residual stream 传播,连接方式比较单一。
也就是说,普通 residual connection 更像是一条主干道。信息可以稳定地往后走,但路线比较固定。
Hyper-Connections 可以理解为对普通 residual connection 的扩展。
普通 residual connection 是一条信息通道:
一条 residual stream:
x0 → x1 → x2 → x3
Hyper-Connections 则可以理解为多条信息通道:
多条 residual streams:
stream 1: s1_0 → s1_1 → s1_2 → s1_3
stream 2: s2_0 → s2_1 → s2_2 → s2_3
stream 3: s3_0 → s3_1 → s3_2 → s3_3
直观理解就是:
普通 residual connection:
一条路传信息。
Hyper-Connections:
多条路一起传信息。
这样做的好处是,模型的信息流动方式更丰富。
普通 residual connection 里,每一层主要只能沿着同一条 hidden state 通道继续更新。而 Hyper-Connections 允许模型保留多条 residual streams,让不同信息可以在不同通道中传播。
可以把它类比成写论文时保留多个版本:
普通 residual connection:
只保留一个主版本,每次都在这个版本上修改。
Hyper-Connections:
同时保留多个版本,不同版本可以记录不同方向的信息。
例如:
stream 1 可能更偏向保留原始语义信息
stream 2 可能更偏向保留上下文交互信息
stream 3 可能更偏向保留深层抽象信息
这只是帮助理解的类比,不代表真实模型中每条 stream 一定有这样明确的语义分工。
Hyper-Connections 的目标是:
让深层模型拥有更丰富的信息流动路径。
但是,Hyper-Connections 也会带来一个问题:连接方式变多以后,信息流可能变得不稳定。
普通 residual connection 虽然简单,但它有一个很重要的稳定性来源:
x_old 可以比较直接地传到 x_new
也就是说,模型至少有一条接近“原样保留”的路径。
这条路径可以理解为:
identity path
中文可以叫:
恒等映射路径
它的意思是:
即使当前层学到的新信息 f(x) 不够好,
原始信息 x 也可以继续传下去。
但是 Hyper-Connections 增加了更多连接以后,多个信息通道之间可能会互相混合。
比如:
stream 1 的信息流到 stream 2
stream 2 的信息流到 stream 3
stream 3 的信息又混回 stream 1
如果这种混合没有约束,就可能出现两个问题。
第一个问题是信息被过度放大:
某些通道的信息越来越强
↓
数值不稳定
↓
训练变困难
第二个问题是信息被过度削弱:
某些通道的信息越来越弱
↓
重要信息被冲淡
↓
深层模型难以保留早期信息
所以 Hyper-Connections 的矛盾是:
连接更多,表达能力更强;
但是连接太自由,稳定性可能变差。
mHC 的作用就是解决这个矛盾。
mHC 不是简单地增加更多连接,也不是让多条通道随便混合,而是:
允许多条 residual streams 存在,
但要求这些连接方式满足一定的稳定性约束。
这里的 Manifold-Constrained 可以先不用按很复杂的数学理解。对于入门阅读来说,可以先把它理解成:
给连接方式加规则。
也就是说,mHC 的直观含义是:
Hyper-Connections 提供更多信息通道;
Manifold-Constrained 给这些通道之间的连接加规则;
mHC 让信息可以多路径流动,但不至于乱流。
可以用交通来类比:
普通 residual connection:
只有一条主路,稳定,但路线少。
Hyper-Connections:
开了很多条路,路线更多,但如果没有规则,容易混乱。
mHC:
开很多条路,同时加交通规则,让信息能多路径流动,但整体仍然稳定。
这个类比能帮助理解 mHC 的核心思想:
不是只追求更多连接,
而是追求有约束、更稳定的连接。
因此,mHC 和普通 residual connection 的关系可以这样理解:
| 结构 | 信息通道 | 连接特点 | 优点 | 风险 |
|---|---|---|---|---|
| Residual Connection | 一条主要通道 | x_old 直接加到输出上 | 简单、稳定、容易训练 | 信息流动方式较单一 |
| Hyper-Connections | 多条通道 | 多个 residual streams 之间可以交互 | 表达能力更强,信息路径更多 | 连接太自由可能不稳定 |
| mHC | 有约束的多条通道 | 多通道连接受到约束 | 同时增强表达能力和稳定性 | 结构更复杂 |
从 DeepSeek-V4 的角度看,CSA 和 HCA 主要解决的是长上下文 attention 的计算和 KV cache 压力;而 mHC 解决的是深层模型中信息怎么稳定传播的问题。
可以这样区分:
CSA / HCA:
解决“长上下文中 attention 怎么算得更省”。
mHC:
解决“很深的模型中信息怎么传得更稳”。
这里还要注意,mHC 不能简单等同于普通 gated residual。
普通 gated residual 通常可以写成:
x_new = x_old + gate × f(x_old)
它的意思是给当前层新增信息加一个门控,控制 f(x_old) 加多少。
但是 mHC 的重点不是简单控制一个 f(x_old) 加多少,而是扩展 residual connection 的连接结构,让信息可以通过多个 residual streams 传播,并且对这些连接加约束。
所以:
Gated residual:
主要是在一条 residual stream 上控制新增信息的比例。
mHC:
关注多条 residual streams 之间如何稳定连接。
因此,博客里的 toy gated residual 只能用来帮助理解“增强残差路径”的动机,不能说它就是 mHC 的真实实现。
更准确的表述应该是:
本文后面的 ToyGatedResidual 不是 mHC。
它只是用一个简单代码例子说明:在普通 residual connection 之外,可以加入额外机制来调节信息流。
真实 mHC 涉及 Hyper-Connections 和 manifold constraint,不能用几行 gated residual 代码完整复现。
一句话总结:
Residual connection 是一条稳定的信息直通路径;
Hyper-Connections 把这条路径扩展成多条信息通道;
mHC 则是在多条通道之间加入约束规则,
让模型既能获得更丰富的信息流动方式,又能保持深层训练的稳定性。
所以,mHC 的重点不是 attention 怎么算,也不是 KV cache 怎么压缩,而是:
深层 Transformer 中,信息如何在层与层之间更稳定、更灵活地传播。
7.2 mHC 的 toy-level 代码理解
mHC 不是普通 gated residual,也不是简单的残差相加。真实 mHC 涉及更复杂的连接约束和表示传播方式。为了帮助理解“增强 residual connection”这个动机,可以先写一个极简的 gated residual 作为类比。
普通 residual connection 是:
x = x + f(x)
toy gated residual 可以写成:
x = x + gate(x) * f(x)
代码如下:
import torch
import torch.nn as nn
class ToyGatedResidual(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(dim, dim),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor, fx: torch.Tensor) -> torch.Tensor:
return x + self.gate(x) * fx
class ToyBlockWithEnhancedResidual(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
self.residual = ToyGatedResidual(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
fx = self.ffn(self.norm(x))
return self.residual(x, fx)
if __name__ == "__main__":
x = torch.randn(2, 128, 256)
block = ToyBlockWithEnhancedResidual(dim=256)
y = block(x)
print("input shape:", x.shape)
print("output shape:", y.shape)
需要强调的是:
ToyGatedResidual 不是 mHC。
ToyGatedResidual 只是为了帮助理解“增强残差路径”的动机。
更准确地说,mHC 的意义在于:当模型很深、上下文很长、模块很复杂时,模型需要更稳定、更灵活的层间信息流。普通 residual connection 是直接相加,而 mHC 试图让层间连接具有更强的表达能力和更好的训练稳定性。
8. 核心技术五:Muon Optimizer
8.1 Muon Optimizer 的基本思想
DeepSeek-V4 引入了 Muon optimizer,用于提升训练收敛速度和稳定性。
对于大模型训练来说,优化器不是细节,而是核心工程之一。尤其是 DeepSeek-V4 这种模型同时包含超大规模 MoE、长上下文注意力压缩、多精度训练、多种并行策略以及后训练中的 SFT / RL / distillation,训练稳定性非常重要。
从论文角度看,Muon optimizer 的作用可以概括为:
更快收敛 + 更稳定训练
不过,本文不会尝试复现 Muon optimizer,因为完整训练 DeepSeek-V4 级别模型涉及大规模分布式训练系统,不适合用一个简单脚本模拟。
8.2 Optimizer 在训练流程中的位置
优化器和模型结构不同,它不是一个简单模块,不能像 attention 或 MoE 那样用几十行代码准确复现。因此,本文不实现 Muon optimizer。为了帮助理解优化器在训练中的位置,下面用 AdamW 写一个普通训练步骤示意。
import torch
import torch.nn as nn
model = nn.Linear(256, 32000)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=1e-4,
weight_decay=0.01,
)
input_x = torch.randn(4, 256)
target = torch.randint(0, 32000, (4,))
criterion = nn.CrossEntropyLoss()
logits = model(input_x)
loss = criterion(logits, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("loss:", loss.item())
这里用 AdamW 只是演示 optimizer 在训练流程中的作用,不代表 DeepSeek-V4 使用 AdamW,也不代表 Muon 的真实实现。
9. DeepSeek-V4 的效率来自哪里?
DeepSeek-V4 在 1M-token context 场景下,相比 DeepSeek-V3.2 显著降低了单 token 推理 FLOPs 和 KV cache 占用。这背后的原因不是一个技巧,而是多个层面的共同作用。
9.1 注意力层面的优化
传统长上下文 attention 的主要瓶颈是:
长序列 attention 计算成本高
KV cache 占用大
KV cache 读写开销大
DeepSeek-V4 通过 CSA 和 HCA 对 KV cache 进行压缩,并结合稀疏注意力和强压缩 dense attention,从架构层面降低长上下文成本。
9.2 MoE 层面的优化
MoE 的核心思想是:模型总参数很多,但每次前向只激活部分专家。
例如:
DeepSeek-V4-Pro: 1.6T total / 49B activated
DeepSeek-V4-Flash: 284B total / 13B activated
这说明 MoE 的重点是用更大的参数容量提升模型能力,同时控制每次前向计算的实际成本。
9.3 MoE 的 toy-level 代码理解
MoE,Mixture-of-Experts,可以理解为“专家混合”。普通 FFN 是所有 token 都经过同一个 MLP:
x → FFN → output
MoE 则是先用 router 判断每个 token 应该交给哪些 expert:
x → router → top-k experts → weighted sum
下面是一个极简 toy MoE 实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
class ToyExpert(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x):
return self.net(x)
class ToyMoE(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_experts: int = 4,
top_k: int = 2,
):
super().__init__()
if top_k > num_experts:
raise ValueError("top_k must be <= num_experts")
self.num_experts = num_experts
self.top_k = top_k
self.router = nn.Linear(dim, num_experts)
self.experts = nn.ModuleList(
[ToyExpert(dim, hidden_dim) for _ in range(num_experts)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
router_logits = self.router(x)
router_probs = F.softmax(router_logits, dim=-1)
topk_probs, topk_indices = torch.topk(
router_probs,
k=self.top_k,
dim=-1,
)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
output = torch.zeros_like(x)
for expert_id, expert in enumerate(self.experts):
expert_out = expert(x)
weight = torch.zeros(b, n, device=x.device, dtype=x.dtype)
for k in range(self.top_k):
weight = weight + torch.where(
topk_indices[:, :, k] == expert_id,
topk_probs[:, :, k],
torch.zeros_like(topk_probs[:, :, k]),
)
output = output + expert_out * weight.unsqueeze(-1)
return output
if __name__ == "__main__":
x = torch.randn(2, 16, 256)
moe = ToyMoE(
dim=256,
hidden_dim=1024,
num_experts=4,
top_k=2,
)
y = moe(x)
print("input shape:", x.shape)
print("output shape:", y.shape)
这个 toy MoE 体现了 MoE 的核心逻辑:
router 给每个 token 分配 expert
↓
只激活 top-k expert
↓
把多个 expert 的输出加权求和
这就是为什么 MoE 模型可以拥有很大的总参数量,但每次实际激活的参数量相对较小。
9.4 精度层面的优化
DeepSeek-V4 官方模型卡中列出的精度包括:
Base model: FP8 Mixed
Instruct model: FP4 + FP8 Mixed
其中,instruct 模型使用 FP4 + FP8 mixed precision,可以进一步降低存储和计算成本。
9.5 系统工程层面的优化
DeepSeek-V4 技术报告还提到一系列基础设施优化,例如:
- single fused kernel for MoE modules;
- TileLang;
- deterministic kernel libraries;
- tensor-level checkpointing;
- hybrid ZeRO;
- two-stage contextual parallelism;
- heterogeneous KV cache;
- on-disk KV cache storage。
这说明大模型效率不是只靠一个公式,而是来自模型架构、优化器、训练系统、推理系统、精度策略和硬件适配的共同作用。
10. 预训练与后训练流程
DeepSeek-V4 的预训练规模非常大。官方技术报告中提到:
DeepSeek-V4-Flash: 32T tokens
DeepSeek-V4-Pro: 33T tokens
也就是说,两个模型都在超过 32T 的高质量 token 上进行了预训练。
后训练方面,DeepSeek-V4 使用了两阶段范式:
阶段一:独立培养不同领域专家
阶段二:通过 on-policy distillation 进行统一模型整合
可以粗略理解为:先针对数学、代码、agent、指令跟随等领域训练不同专家能力,再通过统一蒸馏,把不同领域能力整合到一个模型中。
这对 agent 任务尤其重要,因为 agent 任务通常不是单一步骤,而是包含任务理解、计划生成、工具调用、文件读取、代码修改、测试反馈、错误修复和多轮反思。因此,DeepSeek-V4 的后训练目标不是只提升普通问答,而是进一步强化长程推理和 agentic capabilities。
10.1 MTP 的基本思想
MTP 的全称是 Multi-Token Prediction,可以翻译为“多 token 预测”。普通语言模型训练时,通常是预测下一个 token:
输入: 我 喜欢
预测: 学习
也就是:
predict next token
MTP 的直觉是:模型不仅预测下一个 token,还可以同时预测后面多个 token,从而提供更丰富的训练信号。
例如:
输入: 我 喜欢
预测: 学习 深度 模型
10.2 MTP 的 toy-level 代码理解
下面是一个非常简化的 toy MTP head:
import torch
import torch.nn as nn
class ToyMTPHead(nn.Module):
def __init__(self, dim: int, vocab_size: int, num_future_tokens: int = 3):
super().__init__()
self.num_future_tokens = num_future_tokens
self.heads = nn.ModuleList(
[
nn.Linear(dim, vocab_size)
for _ in range(num_future_tokens)
]
)
def forward(self, hidden_states: torch.Tensor):
logits_list = []
for head in self.heads:
logits = head(hidden_states)
logits_list.append(logits)
return logits_list
if __name__ == "__main__":
batch_size = 2
seq_len = 8
dim = 256
vocab_size = 32000
hidden_states = torch.randn(batch_size, seq_len, dim)
mtp_head = ToyMTPHead(
dim=dim,
vocab_size=vocab_size,
num_future_tokens=3,
)
logits_list = mtp_head(hidden_states)
for i, logits in enumerate(logits_list):
print(f"predict t+{i+1} logits shape:", logits.shape)
如果要计算 toy MTP loss,可以这样理解:
import torch
import torch.nn.functional as F
def toy_mtp_loss(logits_list, input_ids):
total_loss = 0.0
valid_heads = 0
for i, logits in enumerate(logits_list):
shift = i + 1
if input_ids.shape[1] <= shift:
continue
pred_logits = logits[:, :-shift, :].contiguous()
labels = input_ids[:, shift:].contiguous()
loss = F.cross_entropy(
pred_logits.view(-1, pred_logits.size(-1)),
labels.view(-1),
)
total_loss = total_loss + loss
valid_heads += 1
return total_loss / max(valid_heads, 1)
if __name__ == "__main__":
batch_size = 2
seq_len = 8
vocab_size = 32000
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
logits_list = [
torch.randn(batch_size, seq_len, vocab_size),
torch.randn(batch_size, seq_len, vocab_size),
torch.randn(batch_size, seq_len, vocab_size),
]
loss = toy_mtp_loss(logits_list, input_ids)
print("toy mtp loss:", loss.item())
MTP 的直觉可以总结为:
普通 LM:
当前位置只预测下一个 token。
MTP:
当前位置同时预测未来多个 token。
这可以增加训练信号密度,也有助于模型学习更长跨度的 token 关系。这里的代码只是 toy-level 理解,不代表 DeepSeek-V4 的真实 MTP 实现。
11. 为什么 DeepSeek-V4 对 Agent 很重要?
DeepSeek 官方发布中特别强调了 agent 能力。这很合理,因为 agent 任务天然依赖长上下文。
一个典型代码修复 agent 的上下文可能包含:
用户问题
仓库结构
相关源码文件
错误日志
测试输出
历史修改
工具调用轨迹
中间推理状态
失败尝试
最终补丁
如果上下文窗口不足,模型很容易丢失关键信息。例如,忘记之前读过的文件,忘记测试失败原因,忘记自己已经尝试过的修复方案,对不同日志之间的关系判断错误,或者无法从长轨迹中定位真正的 bug。
百万 token 上下文的价值在于,它可以让 agent 在更完整的任务轨迹中进行推理,而不是不断依赖外部检索和截断摘要。这对代码智能体、科研智能体、文档智能体、医学文献分析、法律证据链分析等场景都非常重要。
12. 非官方 PyTorch 推测复现说明
下面进入完整 toy 模型部分。
再次强调:
本文代码不是 DeepSeek 官方实现。
本文代码不是 DeepSeek-V4 真实源码。
本文代码不能复现 DeepSeek-V4 的真实性能。
本文代码只用于理解 CSA、HCA、MoE、MTP 和 Hybrid Attention 的结构直觉。
本文完整 toy 模型会组合以下部分:
- Toy CSA:先压缩 K/V,再做 top-k sparse attention;
- Toy HCA:更强压缩 K/V,再做 dense attention;
- Toy Hybrid Attention:用 gate 融合 CSA 和 HCA 输出;
- Toy MoE:用 router 选择 top-k experts;
- Toy Gated Residual:用作 mHC 的非官方类比;
- Toy MTP Head:同时预测未来多个 token。
为了保持代码简单,本文没有实现真实 DeepSeek Sparse Attention、真实 KV cache 管理、真实 mHC、真实 DeepSeekMoE、真实 MTP、FP4 / FP8、fused kernel、TileLang、分布式训练,也没有实现 causal mask 的严格长上下文推理逻辑。
13. 完整 Toy 代码
下面给出完整代码。建议保存为:
toy_deepseek_v4_style_model.py
然后直接运行:
python toy_deepseek_v4_style_model.py
完整代码如下:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SequenceCompressor(nn.Module):
def __init__(self, dim: int, compression_ratio: int = 4):
super().__init__()
if compression_ratio <= 0:
raise ValueError("compression_ratio must be positive")
self.dim = dim
self.compression_ratio = compression_ratio
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() != 3:
raise ValueError("x must be [batch, seq_len, dim]")
b, n, d = x.shape
if d != self.dim:
raise ValueError(f"expected dim={self.dim}, got dim={d}")
r = self.compression_ratio
pad_len = (r - n % r) % r
if pad_len > 0:
pad = torch.zeros(b, pad_len, d, device=x.device, dtype=x.dtype)
x = torch.cat([x, pad], dim=1)
n2 = x.shape[1]
x = x.view(b, n2 // r, r, d).mean(dim=2)
return self.proj(x)
class ToyCompressedSparseAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
compression_ratio: int = 4,
topk: int = 8,
):
super().__init__()
if dim % num_heads != 0:
raise ValueError("dim must be divisible by num_heads")
if topk <= 0:
raise ValueError("topk must be positive")
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.topk = topk
self.q_proj = nn.Linear(dim, dim)
self.k_compressor = SequenceCompressor(dim, compression_ratio)
self.v_compressor = SequenceCompressor(dim, compression_ratio)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
x = x.view(b, n, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
b, h, n, d = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(b, n, h * d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q = self.split_heads(self.q_proj(x))
k_base = self.k_compressor(x)
v_base = self.v_compressor(x)
k = self.split_heads(self.k_proj(k_base))
v = self.split_heads(self.v_proj(v_base))
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
topk = min(self.topk, scores.shape[-1])
topk_values, topk_indices = torch.topk(scores, k=topk, dim=-1)
sparse_scores = torch.full_like(scores, float("-inf"))
sparse_scores.scatter_(-1, topk_indices, topk_values)
attn = F.softmax(sparse_scores, dim=-1)
out = torch.matmul(attn, v)
out = self.merge_heads(out)
return self.o_proj(out)
class ToyHeavilyCompressedAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
compression_ratio: int = 16,
):
super().__init__()
if dim % num_heads != 0:
raise ValueError("dim must be divisible by num_heads")
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = nn.Linear(dim, dim)
self.k_compressor = SequenceCompressor(dim, compression_ratio)
self.v_compressor = SequenceCompressor(dim, compression_ratio)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def split_heads(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
x = x.view(b, n, self.num_heads, self.head_dim)
return x.transpose(1, 2)
def merge_heads(self, x: torch.Tensor) -> torch.Tensor:
b, h, n, d = x.shape
x = x.transpose(1, 2).contiguous()
return x.view(b, n, h * d)
def forward(self, x: torch.Tensor) -> torch.Tensor:
q = self.split_heads(self.q_proj(x))
k_base = self.k_compressor(x)
v_base = self.v_compressor(x)
k = self.split_heads(self.k_proj(k_base))
v = self.split_heads(self.v_proj(v_base))
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = self.merge_heads(out)
return self.o_proj(out)
class ToyHybridAttention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
csa_compression: int = 4,
hca_compression: int = 16,
topk: int = 8,
):
super().__init__()
self.csa = ToyCompressedSparseAttention(
dim=dim,
num_heads=num_heads,
compression_ratio=csa_compression,
topk=topk,
)
self.hca = ToyHeavilyCompressedAttention(
dim=dim,
num_heads=num_heads,
compression_ratio=hca_compression,
)
self.gate = nn.Sequential(
nn.Linear(dim, dim),
nn.Sigmoid(),
)
self.out_proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
csa_out = self.csa(x)
hca_out = self.hca(x)
gate = self.gate(x)
y = gate * csa_out + (1.0 - gate) * hca_out
return self.out_proj(y)
class ToyGatedResidual(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.gate = nn.Sequential(
nn.Linear(dim, dim),
nn.Sigmoid(),
)
def forward(self, x: torch.Tensor, fx: torch.Tensor) -> torch.Tensor:
return x + self.gate(x) * fx
class ToyExpert(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class ToyMoE(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
num_experts: int = 4,
top_k: int = 2,
):
super().__init__()
if top_k > num_experts:
raise ValueError("top_k must be <= num_experts")
self.num_experts = num_experts
self.top_k = top_k
self.router = nn.Linear(dim, num_experts)
self.experts = nn.ModuleList(
[ToyExpert(dim, hidden_dim) for _ in range(num_experts)]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, n, d = x.shape
router_logits = self.router(x)
router_probs = F.softmax(router_logits, dim=-1)
topk_probs, topk_indices = torch.topk(
router_probs,
k=self.top_k,
dim=-1,
)
topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True)
output = torch.zeros_like(x)
for expert_id, expert in enumerate(self.experts):
expert_out = expert(x)
weight = torch.zeros(b, n, device=x.device, dtype=x.dtype)
for k in range(self.top_k):
weight = weight + torch.where(
topk_indices[:, :, k] == expert_id,
topk_probs[:, :, k],
torch.zeros_like(topk_probs[:, :, k]),
)
output = output + expert_out * weight.unsqueeze(-1)
return output
class ToyTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 4,
moe_hidden_dim: int = 1024,
num_experts: int = 4,
top_k: int = 2,
):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = ToyHybridAttention(dim=dim, num_heads=num_heads)
self.norm2 = nn.LayerNorm(dim)
self.moe = ToyMoE(
dim=dim,
hidden_dim=moe_hidden_dim,
num_experts=num_experts,
top_k=top_k,
)
self.resid1 = ToyGatedResidual(dim)
self.resid2 = ToyGatedResidual(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
attn_out = self.attn(self.norm1(x))
x = self.resid1(x, attn_out)
moe_out = self.moe(self.norm2(x))
x = self.resid2(x, moe_out)
return x
class ToyMTPHead(nn.Module):
def __init__(self, dim: int, vocab_size: int, num_future_tokens: int = 3):
super().__init__()
self.num_future_tokens = num_future_tokens
self.heads = nn.ModuleList(
[nn.Linear(dim, vocab_size) for _ in range(num_future_tokens)]
)
def forward(self, hidden_states: torch.Tensor):
logits_list = []
for head in self.heads:
logits = head(hidden_states)
logits_list.append(logits)
return logits_list
class ToyDeepSeekV4StyleModel(nn.Module):
def __init__(
self,
vocab_size: int,
dim: int = 256,
num_layers: int = 2,
num_heads: int = 8,
max_seq_len: int = 2048,
moe_hidden_dim: int = 1024,
num_experts: int = 4,
top_k: int = 2,
num_future_tokens: int = 3,
):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(vocab_size, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.blocks = nn.ModuleList(
[
ToyTransformerBlock(
dim=dim,
num_heads=num_heads,
moe_hidden_dim=moe_hidden_dim,
num_experts=num_experts,
top_k=top_k,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(dim)
self.lm_head = nn.Linear(dim, vocab_size, bias=False)
self.mtp_head = ToyMTPHead(
dim=dim,
vocab_size=vocab_size,
num_future_tokens=num_future_tokens,
)
def forward(self, input_ids: torch.Tensor):
if input_ids.dim() != 2:
raise ValueError("input_ids must be [batch, seq_len]")
b, n = input_ids.shape
if n > self.max_seq_len:
raise ValueError(f"seq_len={n} exceeds max_seq_len={self.max_seq_len}")
positions = torch.arange(n, device=input_ids.device).unsqueeze(0).expand(b, n)
x = self.token_emb(input_ids) + self.pos_emb(positions)
for block in self.blocks:
x = block(x)
hidden_states = self.norm(x)
next_token_logits = self.lm_head(hidden_states)
mtp_logits_list = self.mtp_head(hidden_states)
return next_token_logits, mtp_logits_list
def toy_mtp_loss(logits_list, input_ids):
total_loss = 0.0
valid_heads = 0
for i, logits in enumerate(logits_list):
shift = i + 1
if input_ids.shape[1] <= shift:
continue
pred_logits = logits[:, :-shift, :].contiguous()
labels = input_ids[:, shift:].contiguous()
loss = F.cross_entropy(
pred_logits.view(-1, pred_logits.size(-1)),
labels.view(-1),
)
total_loss = total_loss + loss
valid_heads += 1
return total_loss / max(valid_heads, 1)
def main():
torch.manual_seed(42)
vocab_size = 32000
batch_size = 2
seq_len = 128
dim = 256
num_layers = 2
num_heads = 8
model = ToyDeepSeekV4StyleModel(
vocab_size=vocab_size,
dim=dim,
num_layers=num_layers,
num_heads=num_heads,
max_seq_len=2048,
moe_hidden_dim=1024,
num_experts=4,
top_k=2,
num_future_tokens=3,
)
input_ids = torch.randint(
low=0,
high=vocab_size,
size=(batch_size, seq_len),
)
next_token_logits, mtp_logits_list = model(input_ids)
print("input_ids shape:", input_ids.shape)
print("next_token_logits shape:", next_token_logits.shape)
for i, logits in enumerate(mtp_logits_list):
print(f"mtp logits t+{i+1} shape:", logits.shape)
loss = toy_mtp_loss(mtp_logits_list, input_ids)
print("toy mtp loss:", loss.item())
if __name__ == "__main__":
main()
预期输出类似:
input_ids shape: torch.Size([2, 128])
next_token_logits shape: torch.Size([2, 128, 32000])
mtp logits t+1 shape: torch.Size([2, 128, 32000])
mtp logits t+2 shape: torch.Size([2, 128, 32000])
mtp logits t+3 shape: torch.Size([2, 128, 32000])
toy mtp loss: ...
14. 代码结构解释
上面的完整 toy 代码主要包含以下模块。
14.1 SequenceCompressor
SequenceCompressor 用 mean pooling 模拟沿序列维度压缩:
[batch, seq_len, dim]
|
按 compression_ratio 分块
|
mean pooling
|
[batch, compressed_len, dim]
如果 seq_len = 1024,compression_ratio = 4,那么压缩后的长度大约是 256。真实 DeepSeek-V4 的 KV cache 压缩方式会复杂得多,这里只是为了展示“沿序列维度压缩”的基本直觉。
14.2 ToyCompressedSparseAttention
ToyCompressedSparseAttention 用来模拟 CSA 的基本思想:
原始序列 x
|
生成 Q
|
压缩 x 得到 compressed K/V
|
Q attend compressed K/V
|
top-k 稀疏选择
它体现的是:不让每个 token attend 到完整历史,而是在压缩后的表示中选择重要位置。
14.3 ToyHeavilyCompressedAttention
ToyHeavilyCompressedAttention 用来模拟 HCA 的基本思想。它使用更大的压缩比例,然后在更短的 compressed K/V 上做 dense attention。
代码差异可以简单概括为:
CSA: 压缩 K/V + top-k sparse attention
HCA: 更强压缩 K/V + dense attention
14.4 ToyHybridAttention
ToyHybridAttention 把 Toy CSA 和 Toy HCA 组合起来:
output = gate * CSA(x) + (1 - gate) * HCA(x)
真实 DeepSeek-V4 的融合方式不一定是这样。这里的 gate 只是为了帮助理解“稀疏细节”和“全局压缩语义”的互补关系。
14.5 ToyMoE
ToyMoE 用 router 为每个 token 选择 top-k experts,再把 expert 输出加权求和。它模拟的是 MoE 的核心思想:模型可以拥有多个 expert,但每次只激活其中一部分。
14.6 ToyGatedResidual
DeepSeek-V4 使用的是 mHC,即 Manifold-Constrained Hyper-Connections。本文代码中的 ToyGatedResidual 不是 mHC,只是为了帮助理解“增强残差路径”这个动机。
这一点必须明确:
ToyGatedResidual 不是 mHC。
ToyGatedResidual 只是为了帮助理解增强 residual connection 的一种 toy 写法。
14.7 ToyMTPHead
ToyMTPHead 用多个 head 同时预测未来多个 token。它模拟的是 MTP 的基本直觉:当前位置不仅预测下一个 token,也可以预测更远的未来 token,从而提供更密集的训练信号。
15. 复杂度直觉分析
假设原始序列长度为 n,压缩比例为 r。
普通 dense attention 的复杂度可以粗略理解为:
O(n²)
如果把 K/V 压缩到 n/r,那么 attention 复杂度可以粗略变成:
O(n × n/r)
如果再做 top-k 稀疏选择,那么复杂度可以进一步接近:
O(n × k)
当然,真实 DeepSeek-V4 的效率不能只用这个公式解释,因为还涉及:
- KV cache 的实际存储布局;
- 稀疏 attention 的 kernel 实现;
- MoE expert routing;
- GPU / NPU 内存带宽;
- distributed communication;
- FP4 / FP8 mixed precision;
- fused kernel;
- heterogeneous KV cache;
- on-disk KV cache。
因此,本文的复杂度分析只是直觉说明,不是对官方系统的精确复现。
16. 本文 toy 代码和官方 DeepSeek-V4 的差距
本文代码和官方 DeepSeek-V4 差距非常大,主要体现在:
| 项目 | 本文 toy 代码 | 官方 DeepSeek-V4 |
|---|---|---|
| 模型规模 | 几层小模型 | 284B / 1.6T MoE |
| Attention | 简化压缩 + top-k | CSA + HCA + DSA |
| KV cache | 无真实缓存系统 | 复杂 KV cache 压缩与管理 |
| Residual | 简单 gated residual | mHC |
| FFN | Toy MoE | DeepSeekMoE |
| MTP | Toy MTP head | 官方 MTP 机制 |
| 优化器 | 未训练 | Muon optimizer |
| 精度 | FP32 / 默认 PyTorch | FP8 Mixed / FP4 + FP8 Mixed |
| Kernel | PyTorch eager | fused kernel / TileLang 等 |
| 训练 | 无真实训练 | 32T / 33T tokens 级别预训练 |
| 推理 | toy forward | 1M context 高效推理系统 |
所以,本文代码只能叫:
非官方 toy-level 推测实现
不能叫:
DeepSeek-V4 官方源码复现
也不能叫:
DeepSeek-V4 完整复现
17. 我对 DeepSeek-V4 的理解
DeepSeek-V4 的意义可以从三个角度理解。
17.1 长上下文从“能放进去”走向“能高效用起来”
很多模型都在扩大上下文窗口,但上下文长度变长以后,真正困难的是:
推理成本
KV cache 显存
长程信息利用能力
DeepSeek-V4 的重点是通过 CSA、HCA 和系统优化,让百万 token 上下文不只是形式上支持,而是更接近工程可用。
17.2 MoE 继续成为大模型扩展的重要路线
DeepSeek-V4-Pro 总参数达到 1.6T,但激活参数为 49B。这说明 MoE 的核心优势仍然很明显:
总容量大
激活成本相对可控
对于需要更强知识、推理、代码和 agent 能力的大模型来说,MoE 仍然是一条重要路线。
17.3 Agent 会成为长上下文模型的重要应用场景
DeepSeek-V4 官方发布中特别提到 agent 能力。这说明长上下文不只是为了“读长文档”,更是为了支撑复杂 agent workflow。
未来很多模型能力可能不只体现在单轮问答,而是体现在:
- 能不能读完整仓库;
- 能不能持续调用工具;
- 能不能根据测试结果修复代码;
- 能不能在长任务中保持状态;
- 能不能从长轨迹中定位错误原因。
这也是百万 token context 的真正价值。
18. 总结
DeepSeek-V4 是一次面向百万 token 上下文的系统级模型升级。
它的核心不是简单“参数更大”,而是通过以下组合提升长上下文效率:
Hybrid Attention Architecture
= CSA + HCA
CSA
= KV cache compression + DeepSeek Sparse Attention
HCA
= heavier KV cache compression + dense attention
DeepSeekMoE
= large total parameters + limited activated parameters
MTP
= multi-token prediction
mHC
= enhanced residual connections
Muon optimizer
= faster convergence + improved training stability
System optimization
= fused kernel + TileLang + ZeRO + checkpointing + heterogeneous KV cache + on-disk KV cache
如果用一句话总结 DeepSeek-V4:
DeepSeek-V4 面向百万 token 长上下文和 agentic intelligence,通过 KV cache 压缩、稀疏注意力、MoE、MTP、mHC、Muon optimizer 和系统级推理优化,共同降低超长上下文推理的计算与存储成本。
本文的 toy PyTorch 代码只是为了帮助理解 Hybrid Attention、MoE、MTP 和增强残差路径的基本思路,不能代表官方实现。
更多推荐

所有评论(0)