vllm 在多个数据集上进行推理
本文针对LLaMA-Factory的vllm推理脚本存在的两个痛点进行了优化:1)多数据集注册繁琐问题;2)切换数据集需重复加载模型问题。通过重构vllm_infer函数,将LLM对象作为参数传入而非内部创建,实现了模型单次加载多次使用的优化方案。代码实现包含两部分:1)vllm_infer.py负责处理数据集加载、批量推理和结果保存;2)主脚本完成模型初始化、参数设置并遍历数据集进行推理。该方案
·
简介
下述代码参考自:https://github.com/hiyouga/LLaMA-Factory/blob/main/scripts/vllm_infer.py
以前一直都用的llamafactory的 vllm_infer.py的推理脚本。使用 llamafactory微调完模型,再使用它的vllm_infer.py 脚本做推理预测。
但是使用过程中,遇到了两个痛点,所以我自己编写了下述代码:
- llamafactory多个数据集注册,人工逐个注册很麻烦。因为llamafactory用到的数据集都要在 dataset_info.json文件中进行注册。
- 每次数据集切换,LLM要重新重复加载;明明只是同一个模型,只是切换了数据集,该脚本就要重新加载模型。
我实现的代码,实现给 vllm_infer 函数传递 llm,而不是在 vllm_infer函数里面定义llm,这样就不需要重复创建llm。
代码实现
import os
import json
def vllm_infer(llm, sampling_params, dataset_file, output_file):
# 加载模型
# 遍历 jsonl 文件
prompts = []
labels = []
with open(dataset_file, "r", encoding="utf-8") as f:
for line in f:
item = json.loads(line)
prompt = item["instruction"] + item.get("input", "")
prompts.append(prompt)
labels.append(item["output"])
# 批量推理
results = llm.generate(prompts, sampling_params)
preds = [result.outputs[0].text for result in results]
with open(output_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(prompts, preds, labels):
f.write(
json.dumps(
{"prompt": text, "predict": pred, "label": label},
ensure_ascii=False,
)
+ "\n"
)
print("*" * 70)
print(
f"{len(prompts)} total generated results have been saved at {output_file}."
)
print("*" * 70)
import os
import time
from vllm import LLM, SamplingParams
from vllm_infer import vllm_infer
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
model_dir = (
"Qwen/Qwen3-30B-A3B-Instruct-2507-FP8"
)
dataset_dir = "data/alpaca_dataset"
llm = LLM(
model=model_dir, max_model_len=2048, gpu_memory_utilization=0.95, dtype="float16"
)
# 设置采样参数
sampling_params = SamplingParams(
presence_penalty=1.0,
repetition_penalty=1.0,
temperature=0.7,
top_p=0.8,
top_k=20,
max_tokens=256,
skip_special_tokens=True,
min_p=0,
)
for file in os.listdir(dataset_dir):
output_file = os.path.join(output_dir, os.path.basename(file))
if file.endswith(".jsonl") and not os.path.exists(output_file):
dataset_file = os.path.join(dataset_dir, file)
vllm_infer(llm, sampling_params, dataset_file, output_file)
# time.sleep(300)
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)