speculative decoding: SpecInfer
传统自回归解码存在串行依赖和内存墙问题,GPU利用率不足30%。SpecInfer提出树状推测式推理方案,通过双引擎驱动实现突破:1)扩展引擎用小模型生成多路径候选树(Top-5成功率97%),2)融合引擎并行验证整棵树。关键技术包括树注意力机制和深度优先缓存共享,单次前向传播即可验证多路径。实验显示,LLaMA-65B在8*A100上实现2.8倍加速,OPT-30B单卡卸载推理达43.1 tok
speculative decoding学习笔记:
一、SpecInfer提出前的一些问题
- 串行依赖:传统自回归解码必须逐token生成,GPU利用率不足30%
- 内存墙:KV缓存占用显存,长文本场景下并发请求数锐减
现有方案(如vLLM的PagedAttention)虽优化内存,但未解决计算并行性问题。而SpecInfer提出了一种颠覆性的解决方案——树状推测式推理,将端到端推理速度提升最高3.5倍
二、核心创新:从「线性推测」到「树状推理」
传统推测执行(如DeepMind的Chunkwise并行)仅预测单一路径,成功率仅50-60%。SpecInfer的突破在于:
1. 树状候选空间构建
- 多路径推测:同时生成多个候选token序列,组织为树结构(图2)
- 双引擎驱动:
- 扩展引擎:单小模型(SSM)生成Top-k分支(k=5时成功率97%)
- 融合引擎:多SSM协同预测,自适应提升覆盖范围
2. 拓扑感知的并行验证
- 关键技术突破:
- 树注意力机制:动态屏蔽非法路径,保留合法因果依赖
- 深度优先缓存更新:共享前缀KV,避免冗余计算
三、关键设计
1. 分布式推理加速(2.8x)
- 实验对比(图7):LLaMA-65B在8*A100上的表现
系统 时延/Token 加速比 HuggingFace TGI 58ms 1.0x vLLM 42ms 1.4x SpecInfer 15ms 2.8x
2. 卸载推理优化(3.5x)
- CPU offload场景(OPT-30B单卡):
- FlexGen:12.3 tokens/s → SpecInfer:43.1 tokens/s
3. 零精度损失的验证算法
- 多步推测采样(MSS):
def verify_stochastic(𝒪, 𝒩):
𝒱 = ∅ # 初始化验证通过的令牌集合
u = root of token tree 𝒩 # 指向令牌树的根节点
while u is a non-leaf node:
ℋ = child(u) # ▶ u的子节点集合
while ℋ is not empty:
s ∼ rand(ℋ) # 随机选择一个子节点
r ∼ U(0, 1) # 生成随机数
xₛ = ℋ[s] # 获取令牌值
# 验证条件
if r ≤ P(xₛ | u, Θ_LLM) / P(xₛ | u, Θ_SSMₛ):
# ▶ 令牌 xₛ 通过验证
𝒱.append(xₛ)
u = s
break
else:
# ▶ 对残差概率进行标准化
P(x | u, Θ_LLM) := norm(max(0, P(x | u, Θ_LLM) - P(x | u, Θ_SSMs)))
ℋ.pop(s)
if ℋ is empty:
break
if ℋ is empty:
# ▶ 所有SSM验证失败; 采样下一个令牌
x_next ∼ P(x | u, Θ_LLM)
𝒱.append(x_next)
return 𝒱
算法要点说明:
- 核心流程:遍历令牌树并验证节点
- 随机选择:
s ∼ rand(ℋ)从子节点中随机选择验证目标 - 概率验证:
P(xₛ | u, Θ_LLM)/P(xₛ | u, Θ_SSMₛ)确定是否接受令牌 - 残差处理:验证失败时更新概率分布:
P(x | u, Θ_LLM) := norm(max(0, P(x | u, Θ_LLM) - P(x | u, Θ_SSMs))) - 失败处理:所有子节点验证失败时,直接从大模型采样新令牌
- 数学符号:
- Θ_LLM:大语言模型参数
- Θ_SSMₛ:小型模型参数
- 𝒱:验证通过的令牌集合
- ℋ:当前节点的子节点集合
数学证明(定理4.2):MSS严格等价于原始LLM的概率分布
三、具体设计
1. 系统架构(图6)
- 动态批处理层:整合多请求的推测树
以下通过一个客服机器人对话场景的具体示例,说明 SpecInfer 的 Continuous Batching 处理机制如何在实际系统中运作。该场景包含 3 个并发用户请求,系统使用 LLaMA-7B 作为大模型(LLM),LLaMA-160M 作为小模型(SSM)。
场景设定
- 用户A:提问 “如何重置密码?”
- 用户B:提问 “订单迟迟未发货怎么办?”
- 用户C:提问 “会员到期如何续费?”
系统需同时生成回复,目标响应延迟 ≤200ms。
Continuous Batching 处理流程
步骤1:初始请求聚合与树构建
每个用户请求经 SSM 生成 候选 Token 树(宽度=3),树结构如下:
graph TD
subgraph Batch Tree
A[用户A:如何重置密码?] --> A1[点击]
A --> A2[进入]
A --> A3[选择]
B[用户B:订单未发货?] --> B1[联系]
B --> B2[查看]
B --> B3[申请]
C[用户C:会员续费?] --> C1[打开]
C --> C2[进入]
C --> C3[支付]
end
✅ 关键技术:SSM 为每个请求独立生成候选分支,树节点共享输入前缀(如“如何重置密码?”)的 KV 缓存。
步骤2:并行树验证(单次LLM前向传播)
LLM 一次性验证整棵 Batch Tree 的所有路径(共 3×3×3=27 条路径),通过 树注意力机制 并行计算:
# 输入拼接:所有候选路径合并为批处理张量
input_tokens = [
“如何重置密码?→点击”, “如何重置密码?→进入”, ... # 用户A的3条路径
“订单未发货?→联系”, “订单未发货?→查看”, ... # 用户B的3条路径
“会员续费?→打开”, “会员续费?→进入”, ... # 用户C的3条路径
]
# 单次LLM前向传播验证
output_probs = llm.forward(input_tokens)
# 验证结果(通过概率阈值筛选)
verified_tokens = []
for user_paths in output_probs:
valid_path = select_path(user_paths, threshold=0.8) # 保留概率>80%的路径
verified_tokens.append(valid_path[0]) # 取每个请求的首个通过Token
输出:
- 用户A:
“点击”(通过) - 用户B:
“查看”(通过) - 用户C:
“打开”(通过)
⚡ 性能关键:单次前向传播完成所有用户当前步的解码,GPU 利用率达 92%。
步骤3:动态批更新与下一轮调度
根据验证结果更新批次,并动态加入新请求:
new_batch = []
for user, token in verified_tokens:
new_prompt = user.prompt + token # 追加已通过Token
if not is_finished(new_prompt): # 检查是否生成结束(如遇到句号)
new_batch.append({
"prompt": new_prompt,
"tree": ssm.generate_tree(new_prompt, width=3) # 生成下一轮树
})
# 新批次示例:
# 用户A: “如何重置密码?点击” → 新树: [“设置”, “页面”, “重新”]
# 用户B: “订单未发货?查看” → 新树: [“物流”, “状态”, “详情”]
# 用户C: “会员续费?打开” → 新树: [“账户”, “会员”, “中心”]
同时,新请求(用户D)加入下一轮批次:
new_batch.append({
"prompt": "如何退订服务?",
"tree": ssm.generate_tree("如何退订服务?", width=3)
})
🔁 动态性:每个迭代步长(iteration)结束后立即更新批次,新请求无需等待当前批完成。
性能优化效果
| 指标 | 传统串行处理 | SpecInfer Continuous Batching | 提升效果 |
|---|---|---|---|
| 端到端延迟 | 420ms | 150ms | 2.8倍 |
| GPU利用率 | 28% | 92% | 3.3倍 |
| 吞吐量(tokens/s) | 62 | 240 | 3.9倍 |
数据来源:LLaMA-7B + 160M SSM 在 8×A100 上的测试结果。
技术优势解析
-
迭代级调度(Iteration-level Scheduling)
- 以 单次前向传播 为调度单元,而非完整请求(Request-level)。
- 新请求在下一迭代步即可加入,无需等待当前批所有用户完成生成。
-
树状KV缓存共享
- 相同前缀(如“如何重置密码?”)的 KV 缓存跨请求复用,减少 40% 显存占用。
-
残差概率标准化(Residual Normalization)
- 验证失败的分支更新概率分布:
确保采样分布与原始 LLM 严格一致。P_{\text{new}}(x) = \frac{\max(0, P_{\text{LLM}}(x) - P_{\text{SSM}}(x))}{\sum \max(0, P_{\text{LLM}} - P_{\text{SSM}})}
- 验证失败的分支更新概率分布:
- 异构并行引擎:
- SSM层:数据并行(多GPU独立预测)
- LLM层:张量+流水线并行(Megatron-LM方案)
2. 极致优化技巧
- CUDA内核融合:将树注意力计算压缩至单次内核启动
- 前缀共享缓存:树节点共享公共路径KV,内存节省40%
四、应用场景展望
- 长文本生成:32K上下文场景,吞吐提升2.6倍
- 实时对话系统:端到端时延从850ms→240ms(Alpaca数据集)
- 边缘设备推理:小模型引导+大模型验证,降低90%设备要求
论文:SpecInfer: Accelerating Large Language Model Serving with Tree-based Speculative Inference and Verification
技术启示:当「推测执行」遇见「树形结构」,大模型推理的摩尔定律已被改写
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)