目前(25.11.25),幂采样被认为能媲美RL调优的LLM推理优化技术。

RL调优模型使其输出更符合人类认知,效果相当于对LLM的分布进行锐化。

幂分布是典型的自然锐化分布。

以下论文显示,幂采样,不需要RL调优,也能有效优化LLM的推理输出。

Reasoning with Sampling: Your Base Model is Smarter Than You Think

https://arxiv.org/abs/2510.14901

这里通过解读其实现代码,尝试探索和分析其推理优化过程。

https://github.com/BitMakerMan/Power-Sampling-Training

1 幂采样

幂采样,是一种通过M-H采样提升LLM推理能力的算法。

Power Sampling is an algorithm that improves LLM reasoning using the
Metropolis-Hastings method (a technique from computational statistics).

1.1 幂采样接受拒绝概率计算

以下是幂采样中M-H接受拒绝概率的计算公式。

P(accept) = min(1, exp(α * (log P_proposal - log P_current)))

α表示锐化参数,通常取值在2.0到6.0之间。

P_proposal和P_current分布表示提议分布概率和当前分布的概率。

提议分布是在当前分布的基础上,通过优化的来。

1.2 幂采样分块采样过程

幂采样优化过程,其实就是LLM的多次推理过程,示例如下。

1)采样初始化

LLM依据prompt生成初始输出init_output,这是第一个版本。

2)文本分块和提议得分计算

将初始文本划分为多个分块,对于每个分块

分别计算当前版本的概率分数,提议分布输出对应文本的概率分数。

3)Metropolis-Hasting接受拒绝采样

如果提议分布获得更高的分数,则更有可能接收提议分布。

偶尔也会接收更差的结果,目的是防止优化路径陷入局部最佳区域。

4)多次重复采样过程,每次仅优化block内文本。

一般情况下迭代次数越多生成的质量越好。

2 采样代码示例

这里通过梳理幂采样代码,探索和分析其优化过程。

2.1 提议分布生成

使用提议分布采样文本,其实就是嗲用model以sample方式重新生成文本的过程

首先,输入来自于当前文本截断,current_seq[:start_pos]

其次,设置do_sample=True,提议分布采样文本,即为调用model生成proposal_block

然后,将current_seq[:start_pos]和proposal_block拼接,组成新采样的文本proposal_seq。

以下是采样代码示例。

            # Create context (prefix)
            if start_pos > 0:
                context_ids = current_seq[:start_pos].unsqueeze(0)
            else:
                # Use original prompt as context if we're at the beginning
                context_ids = inputs['input_ids']

            # Generate proposal block
            with torch.no_grad():
                proposal_block = self.model.generate(
                    context_ids,
                    max_new_tokens=block_size,
                    do_sample=True,
                    temperature=temperature,
                    pad_token_id=self.tokenizer.eos_token_id,
                    num_return_sequences=1
                )

            # Extract only the newly generated part
            if start_pos == 0:
                new_tokens = proposal_block[0][context_ids.size(-1):]
            else:
                new_tokens = proposal_block[0][context_ids.size(-1):]

            # Create proposal sequence
            proposal_seq = torch.cat([
                current_seq[:start_pos],
                new_tokens[:block_size],
                current_seq[end_pos:] if end_pos < current_seq.size(-1) else torch.tensor([], device=self.device)
            ])

2.2 分布得分计算 

分布得分计算是一次调用模型计算输入和输出的对数似然的过程。

首先,输入为待计算分布得分的的id序列inputs,比如current_seq、proposal_seq。

其次,分布得分计算,就是计算模型输出当前id序列的概率,类似监督学习过程,具体为

输入inputs作为model的输入input_ids,inputs也作为输出labels的序列,对应输入和监督标签。

输出为input_ids和labels的负对数似然,因为得分为对数似然,所以取反,即-outputs.loss()。

示例代码如下

 def compute_log_probability(self, sequence: torch.Tensor) -> float:
        """
        Compute the log probability of a sequence under the model.

        Args:
            sequence: Token sequence tensor

        Returns:
            Log probability sum
        """
        with torch.no_grad():
            inputs = sequence.unsqueeze(0) if sequence.dim() == 1 else sequence
            attention_mask = torch.ones_like(inputs)

            outputs = self.model(
                input_ids=inputs,
                attention_mask=attention_mask,
                labels=inputs
            )

            return -outputs.loss.item() * inputs.size(-1)

2.3 M-H接受拒绝采样

M-H接受拒绝采样,其实就是决定是否接受新生成的propsal_seq的过程。

首先,分别结算原始序列current_seq和新生成序列proposal_seq的得分。

其次,利用幂采样接受拒绝公式,计算proposal_seq的接受概率。

P(accept) = min(1, exp(α * (log P_proposal - log P_current)))

然后,引入均匀分布,随机决定是否接受proposal_seq,具体为

如果 P(random) < P(accept),说明当前proposal_seq可以接受,否则拒绝。

需要注意的是,虽然P(accept)越高,越有可能接受,但不是绝对的。

偶尔也会接收更差的结果,目的是防止优化路径陷入局部最佳区域。

                log_p_current = self.compute_log_probability(current_seq)
                log_p_proposal = self.compute_log_probability(proposal_seq)

                # Metropolis acceptance criterion with sharpening
                log_acceptance_ratio = alpha * (log_p_proposal - log_p_current)
                acceptance_prob = min(1.0, torch.exp(torch.tensor(log_acceptance_ratio)).item())

                # Accept or reject proposal
                if random.random() < acceptance_prob:
                    current_seq = proposal_seq
                    if show_progress:
                        iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✓"})
                else:
                    if show_progress:
                        iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✗"})

2.4 整体采样过程

以下是合并以上部分,幂分布整体采样的过程代码。

import torch
import torch.nn.functional as F
import random
import numpy as np
from typing import Optional, Dict, Any
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


class PowerSampler:
    """
    Implements Power Sampling algorithm for improved LLM reasoning.

    Power Sampling uses Metropolis-Hastings algorithm to resample text blocks,
    improving logical coherence and reasoning capabilities.
    """

    def __init__(
        self,
        model: AutoModelForCausalLM,
        tokenizer: AutoTokenizer,
        device: Optional[str] = None
    ):
        """
        Initialize the Power Sampler.

        Args:
            model: The causal language model to use
            tokenizer: The tokenizer corresponding to the model
            device: Device to run on (auto-detect if None)
        """
        self.model = model
        self.tokenizer = tokenizer
        self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        # Set padding token if not present
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

    def compute_log_probability(self, sequence: torch.Tensor) -> float:
        """
        Compute the log probability of a sequence under the model.

        Args:
            sequence: Token sequence tensor

        Returns:
            Log probability sum
        """
        with torch.no_grad():
            inputs = sequence.unsqueeze(0) if sequence.dim() == 1 else sequence
            attention_mask = torch.ones_like(inputs)

            outputs = self.model(
                input_ids=inputs,
                attention_mask=attention_mask,
                labels=inputs
            )

            return -outputs.loss.item() * inputs.size(-1)

    def power_sample(
        self,
        prompt: str,
        alpha: float = 4.0,
        block_size: int = 192,
        steps: int = 10,
        max_len: int = 2048,
        temperature: float = 1.0,
        show_progress: bool = False
    ) -> str:
        """
        Execute Power Sampling algorithm.

        Args:
            prompt: Input prompt text
            alpha: Sharpening factor (recommended: 4.0)
            block_size: Size of text blocks to resample (recommended: T/16)
            steps: Number of Metropolis-Hastings iterations (recommended: 10)
            max_len: Maximum sequence length
            temperature: Sampling temperature
            show_progress: Show progress bar

        Returns:
            Generated text with improved reasoning
        """
        # Tokenize prompt
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=max_len - block_size
        ).to(self.device)

        # Generate initial sequence
        with torch.no_grad():
            initial_output = self.model.generate(
                **inputs,
                max_new_tokens=block_size,
                do_sample=True,
                temperature=temperature,
                pad_token_id=self.tokenizer.eos_token_id,
                num_return_sequences=1
            )

        # Extract generated sequence (remove prompt part)
        prompt_length = inputs['input_ids'].size(-1)
        current_seq = initial_output[0][prompt_length:]

        # Metropolis-Hastings iterations
        iterator = tqdm(range(steps), desc="Power Sampling") if show_progress else range(steps)

        for step in iterator:
            # Select random block to resample
            if current_seq.size(-1) <= block_size:
                # If sequence is shorter than block_size, regenerate whole thing
                start_pos = 0
                end_pos = current_seq.size(-1)
            else:
                start_pos = random.randint(0, current_seq.size(-1) - block_size)
                end_pos = start_pos + block_size

            # Create context (prefix)
            if start_pos > 0:
                context_ids = current_seq[:start_pos].unsqueeze(0)
            else:
                # Use original prompt as context if we're at the beginning
                context_ids = inputs['input_ids']

            # Generate proposal block
            with torch.no_grad():
                proposal_block = self.model.generate(
                    context_ids,
                    max_new_tokens=block_size,
                    do_sample=True,
                    temperature=temperature,
                    pad_token_id=self.tokenizer.eos_token_id,
                    num_return_sequences=1
                )

            # Extract only the newly generated part
            if start_pos == 0:
                new_tokens = proposal_block[0][context_ids.size(-1):]
            else:
                new_tokens = proposal_block[0][context_ids.size(-1):]

            # Create proposal sequence
            proposal_seq = torch.cat([
                current_seq[:start_pos],
                new_tokens[:block_size],
                current_seq[end_pos:] if end_pos < current_seq.size(-1) else torch.tensor([], device=self.device)
            ])

            # Ensure we don't exceed max length
            if proposal_seq.size(-1) > max_len - prompt_length:
                proposal_seq = proposal_seq[:max_len - prompt_length]
                current_seq = current_seq[:max_len - prompt_length]

            # Compute acceptance probability
            try:
                log_p_current = self.compute_log_probability(current_seq)
                log_p_proposal = self.compute_log_probability(proposal_seq)

                # Metropolis acceptance criterion with sharpening
                log_acceptance_ratio = alpha * (log_p_proposal - log_p_current)
                acceptance_prob = min(1.0, torch.exp(torch.tensor(log_acceptance_ratio)).item())

                # Accept or reject proposal
                if random.random() < acceptance_prob:
                    current_seq = proposal_seq
                    if show_progress:
                        iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✓"})
                else:
                    if show_progress:
                        iterator.set_postfix({"accept": f"{acceptance_prob:.3f}", "status": "✗"})

            except Exception as e:
                # If probability computation fails, keep current sequence
                if show_progress:
                    iterator.set_postfix({"error": str(e)[:20]})
                continue

        # Combine prompt and generated sequence
        final_sequence = torch.cat([inputs['input_ids'][0], current_seq])

        # Decode and return
        return self.tokenizer.decode(final_sequence, skip_special_tokens=True)

    def batch_power_sample(
        self,
        prompts: list,
        alpha: float = 4.0,
        block_size: int = 192,
        steps: int = 10,
        max_len: int = 2048,
        temperature: float = 1.0,
        show_progress: bool = False
    ) -> list:
        """
        Execute Power Sampling on multiple prompts.

        Args:
            prompts: List of input prompts
            alpha: Sharpening factor
            block_size: Size of text blocks to resample
            steps: Number of Metropolis-Hastings iterations
            max_len: Maximum sequence length
            temperature: Sampling temperature
            show_progress: Show progress bar

        Returns:
            List of generated texts
        """
        results = []
        iterator = tqdm(prompts, desc="Batch Power Sampling") if show_progress else prompts

        for prompt in iterator:
            result = self.power_sample(
                prompt=prompt,
                alpha=alpha,
                block_size=block_size,
                steps=steps,
                max_len=max_len,
                temperature=temperature,
                show_progress=False
            )
            results.append(result)

        return results


def load_model_and_tokenizer(model_name: str, device: Optional[str] = None):
    """
    Convenience function to load model and tokenizer.

    Args:
        model_name: Hugging Face model name
        device: Device to load on

    Returns:
        Tuple of (model, tokenizer)
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        dtype="auto",
        device_map="auto" if device == "cuda" else None,
        torch_dtype=torch.float16 if device == "cuda" else torch.float32
    )

    if device != "cuda":
        model = model.to(device)

    return model, tokenizer


def power_sample(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    alpha: float = 4.0,
    block_size: int = 192,
    steps: int = 10,
    max_len: int = 2048,
    temperature: float = 1.0,
    device: Optional[str] = None,
    show_progress: bool = False
) -> str:
    """
    Convenience function for single-call power sampling.

    Args:
        model: The causal language model
        tokenizer: The tokenizer
        prompt: Input prompt text
        alpha: Sharpening factor (recommended: 4.0)
        block_size: Size of text blocks to resample
        steps: Number of Metropolis-Hastings iterations
        max_len: Maximum sequence length
        temperature: Sampling temperature
        device: Device to run on
        show_progress: Show progress bar

    Returns:
        Generated text with improved reasoning
    """
    sampler = PowerSampler(model, tokenizer, device)
    return sampler.power_sample(
        prompt=prompt,
        alpha=alpha,
        block_size=block_size,
        steps=steps,
        max_len=max_len,
        temperature=temperature,
        show_progress=show_progress
    )

https://github.com/BitMakerMan/Power-Sampling-Training/blob/main/src/power_sampling.py

3 对比测试

这里通过运行对比测试,示例幂采样的生效过程。

所用代码仅用到torch、transformer等基础库,几乎不需要调试即可运行。

3.1 测试程序

这里对比标准输出和幂采样输出,示例程序如下所示。

测试问题: What is artificial intelligence?

这里通过调整锐化参数α=2.0, 4.0, 6.0,分别展示不同程度的幂采样优化。

import sys
import os

# Add src directory to path
sys.path.append(os.path.join("/data/apps/llm/Power-Sampling-Training", 'src'))

try:
    from power_sampling import load_model_and_tokenizer, power_sample
    print("✅ Successfully imported Power Sampling modules")
except ImportError as e:
    print(f"Import error: {e}")
    sys.exit(1)

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM


def print_header(title):
    """Print formatted header"""
    print(f"\n{'='*60}")
    print(f"🎯 {title}")
    print(f"{'='*60}")

def print_comparison(original, improved, prompt):
    """Print side-by-side comparison"""
    print(f"\n📝 PROMPT: {prompt}")
    print("-" * 60)

    print(f"🔸 STANDARD GENERATION:")
    print(f"   {original}")
    print()

    print(f"✨ POWER SAMPLING (Improved):")
    print(f"   {improved}")
    print()

    print("💡 IMPROVEMENTS:")

    # Simple analysis of improvements
    if len(improved) > len(original):
        print("   ✓ More detailed response")
    if "because" in improved.lower() or "therefore" in improved.lower():
        print("   ✓ Better logical connections")
    if "step" in improved.lower() or "first" in improved.lower():
        print("   ✓ Structured reasoning")
    if improved.count('.') > original.count('.'):
        print("   ✓ More complete sentences")

    print("-" * 60)

def demonstrate_basic_generation(model, tokenizer):
    """Show basic LLM generation without Power Sampling"""
    print_header("1. Standard LLM Generation (Without Power Sampling)")

    prompt = "What is artificial intelligence?"
    print(f"Prompt: {prompt}")
    print("This shows how a normal language model generates text...")

    # Standard generation
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=80,
            temperature=1.0,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    standard_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(f"Response: {standard_response}")

    return standard_response

def demonstrate_power_sampling(model, tokenizer):
    """Show Power Sampling in action"""
    print_header("2. Power Sampling in Action")

    prompt = "What is artificial intelligence?"
    print(f"Same prompt: {prompt}")
    print("Power Sampling applies Metropolis-Hastings to improve coherence...")

    # Power Sampling with different parameters
    print("\n🔧 Testing different Power Sampling parameters:")

    for alpha in [2.0, 4.0, 6.0]:
        print(f"\n--- Alpha = {alpha} (Sharpening Factor) ---")

        response = power_sample(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            alpha=alpha,
            steps=3,
            block_size=32,
            max_len=120,
            show_progress=False
        )

        print(f"Response: {response[:200]}..." if len(response) > 200 else f"Response: {response}")

def demonstrate_step_by_step(model, tokenizer):
    """Show how Power Sampling works step by step"""
    print_header("3. How Power Sampling Works - Step by Step")

    prompt = "Explain why the sky is blue"
    print(f"Prompt: {prompt}")
    print("\n🔄 Power Sampling Process:")
    print("1. Generate initial text")
    print("2. Divide text into blocks")
    print("3. Propose alternatives for each block")
    print("4. Accept/reject based on probability improvement")
    print("5. Repeat for multiple iterations")

    # Run with progress to show the steps
    print(f"\nRunning Power Sampling with visible steps...")

    response = power_sample(
        model=model,
        tokenizer=tokenizer,
        prompt=prompt,
        alpha=4.0,
        steps=5,
        block_size=24,
        max_len=150,
        show_progress=True  # This will show the progress
    )

    print(f"\n✅ Final Improved Response:")
    print(f"   {response}")

def compare_different_prompts(model, tokenizer):
    """Compare Power Sampling on different types of prompts"""
    print_header("4. Power Sampling on Different Question Types")

    prompts = [
        "What is machine learning?",
        "How does photosynthesis work?",
        "Why do we dream?",
        "Explain gravity simply"
    ]

    for i, prompt in enumerate(prompts, 1):
        print(f"\n📍 Test {i}: {prompt}")

        # Get standard response
        inputs = tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=60,
                temperature=1.0,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        standard = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Get Power Sampling response
        improved = power_sample(
            model=model,
            tokenizer=tokenizer,
            prompt=prompt,
            alpha=4.0,
            steps=3,
            block_size=20,
            max_len=100,
            show_progress=False
        )

        print_comparison(standard, improved, prompt)

    
def main():
    print(f"\n🔄 Loading model from local directory...")
    model_path = "EleutherAI/gpt-neo-125M"  # Fallback to online model
    print(f"🔄 Using online model: {model_path}")
    print(f"🚀 Loading model from: {model_path}")
    model, tokenizer = load_model_and_tokenizer(model_path)
    print("✅ Model loaded successfully from local directory!")

    try:
        # Show standard generation
        standard_response = demonstrate_basic_generation(model, tokenizer)

        # Show Power Sampling
        demonstrate_power_sampling(model, tokenizer)        
    except Exception as e:
        print(f"Error during demonstration: {e}")
        print("This is normal - LLM generation can sometimes fail.")
        print("Try running the demo again!")

if __name__ == "__main__":
    main()

https://github.com/BitMakerMan/Power-Sampling-Training/blob/main/examples/understand_power_sampling.py

3.2 输出分析

1)标准输出

Standard LLM Generation (Without Power Sampling)

What is artificial intelligence?

I’m currently writing on a topic I’d be happy to discuss. AI makes life easier and gives you the best possible information for your needs. The first thing you do when you are using AI, is to set up a database of your data, as it will show you the current state of the world and what is happening. The result is that with more and more

由于是125M的小GPT模型EleutherAI/gpt-neo-125M,所以输出不仅重复问题,而且逻辑性不强。

2)幂采样输出

α=2.0, 4.0, 6.0时, Power Sampling输出如下所示。

α=2.0, 4.0,输出开始部分依然有些发散,没有抓住注定。

α=6.0,输出很快就抓住了重点,提到了AI和技术进步。

--- Alpha = 2.0 (Sharpening Factor) ---
Response: What is artificial intelligence? What are the uses for artificial intelligence?"

There are several uses for the artificial intelligence. A good example is a new device, called a "brainwave

--- Alpha = 4.0 (Sharpening Factor) ---
Response: What is artificial intelligence?.

I find I’m often the “computer scientist” when I’m actually just thinking outside the box. When it says

--- Alpha = 6.0 (Sharpening Factor) ---
Response: What is artificial intelligence?

In this page, you will find out about the latest emerging technologies used to support AI. The latest technological breakthroughs that have been reported in the past

3.3 完整输出

以下是示例程序的完整输出。

============================================================
🎯 1. Standard LLM Generation (Without Power Sampling)
============================================================
Prompt: What is artificial intelligence?
This shows how a normal language model generates text...
Response: What is artificial intelligence?

I’m currently writing on a topic I’d be happy to discuss. AI makes life easier and gives you the best possible information for your needs. The first thing you do when you are using AI, is to set up a database of your data, as it will show you the current state of the world and what is happening. The result is that with more and more

============================================================
🎯 2. Power Sampling in Action
============================================================
Same prompt: What is artificial intelligence?
Power Sampling applies Metropolis-Hastings to improve coherence...

🔧 Testing different Power Sampling parameters:

--- Alpha = 2.0 (Sharpening Factor) ---
Response: What is artificial intelligence? What are the uses for artificial intelligence?"

There are several uses for the artificial intelligence. A good example is a new device, called a "brainwave

--- Alpha = 4.0 (Sharpening Factor) ---
Response: What is artificial intelligence?.

I find I’m often the “computer scientist” when I’m actually just thinking outside the box. When it says

--- Alpha = 6.0 (Sharpening Factor) ---
Response: What is artificial intelligence?

In this page, you will find out about the latest emerging technologies used to support AI. The latest technological breakthroughs that have been reported in the past

reference

---

Power-Sampling-Training

https://github.com/BitMakerMan/Power-Sampling-Training

power_sampling

https://github.com/BitMakerMan/Power-Sampling-Training/blob/main/src/power_sampling.py

understant_power_sampling

https://github.com/BitMakerMan/Power-Sampling-Training/blob/main/examples/understand_power_sampling.py

reasoning-with-sampling

https://github.com/aakaran/reasoning-with-sampling

Reasoning with Sampling: Your Base Model is Smarter Than You Think

https://arxiv.org/abs/2510.14901

Language Models are Injective and Hence Invertible

https://arxiv.org/abs/2510.15511

如何理解MCMC的延续 M-H采样

https://blog.csdn.net/liliang199/article/details/154741874

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐