Muon 优化器深度剖析(二):数学公式与 MuonClip 变体
Muon 优化器深度剖析续:数学公式与 MuonClip 变体
大家好,继续上一篇博客深入剖析 Muon 优化器(一):从基础原理到 Kimi K2 大模型的应用的讨论。今天,我们深入到 Muon 优化器的数学核心,以及它在 Kimi K2 中演化出的 MuonClip 变体。作为续篇,我会聚焦于公式推导、算法细节和设计考虑。如果你对基础概念已熟知,这里将提供更严谨的数学分析;如果你是初学者,我会用比喻和逐步解释,让内容易懂。Muon 的美妙在于它将几何约束融入优化,让训练更高效,但大规模应用需额外稳定机制——这就是 MuonClip 的由来。让我们一步步拆解。
第一部分:Muon 优化器的数学基础与推导
浅显入门:从“几何约束”到公式
Muon 不是简单梯度下降,而是把权重矩阵视为“几何对象”。想象权重 W 是一个地图,输入 x 是起点,输出 y = Wx 是终点。传统优化器如 Adam 只看“坡度”(梯度),Muon 则确保每步更新“平衡”(orthogonal),像在地图上均匀分布路径,避免某些方向被忽略。
核心思想:更新矩阵应保持“谱范数约束”(spectral norm constraint),确保输出变化可控。这源于二阶优化的简化,避免计算昂贵的 Hessian 矩阵。
深入推导:从 RMS 范数到正交化更新
Muon 的推导从线性层的“度量”(metrization)开始。我们用根均方(RMS)范数衡量向量大小:
∥ v ∥ RMS = 1 d ∑ i = 1 d v i 2 = 1 d ∥ v ∥ 2 \|v\|_{\text{RMS}} = \sqrt{\frac{1}{d} \sum_{i=1}^d v_i^2} = \sqrt{\frac{1}{d}} \|v\|_2 ∥v∥RMS=d1i=1∑dvi2=d1∥v∥2
其中 d 是维度。这让“稠密”向量(如元素平均大小为 1)的 RMS 为 1。
对于权重矩阵 W(形状 [fan-out, fan-in]),我们定义 RMS-to-RMS 算子范数:
∥ W ∥ RMS → RMS = fan-in fan-out ∥ W ∥ ∗ \|W\|_{\text{RMS} \to \text{RMS}} = \sqrt{\frac{\text{fan-in}}{\text{fan-out}}} \|W\|_* ∥W∥RMS→RMS=fan-outfan-in∥W∥∗
其中 ∥ W ∥ ∗ \|W\|_* ∥W∥∗ 是谱范数(最大奇异值)。
现在,考虑权重更新 ΔW 对输出变化的影响: Δ y = Δ W x Δy = ΔW x Δy=ΔWx。我们希望界定:
∥ Δ y ∥ RMS ≤ ∥ Δ W ∥ RMS → RMS ⋅ ∥ x ∥ RMS \|\Delta y\|_{\text{RMS}} \leq \| \Delta W \|_{\text{RMS} \to \text{RMS}} \cdot \|x\|_{\text{RMS}} ∥Δy∥RMS≤∥ΔW∥RMS→RMS⋅∥x∥RMS
假设 ∥ x ∥ RMS ≤ 1 \|x\|_{\text{RMS}} ≤ 1 ∥x∥RMS≤1,则控制 ∥ Δ W ∥ RMS → RMS ≤ η \| ΔW \|_{\text{RMS} \to \text{RMS}} ≤ η ∥ΔW∥RMS→RMS≤η(步长)即可稳定输出。
优化问题转为:最小化损失线性近似 ⟨ ∇ W L , Δ W ∇_W \mathcal{L}, ΔW ∇WL,ΔW⟩,受限于上述范数:
min Δ W ⟨ ∇ W L , Δ W ⟩ s.t. ∥ Δ W ∥ RMS → RMS ≤ η \min_{\Delta W} \langle \nabla_W \mathcal{L}, \Delta W \rangle \quad \text{s.t.} \quad \| \Delta W \|_{\text{RMS} \to \text{RMS}} \leq \eta ΔWmin⟨∇WL,ΔW⟩s.t.∥ΔW∥RMS→RMS≤η
假设 ∇ W L = U Σ V T ∇_W \mathcal{L} = U Σ V^T ∇WL=UΣVT(SVD 分解),解为:
Δ W = − η fan-out fan-in U V T \Delta W = - \eta \sqrt{\frac{\text{fan-out}}{\text{fan-in}}} U V^T ΔW=−ηfan-infan-outUVT
这本质上是梯度的“对偶化”(dualizing):保留奇异向量 U V T U V^T UVT,丢弃奇异值 Σ Σ Σ,实现正交化(orthogonalization)。 这让更新更均匀,捕捉低秩梯度中的信息。
Muon 的算法公式:Newton-Schulz 迭代
直接 SVD 太慢,Muon 用 Newton-Schulz (NS) 迭代近似正交化。令 G 为更新矩阵(从 SGD-momentum 得来),目标是:
Ortho ( G ) = arg min O { ∥ O − G ∥ F : O T O = I 或 O O T = I } \text{Ortho}(G) = \arg\min_O \{ \|O - G\|_F : O^T O = I \text{ 或 } O O^T = I \} Ortho(G)=argOmin{∥O−G∥F:OTO=I 或 OOT=I}
NS 迭代公式:
X k + 1 = a X k + b ( X k X k T ) X k + c ( X k X k T ) 2 X k X_{k+1} = a X_k + b (X_k X_k^T) X_k + c (X_k X_k^T)^2 X_k Xk+1=aXk+b(XkXkT)Xk+c(XkXkT)2Xk
系数 a=3.4445, b=-4.7750, c=2.0315(五阶多项式近似,确保奇异值收敛到 1)。初始 X = G / ( ∥ G ∥ F + ε ) X = G / (\|G\|_F + ε) X=G/(∥G∥F+ε),迭代 5 次。
完整 Muon 更新(带动量):
- 动量: M t = μ M t − 1 + G t M_t = μ M_{t-1} + G_t Mt=μMt−1+Gt( G t G_t Gt 是梯度)
- 正交化: O t = N S ( M t ) O_t = NS(M_t) Ot=NS(Mt)
- 更新: W t = W t − 1 − η O t W_t = W_{t-1} - η O_t Wt=Wt−1−ηOt
考虑:NS 迭代高效(仅 5% 开销),但需归一化 G 以确保奇异值在 [0,1] 内收敛。这让 Muon 在小模型上加速 2 倍,但大规模时需额外调整。
第二部分:大规模扩展:添加权重衰减与更新缩放
浅显解释:为什么需要调整?
Muon 在小模型上完美,但万亿参数时,更新可能太“小”(低 RMS,导致表达力不足)或太“大”(不稳定)。添加权重衰减像“刹车”,防止参数爆炸;更新缩放像“放大镜”,确保不同形状矩阵的更新一致。
数学细节与公式
从 Lemma:对于形状 [A,B] 的全秩矩阵,Muon 更新 RMS 为 ( 1 / m a x ( A , B ) ) \sqrt{(1/max(A,B))} (1/max(A,B))。 为一致性,缩放 O t O_t Ot:
O t = NS ( M t ) ⋅ max ( A , B ) ⋅ 0.2 O_t = \text{NS}(M_t) \cdot \sqrt{\max(A,B)} \cdot 0.2 Ot=NS(Mt)⋅max(A,B)⋅0.2
(0.2 匹配 AdamW 的 RMS)。
添加权重衰减(AdamW 风格):
W t = W t − 1 − η ( O t + λ W t − 1 ) W_t = W_{t-1} - \eta (O_t + \lambda W_{t-1}) Wt=Wt−1−η(Ot+λWt−1)
其中 λ 是衰减率。
实验考虑:这些调整让 Muon 复用 AdamW 的超参数,无需重调。在 800M 模型上,带调整的 Muon 损失更低,过拟合更少。
第三部分:MuonClip 变体——Kimi K2 的稳定利器
浅显入门:MuonClip 是“安全网”
Muon 的正交化放大梯度,导致注意力 logits(QK 点积)爆炸,像开车太快撞墙。MuonClip 添加“剪枝”(clip),监控并缩放 Q/K 权重,防止崩溃。Kimi K2 用 15.5 万亿令牌训练零不稳定,全靠它。
深入数学:MuonClip 的公式与整合
MuonClip 整合 Muon 与 QK-Clip。注意力公式:
Q h = X W q h , K h = X W k h \mathbf{Q}^h = \mathbf{X} \mathbf{W}_q^h, \quad \mathbf{K}^h = \mathbf{X} \mathbf{W}_k^h Qh=XWqh,Kh=XWkh
logits: ( 1 / √ d ) Q h K h T (1/√d) Q^h K^{h T} (1/√d)QhKhT
最大 logit:
S max h = 1 d max X ∈ B max i , j Q i h K j h ⊤ S_{\max}^h = \frac{1}{\sqrt{d}} \max_{\mathbf{X} \in B} \max_{i,j} \mathbf{Q}_i^h \mathbf{K}_j^{h \top} Smaxh=d1X∈Bmaxi,jmaxQihKjh⊤
如果 S m a x h > τ S_{max}^h > τ Smaxh>τ(阈值,如 100),计算 γ h = m i n ( 1 , τ / S m a x h ) γ_h = min(1, τ / S_{max}^h) γh=min(1,τ/Smaxh),然后缩放:
W q h ← γ h α W q h , W k h ← γ h 1 − α W k h \mathbf{W}_q^h \leftarrow \gamma_h^\alpha \mathbf{W}_q^h, \quad \mathbf{W}_k^h \leftarrow \gamma_h^{1-\alpha} \mathbf{W}_k^h Wqh←γhαWqh,Wkh←γh1−αWkh
α ≈ 0.5(平衡缩放)。对于 MLA(多头潜在注意力),仅缩放头特定组件。
整合流程:
- 执行 Muon 更新(如上,带衰减和 RMS 匹配)。
- 后处理:应用 QK-Clip,仅当 logits 超阈值时干预(Kimi K2 中仅 13% 头触发,早中期后稳定)。
考虑:Muon 的高秩更新(uniform pressure)提升令牌效率,但易致参数增长和 logits 爆炸(>1000)。QK-Clip 是负反馈机制,精确缩放而不扭曲分布,比软帽(soft-cap)或 QK-Norm 更适 MLA。结果:稳定训练,损失无峰值,性能超 GPT-4。
结语:Muon 与 MuonClip 的启示
Muon 的几何视角(正交化 + 范数约束)重塑优化,公式简洁却强大;MuonClip 的 clip 机制则解决实际痛点,让大规模训练可行。未来,或许结合更多二阶近似或自适应阈值。感谢阅读,如果你有公式疑问,欢迎讨论!🚀
后记
2025年8月14日早上7:06于上海,在Grok 4大模型辅助下完成。
更多推荐


所有评论(0)