Baukit库使用教程–监督和修改LLM中间层输出

  • 原始项目地址:https://github.com/davidbau/baukit

1. TraceDict概述

TraceDict(model, layers=target_layers, retain_input=True, retain_output=True, edit_output=None, stop=False)
  • layers:想要监控的层
    • 一个list,里面元素为model.layer.{layer_num}.{specific_module_name}
  • retain_input/output:是否需要保存目标层的原始输入和输出
  • edit_output:若需要编辑模型某层输出,则这个参数为函数名
    • 这个函数需要可以接受 (output , layer) 两个参数的传递
  • stop:运行完监督的层后,模型停止运行

2. TraceDict使用方法

with TraceDict(model, layers=target_layers, retain_input=True, retain_output=True, edit_output=None, stop=False) as td:

2.1 监控输入输出

 def monitor_layers(self, input_text: str, hook_layers: List[str], target_layer_idx: int = 5):
 	input_tokens = self.tokenizer(input_text,return_tensors='pt').to(self.device)
 	results = {}
 	with TraceDict(self.model, layers=hook_layers, retain_input=True, retain_output=True) as td:
 		model_output = self.model(**input_tokens)
 		target_layer = hook_layers[target_layer_idx]
     before_layer = td[target_layer].input  # 目标层之前的输入
     after_layer = td[target_layer].output  # 目标层之后的输出
     # after_attn=td['model.layers.5.self_attn.q_proj'].output # 会报错,因为hook_layers里不包含self_attn.q_proj
 		results = {
 			'before_layer_shape': before_layer.shape,
 			'after_layer_shape': after_layer.shape,
 		}
 	return results

2.2 修改模型输入输出

  • 需要定义一个修改输出的函数。这个函数重点是里边的函数

  • 按照baukit的规定,里面的函数必须接受output和layer_name两个参数:

    • output参数(模型中间层的输出)
    • layer_name(表示当前前传到模型的哪一模块了)
  • 外面封装的函数参数可以随便定义,只要最终返回值是里面的函数即可

          def wrap_func(edit_layer, device, idx=-1):
            def add_func(output, layer_name):
                current_layer = int(layer_name.split(".")[2])
                if current_layer == edit_layer: # 遍历到edit_layer
                    print("output_sum",output.sum())
                    # 创建与output相同形状的扰动
                    perturbation = torch.randn_like(output) * 0.1
                    output += perturbation.to(device)
                    print("output_sum",output.sum())
                return output
            return add_func
    

2.3 完整监控和修改文件

import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import TraceDict
from typing import List, Dict, Any

class ModelHandler:
    
    def __init__(self, model_path: str, device: str = "cuda"):
        self.model_path = model_path
        self.device = device
        self.model = None
        self.tokenizer = None
    
    def load_model(self):
        print(f"Loading model: {self.model_path}")
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_path, 
            torch_dtype=torch.float16
        ).to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        print("Model loaded!")
        return self.model, self.tokenizer
    
    def monitor_layers(self, input_text: str, hook_layers: List[str], target_layer_idx: int = 5):
        
        input_tokens = self.tokenizer(input_text, return_tensors='pt').to(self.device)
        
        results = {}
        with TraceDict(self.model, layers=hook_layers, retain_input=True, retain_output=True) as rep:
            model_output = self.model(**input_tokens)
            
            target_layer = hook_layers[target_layer_idx]
            before_layer = rep[target_layer].input  # 目标层之前的输入
            after_layer = rep[target_layer].output  # 目标层之后的输出
            
            results = {
                'before_layer_shape': before_layer.shape,
                'after_layer_shape': after_layer.shape,
            }

        return results
    
        

def generate_hook_layers(layer_type: str, num_layers: int = 16):
    
    layer_indices = list(range(num_layers))
    if layer_type == "attn_o":
        return [f'model.layers.{l}.self_attn.o_proj' for l in layer_indices]
    elif layer_type == "attn_q":
        return [f'model.layers.{l}.self_attn.q_proj' for l in layer_indices]
    elif layer_type == "attn_k":
        return [f'model.layers.{l}.self_attn.k_proj' for l in layer_indices]
    elif layer_type == "attn_v":
        return [f'model.layers.{l}.self_attn.v_proj' for l in layer_indices]
    elif layer_type == "mlp_gate":
        return [f'model.layers.{l}.mlp.gate_proj' for l in layer_indices]
    elif layer_type == "mlp_up":
        return [f'model.layers.{l}.mlp.up_proj' for l in layer_indices]
    elif layer_type == "mlp_down":
        return [f'model.layers.{l}.mlp.down_proj' for l in layer_indices]
    else:
        raise ValueError(f"Unsupported layer type: {layer_type}")

def wrap_func(edit_layer, device, idx=-1):
    def add_func(output, layer_name):
        current_layer = int(layer_name.split(".")[2])
        if current_layer == edit_layer: # 遍历到edit_layer
            print("output_sum",output.sum())
            # 创建与output相同形状的扰动
            perturbation = torch.randn_like(output) * 0.1
            output += perturbation.to(device)
            print("output_sum",output.sum())
        return output
    return add_func


def main():
    parser = argparse.ArgumentParser(description='Model layer monitoring and editing tool')
    parser.add_argument('--model_path', type=str, default='YOUR_MODEL_PATH', 
                       help='Model path')
    parser.add_argument('--input_text', type=str, default='Hello, how are you?', 
                       help='Input text')
    parser.add_argument('--layer_type', type=str, default='attn_o', 
                       choices=['attn_o', 'attn_q', 'attn_k', 'attn_v', 'mlp_gate', 'mlp_up', 'mlp_down'],
                       help='Layer type to monitor')
    parser.add_argument('--num_layers', type=int, default=16, 
                       help='Total number of layers')
    parser.add_argument('--target_layer_idx', type=int, default=5, 
                       help='Layer index to analyze')
    parser.add_argument('--mode', type=str, default='edit', choices=['monitor', 'edit'],
                       help='Running mode: monitor or edit')
    parser.add_argument('--device', type=str, default='cuda', help='device')
    
    args = parser.parse_args()
    
    # Create model handler
    handler = ModelHandler(args.model_path, args.device)
    handler.load_model()
    
    # Generate hook layers
    hook_layers = generate_hook_layers(args.layer_type, args.num_layers)
    
    print(f"Input text: {args.input_text}")  # Input text
    print(f"Monitor layer type: {args.layer_type}")
    print(f"Monitor layer number: {len(hook_layers)}")
    print(f"Target layer index: {args.target_layer_idx}")
    
    if args.mode == 'monitor':
        # Monitor mode
        results = handler.monitor_layers(args.input_text, hook_layers, args.target_layer_idx)
        
        print(f"\n=== Monitor results ===")
        print(f"Target layer: {hook_layers[args.target_layer_idx]}")
        print(f"Layer input shape: {results['before_layer_shape']}")  # Dimension is [bsz, num_tokens, dim_model]
        print(f"Layer output shape: {results['after_layer_shape']}")  # Dimension is [bsz, num_tokens, dim_model]
        
    elif args.mode == 'edit':
        # Edit mode example: add noise to the target layer
        intervention_fn = wrap_func(args.target_layer_idx, handler.model.device)

        hook_layers = [f'model.layers.{l}.self_attn.o_proj' for l in range(args.num_layers)]
        with TraceDict(handler.model, layers=hook_layers, edit_output=intervention_fn):
            input_tokens = handler.tokenizer(args.input_text, return_tensors='pt').to(handler.device)
            handler.model(**input_tokens)

if __name__ == "__main__":
    main()
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐