全文 - Jet-Nemotron: Efficient Language Model with Post Neural Architecture Search
通过机器学习,训练出能够找到最优的 LLM 网络结构的超模型。摘要:我们推出了 Jet-Nemotron,一个新的混合架构语言模型系列,它在匹配或超越领先的全注意力模型精度的同时,显著提高了生成吞吐量。Jet-Nemotron 是使用后神经架构搜索(PostNAS)开发的,这是一种新颖的神经架构探索流程,能够实现高效的模型设计。与先前方法不同,PostNAS 从一个预训练的全注意力模型开始,并冻结
通过机器学习,训练出能够找到最优的 LLM 网络结构的超模型。

摘要:我们推出了 Jet-Nemotron,一个新的混合架构语言模型系列,它在匹配或超越领先的全注意力模型精度的同时,显著提高了生成吞吐量。Jet-Nemotron 是使用后神经架构搜索(PostNAS)开发的,这是一种新颖的神经架构探索流程,能够实现高效的模型设计。与先前方法不同,PostNAS 从一个预训练的全注意力模型开始,并冻结其 MLP 权重,从而允许高效探索注意力块的设计。该流程包括四个关键组成部分:(1) 学习最优的全注意力层放置与消除,(2) 线性注意力块选择,(3) 设计新的注意力块,以及 (4) 执行硬件感知的超参数搜索。我们的 Jet-Nemotron-2B 模型在一套全面的基准测试中,达到了与 Qwen3、Qwen2.5、Gemma3 和 Llama3.2 相当或更优的精度,同时实现了高达 53.6 倍的生成吞吐量加速和 6.1 倍的预填充加速。在 MMLU 和 MMLU-Pro 上,它的精度也高于近期先进的 MoE 全注意力模型,例如 DeepSeek-V3-Small 和 Moonlight,尽管这些模型规模更大(总参数量 150 亿,激活参数量 22 亿)。

图 1 | Jet-Nemotron 与最先进的高效语言模型对比。生成吞吐量在 NVIDIA H100 GPU 上、上下文长度为 64K token 的条件下测量。Jet-Nemotron-2B 在 MMLU-Pro 上比 Qwen3-1.7B-Base 精度更高,同时实现了 47 倍更高的生成吞吐量。Jet-Nemotron-4B 尽管模型规模更大,其生成吞吐量仍然高于所有参数少于 20 亿的全注意力模型。
1. 引言
语言模型 [1, 2, 3, 4, 5, 6, 7] 的迅速崛起标志着人工智能进入了一个变革时代,这些模型在广泛的任务中展现出卓越的准确性。然而,由于其巨大的计算和内存需求,它们的效率已成为一个重要问题。这个问题在长上下文生成和推理中尤为突出,其中自注意力机制 [8] 带来了 𝑂(𝑛²) 的计算复杂度并产生大量的键值(KV)缓存¹。
为了应对这一挑战,大量工作致力于通过开发复杂度降低到 𝑂(𝑛) 的注意力机制 [9, 10, 11, 12, 13, 14] 来设计更高效的 LM 架构。与此同时,也有重要工作专注于构建结合全注意力和线性注意力的混合模型,以在准确性和效率之间取得平衡 [15, 16, 17]。虽然这些模型相比全注意力架构提供了更高的效率,但它们的准确性仍然显著落后于最先进(SOTA)的全注意力模型,尤其是在具有挑战性的基准测试上,如 MMLU(-Pro) [18, 19]、数学推理 [20, 21, 22]、检索 [23, 24, 25]、编码 [26, 27, 28] 和长上下文任务 [29]。

本文介绍了 Jet-Nemotron,一个新的 LM 系列,它在匹配 SOTA 全注意力模型精度的同时,提供了卓越的效率。图 1 将 Jet-Nemotron 与先前的高效 LM 进行了比较。值得注意的是,Jet-Nemotron-2B 在 MMLU-Pro 上比 Qwen3-1.7B-Base [5] 精度更高,同时在 NVIDIA H100 GPU 上、64K 上下文长度下提供了 47 倍更高的生成吞吐量。
Jet-Nemotron 基于后神经架构搜索(PostNAS)构建,这是一种新颖的神经架构探索流程(图 2),能够快速设计高效的模型架构。与主流的 LM 架构设计方法不同,PostNAS 从一个预训练的全注意力模型开始,从中继承多层感知机(MLP)权重并在整个过程中保持冻结。这一策略显著降低了训练成本,同时仍允许对注意力块进行全面探索。该流程随后通过四个关键步骤系统地搜索最优的注意力块设计。
i) 全注意力放置与消除。在模型中保留少量全注意力层 [30] 对于在具有挑战性的任务(如检索)上保持高精度至关重要。然而,这些层的最佳放置位置仍不明确。在第 2.2 节中,我们介绍了一种新方法,通过训练一个“一次全包”的超网络 [31](图 4)来自动学习在何处使用全注意力层。最终学习到的放置策略在 MMLU 上的准确性方面显著优于常用的均匀放置策略(图 5,右)。
ii) 线性注意力块选择。在确定全注意力层的放置后,我们进行注意力块搜索以确定最优的线性注意力块(第 2.3 节)。得益于我们框架的低训练成本,我们可以在不同任务的准确性、训练效率和推理速度方面系统地评估现有的线性注意力块。重要的是,我们的方法无需依赖小型代理任务(例如训练微型 LM,如 5000 万或 1.5 亿参数),确保搜索结果能直接转化为最终模型精度的提升。此外,随着新的线性注意力块出现,我们的框架可以快速将它们与先前的设计进行评估,并在结果有希望时采纳它们。
iii) 新注意力块设计。我们的框架也促进了新注意力块的快速设计。添加卷积是增强线性注意力能力 [32] 的广泛使用策略。然而,先前的方法仅依赖于静态卷积核,缺乏动态调整卷积核特征提取模式的能力。在第 2.4 节中,我们介绍了一个新的线性注意力块,JetBlock(图 2,#3)。JetBlock 使用一个核生成器,根据输入生成动态因果卷积核,然后将其应用于值(V)令牌。此外,它移除了查询(Q)和键(K)上冗余的静态卷积,简化了计算。与先前的线性注意力块相比,JetBlock 以较小的开销展示了改进的准确性(表 1)。
iv) 硬件感知架构搜索。最后,在第 2.5 节中,我们引入了硬件感知架构搜索,以确定最优的架构超参数。传统上,参数数量被用作 LM 效率的代理指标。然而,参数数量与实际硬件上的生成效率并不直接相关。我们的硬件感知搜索发现了能够提供相似生成吞吐量的架构超参数,同时使用更多参数以实现更好的准确性(表 2)。
我们在一套全面的基准测试中评估 Jet-Nemotron,包括 MMLU(-Pro) [18, 19]、常识推理 [33, 34, 35, 36, 37, 38]、数学推理 [20, 21, 22, 39]、检索 [23, 24, 25]、编码 [26, 27, 28, 40] 和长上下文任务 [29]。我们的 Jet-Nemotron-2B 模型在所有基准测试中都匹配或超越了 SOTA 全注意力模型,如 Qwen2.5 [4]、Qwen3 [5]、Gemma3 [41, 42] 和 Llama3.2 [2],同时实现了显著更高的生成吞吐量。此外,在长上下文设置下,吞吐量增益更为显著(图 6)。例如,在 256K 上下文长度下,与 Qwen3-1.7B-Base 相比,Jet-Nemotron-2B 提供了 6.14 倍的预填充加速和 53.6 倍的解码加速。我们希望我们的高效 LM 系列(Jet-Nemotron)、我们的新线性注意力块(JetBlock)和我们的架构设计流程(PostNAS)将使社区受益,并加速下一代高效 LM 的开发和部署。我们在下面总结我们的主要贡献:
• 我们引入了 PostNAS,一种用于语言模型的新颖模型架构探索范式。通过重用预训练的 LLM,PostNAS 降低了与 LLM 架构探索相关的成本和风险,使得 LM 架构设计能够更快、更高效地创新。
• 我们为高效 LM 的架构设计提供了新颖的见解,例如注意力层对特定任务的重要性,以及发现 KV 缓存大小是比参数数量对生成吞吐量更关键的因素。
• 我们引入了一个新颖的线性注意力块 JetBlock,它将线性注意力与动态卷积和硬件感知架构搜索相结合。与先前的线性注意力块相比,它始终能带来显著的准确性提升,同时保持相当的生成吞吐量。
• 我们引入了 Jet-Nemotron,一个新颖的混合架构 LM 系列,它在广泛的任务上实现了卓越的准确性,并提供了比先前 SOTA 全注意力模型(如 Qwen2.5、Qwen3、Gemma3 和 Llama3.2)显著更高的生成吞吐量。凭借其强大的准确性和卓越的推理效率,Jet-Nemotron 为各种需要高效 LM 的应用提供了实际效益。
2. 方法
2.1. PostNAS 动机与路线图
由于预训练成本高昂,设计新的语言模型架构具有挑战性和风险。此外,计算资源和训练数据方面的巨大差距使得主要组织之外的研究人员难以匹配大型行业参与者开发的最先进全注意力模型的准确性 [4, 41, 2]。这种差距阻碍了语言模型架构设计的创新。
本文提出了一种开发新语言模型架构的替代策略。我们不从头开始预训练模型,而是在现有全注意力模型的基础上探索新颖架构。这种方法极大地减少了训练成本和数据需求。

虽然在此框架内设计的架构在从头训练时可能不会产生最佳结果,但我们认为它们仍然具有很高的价值。首先,如图 1 所示,它们可以立即在效率和准确性上带来超越最先进全注意力模型的收益,转化为实际效益,如改进服务和降低运营成本。其次,我们的框架可作为架构创新的快速测试平台。如果一个新设计在此环境中表现不佳,那么它在完整预训练中成功的可能性也很低 [43]。这种过滤机制有助于研究人员避免在无望的设计上浪费大量计算资源。
图 2 展示了 PostNAS 的路线图。从一个预训练的全注意力模型开始,它冻结 MLP 权重,并通过四个关键步骤以从粗到细的方式探索注意力块设计:全注意力放置与消除(第 2.2 节)、线性注意力块选择(第 2.3 节)、新注意力块设计(第 2.4 节)和硬件感知架构搜索(第 2.5 节)。图 3 显示了这些步骤带来的准确性改进分解。我们观察到所有基准测试都有显著的准确性提升:MMLU +5.3,数学 +8.4,检索 +7.8,常识推理 +3.2。
2.2. 全注意力放置与消除
引入少量全注意力层已成为提高准确性 [30, 16, 44, 17] 的常见策略。标准方法是在固定的层子集上均匀地应用全注意力,其余层使用线性注意力。然而,这种均匀策略并非最优,尤其是在我们从预训练全注意力模型开始的设置中。
为了解决这个问题,我们提出了一种自动方法,用于高效确定全注意力层的放置位置。整体方法如图 4 所示。我们通过为预训练的全注意力模型添加替代的线性注意力路径来构建一个“一次全包”的超网络 [45, 31]。在训练期间,我们在每一步随机采样一个活动路径,形成一个子网络,该子网络使用特征蒸馏损失 [46, 47, 48] 进行训练。
训练完成后,我们执行波束搜索 [49] 以确定在给定约束(例如,两个全注意力层)下全注意力层的最佳放置位置。搜索目标取决于任务:对于 MMLU,我们选择在正确答案上损失最低的配置(即最大化 −𝑙𝑜𝑠𝑠),而对于数学和检索任务,我们选择准确性最高的配置。如图 5(b) 所示,PostNAS 在准确性方面显著优于均匀放置。
图 5(a) 展示了 Qwen2.5-1.5B 的搜索结果。对于每一层,我们通过将该层配置为全注意力,同时将所有剩余层设置为线性注意力,从超网络中提取相应的子网络。我们评估每个子网络在给定任务上的准确性或损失,并使用热图可视化结果。我们的分析揭示了三个关键发现:
关键发现 1:在预训练的全注意力模型中,并非所有注意力层的贡献都相等。对于 MMLU,只有两个层表现出关键重要性,而对于检索任务,只有两到三个层特别关键。
关键发现 2:不同的注意力层贡献于不同的能力。对 MMLU 准确性关键的层不一定对检索任务重要。
关键发现 3:对于复杂任务(如数学推理),注意力重要性的模式变得更加复杂。幸运的是,为 MMLU 和检索确定的前几个关键层的组合已经涵盖了数学任务所需的大部分关键层。
除了这些关键发现,我们观察到在使用不同线性注意力操作时,搜索结果保持一致。在我们最终的实验中,为了简单和略微提高的训练吞吐量,我们在“一次全包”的超网络训练中使用了 GLA [11]。
2.3. 线性注意力块选择
基于发现的全注意力层放置,我们进行注意力块搜索,以确定最适合我们设置的线性注意力块。在我们的实验中,我们评估了六个 SOTA 线性注意力块,包括 RWKV7 [10]、RetNet [12]、Mamba2 [50]、GLA [11]、Deltanet [51] 和 Gated DeltaNet [32]。
在初步的效率分析后,我们观察到 RWKV7 的训练吞吐量显著低于其他线性注意力块,可能是由于核实现未达最优。因此,我们将其排除在训练实验之外。表 1 总结的结果表明,Gated DeltaNet 在评估的线性注意力块中实现了最佳的整体准确性。这归因于两个因素的结合:(1) 数据依赖的门控机制 [52],它动态控制模型是应该更关注当前令牌还是历史状态;(2) Delta 规则 [53],它用当前令牌的信息增量更新历史状态,以节省有限的状态内存。因此,我们在实验中采用 Gated DeltaNet。
2.4. 新注意力块设计
我们提出了一个新的线性注意力块 JetBlock,旨在通过将动态卷积 [54, 55] 融入线性注意力来增强模型的表达能力。卷积已被证明是在许多线性注意力块 [32, 56] 中实现强准确性的关键。然而,先前的工作通常使用静态卷积核,无法根据输入调整其特征提取模式。
为了解决这个限制,我们引入了一个核生成器模块,该模块基于输入特征动态产生卷积核。整体结构如图 2 (#3) 所示。该模块与 Q/K/V 投影层共享相同的输入,并以一个线性降维层开始以提高效率,使用 8 的降维比。应用 SiLU 激活函数 [57],随后是一个最终线性层,输出卷积核权重。我们采用 Gated DeltaNet 进行时间混合,因为如第 2.3 节讨论,与其他设计相比,它的性能最佳。

我们将动态卷积核应用于值(V)令牌,因为将它们应用于查询(Q)或键(K)令牌益处不大。此外,我们发现一旦对 V 应用了动态卷积,移除 Q 和 K 上的静态卷积对最终模型精度的影响可以忽略不计。我们在最终实验中采用了这种设计,因为它略微提高了效率。表 1 将 JetBlock 与先前的线性注意力块进行了比较。它在数学推理和检索任务上提供了比 Gated DeltaNet 更好的准确性,同时保持了相似的效率。
2.5. 硬件感知架构搜索
在最终确定宏观架构(特别是全注意力层的放置)并选择线性注意力块之后,我们执行硬件感知架构搜索以优化核心架构超参数,包括键/值维度和注意力头数量。
传统上,参数大小是用于指导模型架构设计的主要效率指标。然而,这种方法并非最优,因为参数数量与硬件效率并不直接相关。我们通过使用生成吞吐量作为选择架构超参数的直接目标来解决这个限制。我们发现:
关键发现 4:KV 缓存大小是影响长上下文和长生成长吞吐量的最关键因素。当 KV 缓存大小恒定时,具有不同参数数量的模型表现出相似的生成吞吐量(表 2)。
这是因为解码阶段通常是内存带宽受限而非计算受限。在长上下文场景中,KV 缓存通常比模型权重消耗更多内存。减小其大小可以减少每个解码步骤的内存传输时间,并允许更大的批处理大小,从而提高生成吞吐量。

基于发现 4,我们固定 KV 缓存大小以匹配原始设计,并对键维度、值维度和注意力头数量进行小规模网格搜索。表 2 总结了结果,其中所有变体使用相同的线性注意力块(即 Gated DeltaNet),但具有不同的配置。蓝色行代表我们的最终设计,而灰色行对应原始设计。我们的最终配置实现了与原始设计相当的生成吞吐量,同时包含了更多参数并提高了准确性。从表 1 可以看出,我们在 PostNAS 中的硬件感知搜索提升了 JetBlock 的准确性,同时保持了训练和推理吞吐量。
3. 实验
3.1. 设置
Jet-Nemotron 模型系列。我们构建了两个不同参数大小的 Jet-Nemotron 版本:Jet-Nemotron-2B 和 Jet-Nemotron-4B。我们使用检索任务来指导全注意力层的放置,并使用 MMLU 任务来指导滑动窗口注意力(SWA)层的放置。Jet-Nemotron-2B 基于 Qwen2.5-1.5B [4] 构建,包含两个用于检索任务的全注意力层(第 15 和 20 层)和两个用于像 MMLU 这样的多项选择题的滑动窗口注意力(SWA)层(第 21 和 22 层)。我们发现多项选择题主要依赖于 softmax 操作的模式匹配特性来将答案的知识路由到其选项。SWA 有效地保留了此类任务的准确性。剩余的注意力层被替换为 JetBlock。类似地,Jet-Nemotron-4B 基于 Qwen2.5-3B 构建,包含三个全注意力层(第 18, 21, 33 层)和七个 SWA 层(第 6, 17, 20, 22, 23, 26, 和 28 层)。我们在附录 A.1 中总结了最终的模型架构。

训练细节。训练包含两个阶段。在第一阶段,我们冻结 MLP 并使用蒸馏损失来训练模型。在第二阶段,我们执行全模型训练。在第一阶段,我们使用 Nemotron-CC [63] 和 Redstone-QA [64] 的组合作为预训练语料库,并将 Jet-Nemotron 模型训练了 500 亿个 token。这也是我们在第 2 节中执行 PostNAS 的设置。在第二阶段,我们在数据混合中加入了更多来自数学 [65] 和编码 [66, 67] 领域的高质量数据。随后,模型在 3500 亿个 token 上进行了训练。我们在附录 A.2 中总结了实验成本。
评估细节。我们在主流基准测试设置下评估 Jet-Nemotron:MMLU(-Pro) [18, 19]、数学推理 [18, 20, 21, 22]、常识推理 [33, 34, 35, 36, 37, 38]、检索 [23, 24, 25]、编码 [26, 27, 28, 40] 和长上下文任务 [29]。我们将我们的模型与最先进的全注意力模型 [2, 4, 5]、线性注意力模型 [10, 50] 和混合模型 [41, 44] 进行比较。我们对 GSM8K [22] 和 MATH [18] 采用 4-shot 评估,对 GPQA [20] 和 MMLU-Pro [19] 采用 5-shot 评估。我们使用 EvalPlus [40] 和 CRUXEval [28] 的官方实现进行编码任务。对于所有其他任务,我们使用零样本设置。所有评估均基于 LM-Evaluation-Harness [68]。
吞吐量测试平台。我们的吞吐量评估在 DGX H100 服务器上进行,该服务器配备 8 个 NVIDIA H100 GPU、2 个 Intel Xeon Platinum 8480C(112 核)CPU 和 2TB 内存。为了公平和一致的比较,我们采用了最新的可用软件版本。具体来说,我们的环境包括 Pytorch 2.7.0 和 Triton 3.3.0。我们使用 FlashAttention 2.7.4 [69] 实现全注意力块,并使用 Flash-Linear-Attention 0.2.1 [70] 实现线性注意力块。模型推理基于 Transformers 4.52.0 实现 [71]。除非明确说明,上下文长度为 64K,每个模型在单个 H100 GPU 上进行测试。我们在表 3 中报告了 64K 输入上下文的缓存大小。在测试吞吐量时,我们采用分块预填充 [72],并搜索分块大小以在 GPU 内存限制下最大化每个模型的批处理大小。通过这种方式,我们测量了设备上可实现的最高解码吞吐量。我们在附录 A.3 中列出了每个模型使用的批处理大小。
3.2. 准确性主要结果
MMLU(-Pro) 和 BBH 上的结果。表 3 将 Jet-Nemotron 与最先进的高效语言模型进行了比较。Jet-Nemotron-2B 实现了比 Qwen3-1.7B-Base 高 47 倍的吞吐量和 47 倍小的缓存大小,同时在 MMLU、MMLU-Pro 和 BBH 上提供了显著更好的准确性。Jet-Nemotron-2B 甚至优于近期的 MoE 模型,如具有更多激活参数(22 亿)和总参数(150 亿)大得多的 DeepSeek-V3-Small [6] 和 Moonlight [61]。当扩展到 40 亿参数时,Jet-Nemotron-4B 相对于 Qwen3-1.7B-Base 仍保持 21 倍的吞吐量优势。与其他线性注意力和混合模型相比,Jet-Nemotron 也实现了显著更高的准确性。

数学任务上的结果。表 4 报告了我们在数学任务上的结果。Jet-Nemotron-2B 实现了 49.6 的平均准确率,以 6.3 的优势超过 Qwen3-1.7B-Base,同时速度快 47 倍。相比之下,先前的线性注意力和混合模型在数学任务上远远落后于 Qwen3。
常识推理任务上的结果。表 5 总结了常识推理任务的结果。Qwen2.5 和 Qwen3 在这个领域相对较弱。然而,以 Qwen2.5-1.5B 为起点的 Jet-Nemotron-2B 仍然展示了强劲的结果,实现了 62.0 的平均准确率,优于所有基线模型。
检索任务上的结果。表 6 展示了检索任务的结果。Jet-Nemotron-2B 优于除 Qwen3-1.7B-Base 之外的所有基线。当扩展到 40 亿参数时,Jet-Nemotron-4B 实现了 76.2 的最佳平均准确率,同时与 Qwen3 相比仍保持 21 倍的加速。
编码任务上的结果。表 7 显示了编码任务的结果。Jet-Nemotron-2B 的平均准确率高于所有基线。Jet-Nemotron-4B 在所有编码任务中实现了更高的准确率,同时在生成吞吐量上相对于 Qwen3-1.7B-Base 等领先 LM 仍具有巨大优势。
长上下文任务上的结果。线性和混合架构的一个常见担忧是它们在长上下文任务上的准确性。在表 8 中,我们在 LongBench [29] 上评估了这一点,上下文长度高达 64K。我们的研究结果表明,带有两个全注意力层的 Jet-Nemotron-2B 实现了与 Qwen2.5-1.5B 和 Gemma3n-E2B 等领先模型相当的性能,而后两者具有更多的此类层。此外,我们的 Jet-Nemotron-4B 在生成吞吐量上提供 21 倍加速的同时,性能超过了 Qwen3-1.7B-Base。这些结果极大地推进了长上下文任务中效率-准确性权衡的前沿。
总结。结合之前的结果,Jet-Nemotron-2B 和 Jet-Nemotron-4B 在所有六个评估领域中的表现与先进的全注意力模型(Qwen3-1.7B-Base)相当甚至更好。凭借显著更少的全注意力层和更小的 KV 缓存大小,Jet-Nemotron-2B 和 Jet-Nemotron-4B 分别提供了比 Qwen3-1.7B-Base 高 47 倍和 21 倍的生成吞吐量。
3.3. 效率基准结果
图 6 显示了 Qwen3-1.7B-Base 和 Jet-Nemotron-2B 在不同上下文长度下的吞吐量比较。在预填充阶段,Jet-Nemotron-2B 在较短上下文长度(4K 和 8K)下最初比 Qwen3-1.7B-Base 快 1.14 和 1.15 倍。这可以通过设计更好优化的 JetBlock 核实现来进一步改进。随着上下文长度的增加,线性注意力的优势变得突出,使得 Jet-Nemotron-2B 在 256K 上下文长度下实现了 6.14 倍的加速。
在解码阶段,Jet-Nemotron-2B 始终以较大幅度优于 Qwen3-1.7B-Base。由于 Jet-Nemotron-2B 包含 2 个全注意力层,每组有 2 组键值状态,其相对于具有 28 个全注意力层(每层包含 8 组键值状态)的 Qwen3-1.7B-Base 的理论最大加速是 14 × 4 = 56 倍。在我们的吞吐量测试平台中,Jet-Nemotron-2B 在 4K 上下文长度下实现了 15.6 倍的加速,在 256K 上下文长度下实现了高达 53.6 倍的加速,几乎达到了理论上限。
4. 相关工作
大型语言模型(LLMs)功能强大但计算密集,这促使许多工作为 LLMs 构建高效的模型架构。一系列研究专注于设计高效的线性注意力块 [9, 10, 11, 12, 32, 50, 51, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 82, 84] 或对数线性注意力 [85] 块来替代全注意力块。正交地,另一系列研究尝试结合全注意力和线性注意力来构建混合模型 [13, 15, 16, 17, 44, 86, 87, 88]。这些工作通常专注于预训练设置,并且它们的准确性落后于领先的全注意力模型。最近,有一些努力致力于将具有全注意力的 LLMs 线性化,即用线性注意力替换全注意力 [89, 90, 91, 92, 93, 94, 95, 96]。然而,由于评估特定配置的开销很大,它们的模型架构优化不佳,因此其结果仍然不如 SOTA 全注意力模型。
我们的工作也与神经架构搜索(NAS)[45, 97, 98, 99, 100] 相关,这是一种探索架构设计空间和发现新颖模型结构的强大技术。特别是,硬件感知神经架构搜索 [45] 通过训练一个"一次全包"的超网络 [31],或利用逐层蒸馏 [101, 102] 等方法,使得开发针对目标硬件优化的专用模型架构成为可能。然而,由于预训练成本过高,NAS 在大型语言模型(LLMs)时代很少被应用。近期的努力主要集中于构建灵活的 LLM 架构 [103, 104],这些架构可以生成一系列具有不同深度和宽度的子网络,以适应不同的硬件平台。然而,这些子网络的架构主干保持不变,完全依赖于全注意力层。
5. 结论
我们推出了 Jet-Nemotron,一个新的混合架构语言模型系列,它在性能上超越了最先进的全注意力模型——包括 Qwen3、Qwen2.5、Gemma3 和 Llama3.2——同时提供了显著的效率提升,在 H100 GPU 上(256K 上下文长度,最大批处理大小)生成吞吐量提升高达 53.6 倍。Jet-Nemotron 的实现得益于两项关键创新:(1) 后神经架构搜索(Post Neural Architecture Search),一种高效的训练后架构适应流程,适用于任何预训练的 Transformer 模型;以及 (2) JetBlock,一种新颖的线性注意力块,其性能显著优于先前的设计,如 Mamba2、GLA 和 Gated DeltaNet。大量的实证结果表明,Jet-Nemotron 在广泛的基准测试中实现了重大的效率改进,同时没有牺牲准确性。此外,Jet-Nemotron 显著降低了与 LLM 架构探索相关的成本和风险,使得语言模型设计的创新更快、更高效。
A. 实验细节
A.1. 最终模型架构
最终的 Jet-Nemotron 模型由一系列块堆叠而成,每个块包含一个多层感知机(MLP)层和一个注意力层。注意力层从三种类型中选择一种:全注意力、滑动窗口注意力或 JetBlock。详细的架构配置如表 9 所示。
全注意力和滑动窗口注意力层使用分组查询注意力 [105],并按表 10 进行配置。对于滑动窗口注意力层,在 Jet-Nemotron-2B 中窗口大小设置为 1,152,在 Jet-Nemotron-4B 中设置为 2,048。
JetBlock 的配置如表 11 所示:
A.2. 实验成本
表 11 | JetBlock 的配置。
表 12 总结了 PostNAS 和训练 Jet-Nemotron-2B 模型的成本。我们并行使用了 32 个 H100 GPU。报告的 GPU 小时数已计入设备总数。
A.3. 吞吐量测量
在整个实验中,我们在单个 H100 GPU 上测量了 Jet-Nemotron 和基线模型可达到的最大预填充和解码吞吐量。这是通过调整分块预填充 [72] 中的分块大小来实现的,以在不牺牲预填充吞吐量的前提下最大化解码批处理大小。我们在表 13 中列出了每个模型的优化后批处理大小及相应的分块大小。预填充上下文长度为 64K。由于 KV 缓存内存主导了推理过程中的 GPU 使用量,通过减少每个序列的内存占用,更小的缓存允许更多序列被并行处理,从而极大地提升了生成吞吐量。
B. 补充结果
B.1. 关于训练数据的对照研究
为了排除训练数据的影响,我们在 Jet-Nemotron 的训练数据集上对基线模型(Qwen2.5、RWKV-7 和 Mamba-2)进行了持续预训练,以提供更全面的评估。表 14 的结果表明,Jet-Nemotron-2B 以显著优势优于所有这些经过微调的基线模型。
B.2. 在低端硬件上的吞吐量结果
我们在 NVIDIA Jetson Orin (32GB) 和 NVIDIA RTX 3090 GPU 上测量了 Jet-Nemotron-2B 和 Qwen2.5-1.5B 在 64K 上下文长度下的吞吐量。表 15 的结果显示,Jet-Nemotron-2B 在 Jetson Orin 和 RTX 3090 GPU 上分别比 Qwen2.5-1.5B 实现了 8.84 倍和 6.50 倍的加速。

参考:
https://arxiv.org/pdf/2508.15884v1
https://github.com/NVlabs/Jet-Nemotron
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)