Muon 优化器深度剖析续:数学公式与 MuonClip 变体

大家好,继续上一篇博客深入剖析 Muon 优化器(一):从基础原理到 Kimi K2 大模型的应用的讨论。今天,我们深入到 Muon 优化器的数学核心,以及它在 Kimi K2 中演化出的 MuonClip 变体。作为续篇,我会聚焦于公式推导、算法细节和设计考虑。如果你对基础概念已熟知,这里将提供更严谨的数学分析;如果你是初学者,我会用比喻和逐步解释,让内容易懂。Muon 的美妙在于它将几何约束融入优化,让训练更高效,但大规模应用需额外稳定机制——这就是 MuonClip 的由来。让我们一步步拆解。

第一部分:Muon 优化器的数学基础与推导

浅显入门:从“几何约束”到公式

Muon 不是简单梯度下降,而是把权重矩阵视为“几何对象”。想象权重 W 是一个地图,输入 x 是起点,输出 y = Wx 是终点。传统优化器如 Adam 只看“坡度”(梯度),Muon 则确保每步更新“平衡”(orthogonal),像在地图上均匀分布路径,避免某些方向被忽略。

核心思想:更新矩阵应保持“谱范数约束”(spectral norm constraint),确保输出变化可控。这源于二阶优化的简化,避免计算昂贵的 Hessian 矩阵。

深入推导:从 RMS 范数到正交化更新

Muon 的推导从线性层的“度量”(metrization)开始。我们用根均方(RMS)范数衡量向量大小:
∥ v ∥ RMS = 1 d ∑ i = 1 d v i 2 = 1 d ∥ v ∥ 2 \|v\|_{\text{RMS}} = \sqrt{\frac{1}{d} \sum_{i=1}^d v_i^2} = \sqrt{\frac{1}{d}} \|v\|_2 vRMS=d1i=1dvi2 =d1 v2
其中 d 是维度。这让“稠密”向量(如元素平均大小为 1)的 RMS 为 1。

对于权重矩阵 W(形状 [fan-out, fan-in]),我们定义 RMS-to-RMS 算子范数:
∥ W ∥ RMS → RMS = fan-in fan-out ∥ W ∥ ∗ \|W\|_{\text{RMS} \to \text{RMS}} = \sqrt{\frac{\text{fan-in}}{\text{fan-out}}} \|W\|_* WRMSRMS=fan-outfan-in W
其中 ∥ W ∥ ∗ \|W\|_* W 是谱范数(最大奇异值)。

现在,考虑权重更新 ΔW 对输出变化的影响: Δ y = Δ W x Δy = ΔW x Δy=ΔWx。我们希望界定:
∥ Δ y ∥ RMS ≤ ∥ Δ W ∥ RMS → RMS ⋅ ∥ x ∥ RMS \|\Delta y\|_{\text{RMS}} \leq \| \Delta W \|_{\text{RMS} \to \text{RMS}} \cdot \|x\|_{\text{RMS}} ∥ΔyRMS∥ΔWRMSRMSxRMS
假设 ∥ x ∥ RMS ≤ 1 \|x\|_{\text{RMS}} ≤ 1 xRMS1,则控制 ∥ Δ W ∥ RMS → RMS ≤ η \| ΔW \|_{\text{RMS} \to \text{RMS}} ≤ η ∥ΔWRMSRMSη(步长)即可稳定输出。

优化问题转为:最小化损失线性近似 ⟨ ∇ W L , Δ W ∇_W \mathcal{L}, ΔW WL,ΔW⟩,受限于上述范数:
min ⁡ Δ W ⟨ ∇ W L , Δ W ⟩ s.t. ∥ Δ W ∥ RMS → RMS ≤ η \min_{\Delta W} \langle \nabla_W \mathcal{L}, \Delta W \rangle \quad \text{s.t.} \quad \| \Delta W \|_{\text{RMS} \to \text{RMS}} \leq \eta ΔWminWL,ΔWs.t.∥ΔWRMSRMSη
假设 ∇ W L = U Σ V T ∇_W \mathcal{L} = U Σ V^T WL=UΣVT(SVD 分解),解为:
Δ W = − η fan-out fan-in U V T \Delta W = - \eta \sqrt{\frac{\text{fan-out}}{\text{fan-in}}} U V^T ΔW=ηfan-infan-out UVT
这本质上是梯度的“对偶化”(dualizing):保留奇异向量 U V T U V^T UVT,丢弃奇异值 Σ Σ Σ,实现正交化(orthogonalization)。 这让更新更均匀,捕捉低秩梯度中的信息。

Muon 的算法公式:Newton-Schulz 迭代

直接 SVD 太慢,Muon 用 Newton-Schulz (NS) 迭代近似正交化。令 G 为更新矩阵(从 SGD-momentum 得来),目标是:
Ortho ( G ) = arg ⁡ min ⁡ O { ∥ O − G ∥ F : O T O = I  或  O O T = I } \text{Ortho}(G) = \arg\min_O \{ \|O - G\|_F : O^T O = I \text{ 或 } O O^T = I \} Ortho(G)=argOmin{OGF:OTO=I  OOT=I}
NS 迭代公式:
X k + 1 = a X k + b ( X k X k T ) X k + c ( X k X k T ) 2 X k X_{k+1} = a X_k + b (X_k X_k^T) X_k + c (X_k X_k^T)^2 X_k Xk+1=aXk+b(XkXkT)Xk+c(XkXkT)2Xk
系数 a=3.4445, b=-4.7750, c=2.0315(五阶多项式近似,确保奇异值收敛到 1)。初始 X = G / ( ∥ G ∥ F + ε ) X = G / (\|G\|_F + ε) X=G/(GF+ε),迭代 5 次。

完整 Muon 更新(带动量):

  1. 动量: M t = μ M t − 1 + G t M_t = μ M_{t-1} + G_t Mt=μMt1+Gt G t G_t Gt 是梯度)
  2. 正交化: O t = N S ( M t ) O_t = NS(M_t) Ot=NS(Mt)
  3. 更新: W t = W t − 1 − η O t W_t = W_{t-1} - η O_t Wt=Wt1ηOt

考虑:NS 迭代高效(仅 5% 开销),但需归一化 G 以确保奇异值在 [0,1] 内收敛。这让 Muon 在小模型上加速 2 倍,但大规模时需额外调整。

第二部分:大规模扩展:添加权重衰减与更新缩放

浅显解释:为什么需要调整?

Muon 在小模型上完美,但万亿参数时,更新可能太“小”(低 RMS,导致表达力不足)或太“大”(不稳定)。添加权重衰减像“刹车”,防止参数爆炸;更新缩放像“放大镜”,确保不同形状矩阵的更新一致。

数学细节与公式

从 Lemma:对于形状 [A,B] 的全秩矩阵,Muon 更新 RMS 为 ( 1 / m a x ( A , B ) ) \sqrt{(1/max(A,B))} (1/max(A,B)) 。 为一致性,缩放 O t O_t Ot
O t = NS ( M t ) ⋅ max ⁡ ( A , B ) ⋅ 0.2 O_t = \text{NS}(M_t) \cdot \sqrt{\max(A,B)} \cdot 0.2 Ot=NS(Mt)max(A,B) 0.2
(0.2 匹配 AdamW 的 RMS)。

添加权重衰减(AdamW 风格):
W t = W t − 1 − η ( O t + λ W t − 1 ) W_t = W_{t-1} - \eta (O_t + \lambda W_{t-1}) Wt=Wt1η(Ot+λWt1)
其中 λ 是衰减率。

实验考虑:这些调整让 Muon 复用 AdamW 的超参数,无需重调。在 800M 模型上,带调整的 Muon 损失更低,过拟合更少。

第三部分:MuonClip 变体——Kimi K2 的稳定利器

浅显入门:MuonClip 是“安全网”

Muon 的正交化放大梯度,导致注意力 logits(QK 点积)爆炸,像开车太快撞墙。MuonClip 添加“剪枝”(clip),监控并缩放 Q/K 权重,防止崩溃。Kimi K2 用 15.5 万亿令牌训练零不稳定,全靠它。

深入数学:MuonClip 的公式与整合

MuonClip 整合 Muon 与 QK-Clip。注意力公式:
Q h = X W q h , K h = X W k h \mathbf{Q}^h = \mathbf{X} \mathbf{W}_q^h, \quad \mathbf{K}^h = \mathbf{X} \mathbf{W}_k^h Qh=XWqh,Kh=XWkh
logits: ( 1 / √ d ) Q h K h T (1/√d) Q^h K^{h T} (1/√d)QhKhT
最大 logit:
S max ⁡ h = 1 d max ⁡ X ∈ B max ⁡ i , j Q i h K j h ⊤ S_{\max}^h = \frac{1}{\sqrt{d}} \max_{\mathbf{X} \in B} \max_{i,j} \mathbf{Q}_i^h \mathbf{K}_j^{h \top} Smaxh=d 1XBmaxi,jmaxQihKjh

如果 S m a x h > τ S_{max}^h > τ Smaxh>τ(阈值,如 100),计算 γ h = m i n ( 1 , τ / S m a x h ) γ_h = min(1, τ / S_{max}^h) γh=min(1,τ/Smaxh),然后缩放:
W q h ← γ h α W q h , W k h ← γ h 1 − α W k h \mathbf{W}_q^h \leftarrow \gamma_h^\alpha \mathbf{W}_q^h, \quad \mathbf{W}_k^h \leftarrow \gamma_h^{1-\alpha} \mathbf{W}_k^h WqhγhαWqh,Wkhγh1αWkh
α ≈ 0.5(平衡缩放)。对于 MLA(多头潜在注意力),仅缩放头特定组件。

整合流程:

  1. 执行 Muon 更新(如上,带衰减和 RMS 匹配)。
  2. 后处理:应用 QK-Clip,仅当 logits 超阈值时干预(Kimi K2 中仅 13% 头触发,早中期后稳定)。

考虑:Muon 的高秩更新(uniform pressure)提升令牌效率,但易致参数增长和 logits 爆炸(>1000)。QK-Clip 是负反馈机制,精确缩放而不扭曲分布,比软帽(soft-cap)或 QK-Norm 更适 MLA。结果:稳定训练,损失无峰值,性能超 GPT-4。

结语:Muon 与 MuonClip 的启示

Muon 的几何视角(正交化 + 范数约束)重塑优化,公式简洁却强大;MuonClip 的 clip 机制则解决实际痛点,让大规模训练可行。未来,或许结合更多二阶近似或自适应阈值。感谢阅读,如果你有公式疑问,欢迎讨论!🚀

后记

2025年8月14日早上7:06于上海,在Grok 4大模型辅助下完成。

Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐