如何基于幂采样优化LLM推理-原理&代码&示例
目前(25.11.25),幂分布采样被认为能媲美RL调优的LLM推理优化技术。RL调优模型使其输出更符合人类认知,效果相当于对LLM的分布进行锐化。幂分布是典型的自然锐化分布。以下论文显示,幂分布采样,不需要RL调优,也能有效优化LLM的推理输出。这里通过解读其实现代码,尝试探索和分析其推理优化过程。
目前(25.11.25),幂采样被认为能媲美RL调优的LLM推理优化技术。
RL调优模型使其输出更符合人类认知,效果相当于对LLM的分布进行锐化。
幂分布是典型的自然锐化分布。
以下论文显示,幂采样,不需要RL调优,也能有效优化LLM的推理输出。
Reasoning with Sampling: Your Base Model is Smarter Than You Think
https://arxiv.org/abs/2510.14901
这里通过解读其实现代码,尝试探索和分析其推理优化过程。
https://github.com/BitMakerMan/Power-Sampling-Training
1 幂采样
幂采样,是一种通过M-H采样提升LLM推理能力的算法。
Power Sampling is an algorithm that improves LLM reasoning using the
Metropolis-Hastings method (a technique from computational statistics).
1.1 幂采样接受拒绝概率计算
以下是幂采样中M-H接受拒绝概率的计算公式。
P(accept) = min(1, exp(α * (log P_proposal - log P_current)))
α表示锐化参数,通常取值在2.0到6.0之间。
P_proposal和P_current分布表示提议分布概率和当前分布的概率。
提议分布是在当前分布的基础上,通过优化的来。
1.2 幂采样分块采样过程
幂采样优化过程,其实就是LLM的多次推理过程,示例如下。
1)采样初始化
LLM依据prompt生成初始输出init_output,这是第一个版本。
2)文本分块和提议得分计算
将初始文本划分为多个分块,对于每个分块
分别计算当前版本的概率分数,提议分布输出对应文本的概率分数。
3)Metropolis-Hasting接受拒绝采样
如果提议分布获得更高的分数,则更有可能接收提议分布。
偶尔也会接收更差的结果,目的是防止优化路径陷入局部最佳区域。
4)多次重复采样过程,每次仅优化block内文本。
一般情况下迭代次数越多生成的质量越好。
2 采样代码示例
这里通过梳理幂采样代码,探索和分析其优化过程。
2.1 提议分布生成
使用提议分布采样文本,其实就是嗲用model以sample方式重新生成文本的过程
首先,输入来自于当前文本截断,current_seq[:start_pos]
其次,设置do_sample=True,提议分布采样文本,即为调用model生成proposal_block
然后,将current_seq[:start_pos]和proposal_block拼接,组成新采样的文本proposal_seq。
以下是采样代码示例。
# Create context (prefix)
if start_pos > 0:
context_ids = current_seq[:start_pos].unsqueeze(0)
else:
# Use original prompt as context if we're at the beginning
context_ids = inputs['input_ids']
# Generate proposal block
with torch.no_grad():
proposal_block = self.model.generate(
context_ids,
max_new_tokens=block_size,
do_sample=True,
temperature=temperature,
pad_token_id=self.tokenizer.eos_token_id,
num_return_sequences=1
)
# Extract only the newly generated part
if start_pos == 0:
new_tokens = proposal_block[0][context_ids.size(-1):]
else:
new_tokens = proposal_block[0][context_ids.size(-1):]
# Create proposal sequence
proposal_seq = torch.cat([
current_seq[:start_pos],
new_tokens[:block_size],
current_seq[end_pos:] if end_pos < current_seq.size(-1) else torch.tensor([], device=self.device)
])
2.2 分布得分计算
分布得分计算是一次调用模型计算输入和输出的对数似然的过程。
首先,输入为待计算分布得分的的id序列inputs,比如current_seq、proposal_seq。
其次,分布得分计算,就是计算模型输出当前id序列的概率,类似监督学习过程,具体为
输入inputs作为model的输入input_ids,inputs也作为输出labels的序列,对应输入和监督标签。
输出为input_ids和labels的负对数似然,因为得分为对数似然,所以取反,即-outputs.loss()。
示例代码如下
def compute_log_probability(self, sequence: torch.Tensor) -> float:
"""
Compute the log probability of a sequence under the model.
Args:
sequence: Token sequence tensor
Returns:
Log probability sum
"""
with torch.no_grad():
inputs = sequence.unsqueeze(0) if sequence.dim() == 1 else sequence
attention_mask = torch.ones_like(inputs)
outputs = self.model(
input_ids=inputs,
attention_mask=attention_mask,
labels=inputs
)
return -outputs.loss.item() * inputs.size(-1)
2.3 M-H接受拒绝采样
M-H接受拒绝采样,其实就是决定是否接受新生成的propsal_seq的过程。
首先,分别结算原始序列current_seq和新生成序列proposal_seq的得分。
其次,利用幂采样接受拒绝公式,计算proposal_seq的接受概率。
P(accept) = min(1, exp(α * (log P_proposal - log P_current)))
然后,引入均匀分布,随机决定是否接受proposal_seq,具体为
如果 P(random) < P(accept),说明当前proposal_seq可以接受,否则拒绝。
需要注意的是,虽然P(accept)越高,越有可能接受,但不是绝对的。
偶尔也会接收更差的结果,目的是防止优化路径陷入局部最佳区域。
log_p_current = self.compute_log_probability(current_seq)
log_p_proposal = self.compute_log_probability(proposal_seq)
# Metropolis acceptance criterion with sharpening
log_acceptance_ratio = alpha * (log_p_proposal - log_p_current)
acceptance_prob = min(1.0, torch.exp(torch.tensor(log_acceptance_ratio)).item())
# Accept or reject proposal
if random.random() < acceptance_prob:
current_seq = proposal_seq
if show_progress:
iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✓"})
else:
if show_progress:
iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✗"})
2.4 整体采样过程
以下是合并以上部分,幂分布整体采样的过程代码。
import torch
import torch.nn.functional as F
import random
import numpy as np
from typing import Optional, Dict, Any
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
class PowerSampler:
"""
Implements Power Sampling algorithm for improved LLM reasoning.
Power Sampling uses Metropolis-Hastings algorithm to resample text blocks,
improving logical coherence and reasoning capabilities.
"""
def __init__(
self,
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
device: Optional[str] = None
):
"""
Initialize the Power Sampler.
Args:
model: The causal language model to use
tokenizer: The tokenizer corresponding to the model
device: Device to run on (auto-detect if None)
"""
self.model = model
self.tokenizer = tokenizer
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Set padding token if not present
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def compute_log_probability(self, sequence: torch.Tensor) -> float:
"""
Compute the log probability of a sequence under the model.
Args:
sequence: Token sequence tensor
Returns:
Log probability sum
"""
with torch.no_grad():
inputs = sequence.unsqueeze(0) if sequence.dim() == 1 else sequence
attention_mask = torch.ones_like(inputs)
outputs = self.model(
input_ids=inputs,
attention_mask=attention_mask,
labels=inputs
)
return -outputs.loss.item() * inputs.size(-1)
def power_sample(
self,
prompt: str,
alpha: float = 4.0,
block_size: int = 192,
steps: int = 10,
max_len: int = 2048,
temperature: float = 1.0,
show_progress: bool = False
) -> str:
"""
Execute Power Sampling algorithm.
Args:
prompt: Input prompt text
alpha: Sharpening factor (recommended: 4.0)
block_size: Size of text blocks to resample (recommended: T/16)
steps: Number of Metropolis-Hastings iterations (recommended: 10)
max_len: Maximum sequence length
temperature: Sampling temperature
show_progress: Show progress bar
Returns:
Generated text with improved reasoning
"""
# Tokenize prompt
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_len - block_size
).to(self.device)
# Generate initial sequence
with torch.no_grad():
initial_output = self.model.generate(
**inputs,
max_new_tokens=block_size,
do_sample=True,
temperature=temperature,
pad_token_id=self.tokenizer.eos_token_id,
num_return_sequences=1
)
# Extract generated sequence (remove prompt part)
prompt_length = inputs['input_ids'].size(-1)
current_seq = initial_output[0][prompt_length:]
# Metropolis-Hastings iterations
iterator = tqdm(range(steps), desc="Power Sampling") if show_progress else range(steps)
for step in iterator:
# Select random block to resample
if current_seq.size(-1) <= block_size:
# If sequence is shorter than block_size, regenerate whole thing
start_pos = 0
end_pos = current_seq.size(-1)
else:
start_pos = random.randint(0, current_seq.size(-1) - block_size)
end_pos = start_pos + block_size
# Create context (prefix)
if start_pos > 0:
context_ids = current_seq[:start_pos].unsqueeze(0)
else:
# Use original prompt as context if we're at the beginning
context_ids = inputs['input_ids']
# Generate proposal block
with torch.no_grad():
proposal_block = self.model.generate(
context_ids,
max_new_tokens=block_size,
do_sample=True,
temperature=temperature,
pad_token_id=self.tokenizer.eos_token_id,
num_return_sequences=1
)
# Extract only the newly generated part
if start_pos == 0:
new_tokens = proposal_block[0][context_ids.size(-1):]
else:
new_tokens = proposal_block[0][context_ids.size(-1):]
# Create proposal sequence
proposal_seq = torch.cat([
current_seq[:start_pos],
new_tokens[:block_size],
current_seq[end_pos:] if end_pos < current_seq.size(-1) else torch.tensor([], device=self.device)
])
# Ensure we don't exceed max length
if proposal_seq.size(-1) > max_len - prompt_length:
proposal_seq = proposal_seq[:max_len - prompt_length]
current_seq = current_seq[:max_len - prompt_length]
# Compute acceptance probability
try:
log_p_current = self.compute_log_probability(current_seq)
log_p_proposal = self.compute_log_probability(proposal_seq)
# Metropolis acceptance criterion with sharpening
log_acceptance_ratio = alpha * (log_p_proposal - log_p_current)
acceptance_prob = min(1.0, torch.exp(torch.tensor(log_acceptance_ratio)).item())
# Accept or reject proposal
if random.random() < acceptance_prob:
current_seq = proposal_seq
if show_progress:
iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✓"})
else:
if show_progress:
iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✗"})
except Exception as e:
# If probability computation fails, keep current sequence
if show_progress:
iterator.set_postfix({"error": str(e)[:20]})
continue
# Combine prompt and generated sequence
final_sequence = torch.cat([inputs['input_ids'][0], current_seq])
# Decode and return
return self.tokenizer.decode(final_sequence, skip_special_tokens=True)
def batch_power_sample(
self,
prompts: list,
alpha: float = 4.0,
block_size: int = 192,
steps: int = 10,
max_len: int = 2048,
temperature: float = 1.0,
show_progress: bool = False
) -> list:
"""
Execute Power Sampling on multiple prompts.
Args:
prompts: List of input prompts
alpha: Sharpening factor
block_size: Size of text blocks to resample
steps: Number of Metropolis-Hastings iterations
max_len: Maximum sequence length
temperature: Sampling temperature
show_progress: Show progress bar
Returns:
List of generated texts
"""
results = []
iterator = tqdm(prompts, desc="Batch Power Sampling") if show_progress else prompts
for prompt in iterator:
result = self.power_sample(
prompt=prompt,
alpha=alpha,
block_size=block_size,
steps=steps,
max_len=max_len,
temperature=temperature,
show_progress=False
)
results.append(result)
return results
def load_model_and_tokenizer(model_name: str, device: Optional[str] = None):
"""
Convenience function to load model and tokenizer.
Args:
model_name: Hugging Face model name
device: Device to load on
Returns:
Tuple of (model, tokenizer)
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype="auto",
device_map="auto" if device == "cuda" else None,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
if device != "cuda":
model = model.to(device)
return model, tokenizer
def power_sample(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: str,
alpha: float = 4.0,
block_size: int = 192,
steps: int = 10,
max_len: int = 2048,
temperature: float = 1.0,
device: Optional[str] = None,
show_progress: bool = False
) -> str:
"""
Convenience function for single-call power sampling.
Args:
model: The causal language model
tokenizer: The tokenizer
prompt: Input prompt text
alpha: Sharpening factor (recommended: 4.0)
block_size: Size of text blocks to resample
steps: Number of Metropolis-Hastings iterations
max_len: Maximum sequence length
temperature: Sampling temperature
device: Device to run on
show_progress: Show progress bar
Returns:
Generated text with improved reasoning
"""
sampler = PowerSampler(model, tokenizer, device)
return sampler.power_sample(
prompt=prompt,
alpha=alpha,
block_size=block_size,
steps=steps,
max_len=max_len,
temperature=temperature,
show_progress=show_progress
)
https://github.com/BitMakerMan/Power-Sampling-Training/blob/main/src/power_sampling.py
3 对比测试
这里通过运行对比测试,示例幂采样的生效过程。
所用代码仅用到torch、transformer等基础库,几乎不需要调试即可运行。
3.1 测试程序
这里对比标准输出和幂采样输出,示例程序如下所示。
测试问题: What is artificial intelligence?
这里通过调整锐化参数α=2.0, 4.0, 6.0,分别展示不同程度的幂采样优化。
import sys
import os
# Add src directory to path
sys.path.append(os.path.join("/data/apps/llm/Power-Sampling-Training", 'src'))
try:
from power_sampling import load_model_and_tokenizer, power_sample
print("✅ Successfully imported Power Sampling modules")
except ImportError as e:
print(f"Import error: {e}")
sys.exit(1)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def print_header(title):
"""Print formatted header"""
print(f"\n{'='*60}")
print(f"🎯 {title}")
print(f"{'='*60}")
def print_comparison(original, improved, prompt):
"""Print side-by-side comparison"""
print(f"\n📝 PROMPT: {prompt}")
print("-" * 60)
print(f"🔸 STANDARD GENERATION:")
print(f" {original}")
print()
print(f"✨ POWER SAMPLING (Improved):")
print(f" {improved}")
print()
print("💡 IMPROVEMENTS:")
# Simple analysis of improvements
if len(improved) > len(original):
print(" ✓ More detailed response")
if "because" in improved.lower() or "therefore" in improved.lower():
print(" ✓ Better logical connections")
if "step" in improved.lower() or "first" in improved.lower():
print(" ✓ Structured reasoning")
if improved.count('.') > original.count('.'):
print(" ✓ More complete sentences")
print("-" * 60)
def demonstrate_basic_generation(model, tokenizer):
"""Show basic LLM generation without Power Sampling"""
print_header("1. Standard LLM Generation (Without Power Sampling)")
prompt = "What is artificial intelligence?"
print(f"Prompt: {prompt}")
print("This shows how a normal language model generates text...")
# Standard generation
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=80,
temperature=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
standard_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Response: {standard_response}")
return standard_response
def demonstrate_power_sampling(model, tokenizer):
"""Show Power Sampling in action"""
print_header("2. Power Sampling in Action")
prompt = "What is artificial intelligence?"
print(f"Same prompt: {prompt}")
print("Power Sampling applies Metropolis-Hastings to improve coherence...")
# Power Sampling with different parameters
print("\n🔧 Testing different Power Sampling parameters:")
for alpha in [2.0, 4.0, 6.0]:
print(f"\n--- Alpha = {alpha} (Sharpening Factor) ---")
response = power_sample(
model=model,
tokenizer=tokenizer,
prompt=prompt,
alpha=alpha,
steps=3,
block_size=32,
max_len=120,
show_progress=False
)
print(f"Response: {response[:200]}..." if len(response) > 200 else f"Response: {response}")
def demonstrate_step_by_step(model, tokenizer):
"""Show how Power Sampling works step by step"""
print_header("3. How Power Sampling Works - Step by Step")
prompt = "Explain why the sky is blue"
print(f"Prompt: {prompt}")
print("\n🔄 Power Sampling Process:")
print("1. Generate initial text")
print("2. Divide text into blocks")
print("3. Propose alternatives for each block")
print("4. Accept/reject based on probability improvement")
print("5. Repeat for multiple iterations")
# Run with progress to show the steps
print(f"\nRunning Power Sampling with visible steps...")
response = power_sample(
model=model,
tokenizer=tokenizer,
prompt=prompt,
alpha=4.0,
steps=5,
block_size=24,
max_len=150,
show_progress=True # This will show the progress
)
print(f"\n✅ Final Improved Response:")
print(f" {response}")
def compare_different_prompts(model, tokenizer):
"""Compare Power Sampling on different types of prompts"""
print_header("4. Power Sampling on Different Question Types")
prompts = [
"What is machine learning?",
"How does photosynthesis work?",
"Why do we dream?",
"Explain gravity simply"
]
for i, prompt in enumerate(prompts, 1):
print(f"\n📍 Test {i}: {prompt}")
# Get standard response
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=60,
temperature=1.0,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
standard = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Get Power Sampling response
improved = power_sample(
model=model,
tokenizer=tokenizer,
prompt=prompt,
alpha=4.0,
steps=3,
block_size=20,
max_len=100,
show_progress=False
)
print_comparison(standard, improved, prompt)
def main():
print(f"\n🔄 Loading model from local directory...")
model_path = "EleutherAI/gpt-neo-125M" # Fallback to online model
print(f"🔄 Using online model: {model_path}")
print(f"🚀 Loading model from: {model_path}")
model, tokenizer = load_model_and_tokenizer(model_path)
print("✅ Model loaded successfully from local directory!")
try:
# Show standard generation
standard_response = demonstrate_basic_generation(model, tokenizer)
# Show Power Sampling
demonstrate_power_sampling(model, tokenizer)
except Exception as e:
print(f"Error during demonstration: {e}")
print("This is normal - LLM generation can sometimes fail.")
print("Try running the demo again!")
if __name__ == "__main__":
main()
3.2 输出分析
1)标准输出
Standard LLM Generation (Without Power Sampling)
What is artificial intelligence?
I’m currently writing on a topic I’d be happy to discuss. AI makes life easier and gives you the best possible information for your needs. The first thing you do when you are using AI, is to set up a database of your data, as it will show you the current state of the world and what is happening. The result is that with more and more
由于是125M的小GPT模型EleutherAI/gpt-neo-125M,所以输出不仅重复问题,而且逻辑性不强。
2)幂采样输出
α=2.0, 4.0, 6.0时, Power Sampling输出如下所示。
α=2.0, 4.0,输出开始部分依然有些发散,没有抓住注定。
α=6.0,输出很快就抓住了重点,提到了AI和技术进步。
--- Alpha = 2.0 (Sharpening Factor) ---
Response: What is artificial intelligence? What are the uses for artificial intelligence?"There are several uses for the artificial intelligence. A good example is a new device, called a "brainwave
--- Alpha = 4.0 (Sharpening Factor) ---
Response: What is artificial intelligence?.I find I’m often the “computer scientist” when I’m actually just thinking outside the box. When it says
--- Alpha = 6.0 (Sharpening Factor) ---
Response: What is artificial intelligence?In this page, you will find out about the latest emerging technologies used to support AI. The latest technological breakthroughs that have been reported in the past
3.3 完整输出
以下是示例程序的完整输出。
============================================================
🎯 1. Standard LLM Generation (Without Power Sampling)
============================================================
Prompt: What is artificial intelligence?
This shows how a normal language model generates text...
Response: What is artificial intelligence?I’m currently writing on a topic I’d be happy to discuss. AI makes life easier and gives you the best possible information for your needs. The first thing you do when you are using AI, is to set up a database of your data, as it will show you the current state of the world and what is happening. The result is that with more and more
============================================================
🎯 2. Power Sampling in Action
============================================================
Same prompt: What is artificial intelligence?
Power Sampling applies Metropolis-Hastings to improve coherence...🔧 Testing different Power Sampling parameters:
--- Alpha = 2.0 (Sharpening Factor) ---
Response: What is artificial intelligence? What are the uses for artificial intelligence?"There are several uses for the artificial intelligence. A good example is a new device, called a "brainwave
--- Alpha = 4.0 (Sharpening Factor) ---
Response: What is artificial intelligence?.I find I’m often the “computer scientist” when I’m actually just thinking outside the box. When it says
--- Alpha = 6.0 (Sharpening Factor) ---
Response: What is artificial intelligence?In this page, you will find out about the latest emerging technologies used to support AI. The latest technological breakthroughs that have been reported in the past
reference
---
Power-Sampling-Training
https://github.com/BitMakerMan/Power-Sampling-Training
power_sampling
https://github.com/BitMakerMan/Power-Sampling-Training/blob/main/src/power_sampling.py
understant_power_sampling
reasoning-with-sampling
https://github.com/aakaran/reasoning-with-sampling
Reasoning with Sampling: Your Base Model is Smarter Than You Think
https://arxiv.org/abs/2510.14901
Language Models are Injective and Hence Invertible
https://arxiv.org/abs/2510.15511
如何理解MCMC的延续 M-H采样
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)