简介

下述代码参考自:https://github.com/hiyouga/LLaMA-Factory/blob/main/scripts/vllm_infer.py

以前一直都用的llamafactory的 vllm_infer.py的推理脚本。使用 llamafactory微调完模型,再使用它的vllm_infer.py 脚本做推理预测。

但是使用过程中,遇到了两个痛点,所以我自己编写了下述代码:

  1. llamafactory多个数据集注册,人工逐个注册很麻烦。因为llamafactory用到的数据集都要在 dataset_info.json文件中进行注册。
  2. 每次数据集切换,LLM要重新重复加载;明明只是同一个模型,只是切换了数据集,该脚本就要重新加载模型。

我实现的代码,实现给 vllm_infer 函数传递 llm,而不是在 vllm_infer函数里面定义llm,这样就不需要重复创建llm。

代码实现

import os
import json


def vllm_infer(llm, sampling_params, dataset_file, output_file):
    # 加载模型
    # 遍历 jsonl 文件
    prompts = []
    labels = []
    with open(dataset_file, "r", encoding="utf-8") as f:
        for line in f:
            item = json.loads(line)
            prompt = item["instruction"] + item.get("input", "")
            prompts.append(prompt)
            labels.append(item["output"])

    # 批量推理
    results = llm.generate(prompts, sampling_params)

    preds = [result.outputs[0].text for result in results]

    with open(output_file, "w", encoding="utf-8") as f:
        for text, pred, label in zip(prompts, preds, labels):
            f.write(
                json.dumps(
                    {"prompt": text, "predict": pred, "label": label},
                    ensure_ascii=False,
                )
                + "\n"
            )

        print("*" * 70)
        print(
            f"{len(prompts)} total generated results have been saved at {output_file}."
        )
        print("*" * 70)

import os
import time
from vllm import LLM, SamplingParams
from vllm_infer import vllm_infer

output_dir = "output"
os.makedirs(output_dir, exist_ok=True)

model_dir = (
    "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8"
)

dataset_dir = "data/alpaca_dataset"


llm = LLM(
    model=model_dir, max_model_len=2048, gpu_memory_utilization=0.95, dtype="float16"
)

# 设置采样参数
sampling_params = SamplingParams(
    presence_penalty=1.0,
    repetition_penalty=1.0,
    temperature=0.7,
    top_p=0.8,
    top_k=20,
    max_tokens=256,
    skip_special_tokens=True,
    min_p=0,
)

for file in os.listdir(dataset_dir):
    output_file = os.path.join(output_dir, os.path.basename(file))
    if file.endswith(".jsonl") and not os.path.exists(output_file):
        dataset_file = os.path.join(dataset_dir, file)
        vllm_infer(llm, sampling_params, dataset_file, output_file)
        # time.sleep(300)
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐