Baukit库使用教程--监督和修改LLM中间层输出
Baukit库使用教程--监控和修改模型输入输出示例
·
Baukit库使用教程–监督和修改LLM中间层输出
- 原始项目地址:https://github.com/davidbau/baukit
1. TraceDict概述
TraceDict(model, layers=target_layers, retain_input=True, retain_output=True, edit_output=None, stop=False)
layers:想要监控的层- 一个list,里面元素为
model.layer.{layer_num}.{specific_module_name}
- 一个list,里面元素为
retain_input/output:是否需要保存目标层的原始输入和输出edit_output:若需要编辑模型某层输出,则这个参数为函数名- 这个函数需要可以接受
(output , layer)两个参数的传递
- 这个函数需要可以接受
stop:运行完监督的层后,模型停止运行
2. TraceDict使用方法
with TraceDict(model, layers=target_layers, retain_input=True, retain_output=True, edit_output=None, stop=False) as td:
2.1 监控输入输出
def monitor_layers(self, input_text: str, hook_layers: List[str], target_layer_idx: int = 5):
input_tokens = self.tokenizer(input_text,return_tensors='pt').to(self.device)
results = {}
with TraceDict(self.model, layers=hook_layers, retain_input=True, retain_output=True) as td:
model_output = self.model(**input_tokens)
target_layer = hook_layers[target_layer_idx]
before_layer = td[target_layer].input # 目标层之前的输入
after_layer = td[target_layer].output # 目标层之后的输出
# after_attn=td['model.layers.5.self_attn.q_proj'].output # 会报错,因为hook_layers里不包含self_attn.q_proj
results = {
'before_layer_shape': before_layer.shape,
'after_layer_shape': after_layer.shape,
}
return results
2.2 修改模型输入输出
-
需要定义一个修改输出的函数。这个函数重点是里边的函数
-
按照baukit的规定,里面的函数必须接受output和layer_name两个参数:
- output参数(模型中间层的输出)
- layer_name(表示当前前传到模型的哪一模块了)
-
外面封装的函数参数可以随便定义,只要最终返回值是里面的函数即可
def wrap_func(edit_layer, device, idx=-1): def add_func(output, layer_name): current_layer = int(layer_name.split(".")[2]) if current_layer == edit_layer: # 遍历到edit_layer print("output_sum",output.sum()) # 创建与output相同形状的扰动 perturbation = torch.randn_like(output) * 0.1 output += perturbation.to(device) print("output_sum",output.sum()) return output return add_func
2.3 完整监控和修改文件
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import TraceDict
from typing import List, Dict, Any
class ModelHandler:
def __init__(self, model_path: str, device: str = "cuda"):
self.model_path = model_path
self.device = device
self.model = None
self.tokenizer = None
def load_model(self):
print(f"Loading model: {self.model_path}")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path,
torch_dtype=torch.float16
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
print("Model loaded!")
return self.model, self.tokenizer
def monitor_layers(self, input_text: str, hook_layers: List[str], target_layer_idx: int = 5):
input_tokens = self.tokenizer(input_text, return_tensors='pt').to(self.device)
results = {}
with TraceDict(self.model, layers=hook_layers, retain_input=True, retain_output=True) as rep:
model_output = self.model(**input_tokens)
target_layer = hook_layers[target_layer_idx]
before_layer = rep[target_layer].input # 目标层之前的输入
after_layer = rep[target_layer].output # 目标层之后的输出
results = {
'before_layer_shape': before_layer.shape,
'after_layer_shape': after_layer.shape,
}
return results
def generate_hook_layers(layer_type: str, num_layers: int = 16):
layer_indices = list(range(num_layers))
if layer_type == "attn_o":
return [f'model.layers.{l}.self_attn.o_proj' for l in layer_indices]
elif layer_type == "attn_q":
return [f'model.layers.{l}.self_attn.q_proj' for l in layer_indices]
elif layer_type == "attn_k":
return [f'model.layers.{l}.self_attn.k_proj' for l in layer_indices]
elif layer_type == "attn_v":
return [f'model.layers.{l}.self_attn.v_proj' for l in layer_indices]
elif layer_type == "mlp_gate":
return [f'model.layers.{l}.mlp.gate_proj' for l in layer_indices]
elif layer_type == "mlp_up":
return [f'model.layers.{l}.mlp.up_proj' for l in layer_indices]
elif layer_type == "mlp_down":
return [f'model.layers.{l}.mlp.down_proj' for l in layer_indices]
else:
raise ValueError(f"Unsupported layer type: {layer_type}")
def wrap_func(edit_layer, device, idx=-1):
def add_func(output, layer_name):
current_layer = int(layer_name.split(".")[2])
if current_layer == edit_layer: # 遍历到edit_layer
print("output_sum",output.sum())
# 创建与output相同形状的扰动
perturbation = torch.randn_like(output) * 0.1
output += perturbation.to(device)
print("output_sum",output.sum())
return output
return add_func
def main():
parser = argparse.ArgumentParser(description='Model layer monitoring and editing tool')
parser.add_argument('--model_path', type=str, default='YOUR_MODEL_PATH',
help='Model path')
parser.add_argument('--input_text', type=str, default='Hello, how are you?',
help='Input text')
parser.add_argument('--layer_type', type=str, default='attn_o',
choices=['attn_o', 'attn_q', 'attn_k', 'attn_v', 'mlp_gate', 'mlp_up', 'mlp_down'],
help='Layer type to monitor')
parser.add_argument('--num_layers', type=int, default=16,
help='Total number of layers')
parser.add_argument('--target_layer_idx', type=int, default=5,
help='Layer index to analyze')
parser.add_argument('--mode', type=str, default='edit', choices=['monitor', 'edit'],
help='Running mode: monitor or edit')
parser.add_argument('--device', type=str, default='cuda', help='device')
args = parser.parse_args()
# Create model handler
handler = ModelHandler(args.model_path, args.device)
handler.load_model()
# Generate hook layers
hook_layers = generate_hook_layers(args.layer_type, args.num_layers)
print(f"Input text: {args.input_text}") # Input text
print(f"Monitor layer type: {args.layer_type}")
print(f"Monitor layer number: {len(hook_layers)}")
print(f"Target layer index: {args.target_layer_idx}")
if args.mode == 'monitor':
# Monitor mode
results = handler.monitor_layers(args.input_text, hook_layers, args.target_layer_idx)
print(f"\n=== Monitor results ===")
print(f"Target layer: {hook_layers[args.target_layer_idx]}")
print(f"Layer input shape: {results['before_layer_shape']}") # Dimension is [bsz, num_tokens, dim_model]
print(f"Layer output shape: {results['after_layer_shape']}") # Dimension is [bsz, num_tokens, dim_model]
elif args.mode == 'edit':
# Edit mode example: add noise to the target layer
intervention_fn = wrap_func(args.target_layer_idx, handler.model.device)
hook_layers = [f'model.layers.{l}.self_attn.o_proj' for l in range(args.num_layers)]
with TraceDict(handler.model, layers=hook_layers, edit_output=intervention_fn):
input_tokens = handler.tokenizer(args.input_text, return_tensors='pt').to(handler.device)
handler.model(**input_tokens)
if __name__ == "__main__":
main()
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)