大模型RAG系统的多路召回检索方法以及优化
基于大模型私有化的RAG多路召回检索设计以及系统优化
·
写在前文
本文包含两个方向,一个流程是txt数据加载/预处理/向量并入库、一个流程为多路召回检索流程,文末也会给出多模态(图片)的检索方案,本文核心技术包含如下:
核心框架:Langchain
向量库(FAISS可替换Chromadb/Qdrant/LanceDB):FAISS、Elasticsearch
Embedding模型(二选一):nomic-embed-text、bge系列
LLMs模型选择(任意一个):Deepseek/Qwen/ChatGLM/LLama
多路召回的算法---可以定制多个算法:基于ES的BM25/基于FAISS的BM25(二选一)、向量检索(MMR、similarity_score_threshold、similarity三选一)
模型私有化(三选一):
本地模型服务化部署:Ollama
本地模型Local加载:使用原生的Transformers/HuggingFace
远程模型服务化部署:vLLM
重排序模型(二选一):采用Langchain的FlashrankRerank/bge-rerank
使用服务化API的方法(优先OpenAI):基于Langchain的OpenAI/Ollama
Chain的生成,我采用新版本LCEL管道拼接/RetrievalQA模块构建QA链
对外暴露接口:FastAPI
客户端采用:Java服务
图形界面采用:VUE
注意:本文单独的某一条线我测试通过,但是没有混合测试(比如,我可能选择了vllm+FlashrankRerank,但是关于bge-rerank我可能选择了ollama+bge-rerank);总之,没有测试完所有组合----实在太多,不过每个技术替换我都单独测试通过了;
本文默认已经安装好了vLLM、Ollama等环境并且已经下载好了LLMs模型以及Embedding模型(如果不清楚的,我之前有更新具体的安装环境)
文档加载&预处理类BaseBuilder
class BaseBuilder:
from langchain_core.documents import Document
from typing import List, Dict, Any
# step1 加载文档
def load_txt(self, file_path) -> list[Document]:
from langchain_community.document_loaders import TextLoader
return TextLoader(file_path=file_path, encoding='utf-8').load()
def load_pdf(self):
pass
def load_web(self):
pass
def load_mkdown(self):
pass
def load_docs(self, stop_word_file_path) -> list[Document]:
"""从txt文档中加载documents----txt文档内容要符合doc格式"""
with open(stop_word_file_path, 'r', encoding='utf-8') as f:
text_data = f.read()
return self.__parse_documents(text_data)
def load_documents(self,
store_file_path,
stop_word_file_path: str = None,
chune_size: int = 512,
merge: bool = True,
step_log: bool = True,
):
"""
:param store_file_path: 要解析的txt文档路径
:param stop_word_file_path: 停用词文件路径
:param chune_size: # 切分大小
:param merge: # 是否合并
:param step_log: # 是否保存step日志
:return:
"""
rows_data = self.load_txt(store_file_path)
stop_words = None
if stop_word_file_path:
stop_words = self.__stop_words(stop_word_file_path)
pre_text = self.__preprocess_text(rows_data[0].page_content, stop_words, merge=merge)
if step_log:
self.__doc2StepLogLocal(pre_text, filename=os.path.basename(store_file_path), theme="预处理日志")
documents = self.__chunk_documents(
[Document(page_content=pre_text, metadata=rows_data[0].metadata)],
chune_size=chune_size)
if step_log:
self.__doc2StepLogLocal(documents, filename=os.path.basename(store_file_path), theme="切分日志")
self.__doc2DB(documents, filename=os.path.basename(store_file_path))
return documents
# 加载停用词
def __stop_words(self, stop_word_file_path):
with open(stop_word_file_path, 'r', encoding='utf-8') as f:
stopwords = set(line.strip().lower() for line in f)
return stopwords
# step2 文本预处理
# 去除停用词 特殊符号
def __preprocess_text(self, text: str, stopwords: set, merge: bool = True) -> Union[str, List[str]]:
"""
优化后的文本预处理流程:
1. 统一换行符并替换特殊分隔符
2. 转换为全小写
3. 清除特殊符号和多余空白
4. 移除停用词(支持正则和分词两种模式)
# 目前统一按照“\n\n”进行分块。比如下面将分割为“xxxxx”、“BBBBB”...
xxxxx
xxxxx
BBBBB
BBBBB
CCCCC
CCCCC
"""
import re
# 阶段1:文本标准化
# 统一换行符为\n,替换连续换行符为分隔符
text = re.sub(r"\n{3,}", "SEGMENTATIONSYMBOL".lower(), text)
# 阶段2:符号处理
# 替换反斜杠并移除其他特殊符号
import string
text = (
text.replace("\\", '"')
.translate(str.maketrans('', '', string.punctuation))
)
# 阶段3:文本规范化
text = text.lower().strip()
# 阶段4:停用词处理
if stopwords:
stopwords_pattern = '|'.join(map(re.escape, stopwords))
text = re.sub(stopwords_pattern, '', text)
# 阶段5:是否合并Q&A
if merge:
text = text.replace("\n", "--answer--:")
return text
# 分词后去除停用词
def __cut(self, text: str, stopwords: set):
"""
推荐查询时用分词 、停用词;存库的时候不适合。
存库推荐只去除停用词即可,不然容易将语义分割。
:param text:
:param stopwords:
:return:
"""
import jieba
words = list(jieba.cut(text))
if stopwords:
return [word for word in words if word not in stopwords and word.strip()]
return words
# step3 切分文档
def __chunk_documents(self, raw_docs: List[Document], chune_size: int = 512) -> List[Document]:
"""优化后的文档分割"""
from langchain_text_splitters import RecursiveCharacterTextSplitter
# text_splitter = RecursiveCharacterTextSplitter(# 递归分割文本
from langchain_text_splitters import CharacterTextSplitter
text_splitter = CharacterTextSplitter( # 按字符分割
separator="SEGMENTATIONSYMBOL".lower(),
chunk_size=chune_size, # 如果使用QA时,最好使用0,如果使用一个块时,可以适当设置大小。
# 足够大的值以保证按分隔符分割。如果设置过大,系统会自动合并分割后小文本。对于每个分割后的片段,如果长度小于 _chunk_size,则加入 _good_splits 列表。如果长度大于 _chunk_size,则递归调用 _split_text 方法继续分割。
chunk_overlap=0, # 重复字符
keep_separator=False, # 是否保留分隔符
length_function=len,
is_separator_regex=False, # 启用正则模式
add_start_index=True, # 包含块的起始索引
strip_whitespace=True # 是否保留文档开始/结尾的空白
)
# 先分割文本,再动态生成元数据 --- 主要是为了设置不同的uid 如果直接设置到text_splitter.create_documents中可能会导致uid相同
split_docs = text_splitter.split_documents(raw_docs)
import time
from uuid import uuid4
# 为每个文档块生成独立的 metadata(包括 doc_id)
for doc in split_docs:
uid = str(uuid4()).replace('-', '')
doc.id = "文件名称_" + uid
doc.metadata['doc_id'] = "MySQL_" + uid,
doc.metadata['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()),
return split_docs
def __doc2DB(self,
documents: Union[List[Document]],
filename: str,
path: str = "./store/"):
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, f"入库_{filename.split(".")[0]}_{datetime.now().strftime("%Y%m%d_%H%M%S")}.txt"),
'w', encoding='utf-8') as f:
f.write(f"{documents}\n")
def __doc2StepLogLocal(self,
documents: Union[str, List[Document], Document],
filename: str, theme: str, path: str = "./step_logs/"):
"""
将文档处理步骤的日志保存到本地文件中,支持多种输入类型(str / Document / List[Document])
:param documents: 可以是字符串、单个 Document 或多个 Document 的列表
:param filename: 原始文件名,用于构造日志文件路径
:param theme: 当前操作主题(例如“预处理日志”、“切分日志”)
:param path: 日志保存的基础路径,默认为 ./step_logs/
"""
# 确保目标目录存在
log_dir = os.path.join(path, filename)
os.makedirs(log_dir, exist_ok=True)
# 构建日志文件路径
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f"{theme}_{timestamp}.log")
with open(log_file, "w", encoding="utf-8") as f:
f.write("-" * 60 + "\n")
f.write(f"{theme} 日志记录\n")
f.write(f"时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write("-" * 60 + "\n\n")
# 处理输入类型
if isinstance(documents, str):
self._write_string_log(f, documents)
elif isinstance(documents, Document):
self._write_document_log(f, [documents])
elif isinstance(documents, list):
if all(isinstance(d, Document) for d in documents):
self._write_document_log(f, documents)
elif all(isinstance(d, str) for d in documents):
for doc in documents:
self._write_string_log(f, doc)
else:
raise ValueError("documents 列表中的元素必须是统一类型(str 或 Document)")
else:
raise TypeError(f"不支持的 documents 类型:{type(documents)}")
f.write("\n" + "-" * 60 + "\n")
f.write("日志结束\n")
f.write("-" * 60 + "\n")
def __write_string_log(self, f, content: str):
"""写入字符串类型的日志内容"""
f.write("Content:\n")
f.write(content)
f.write("\n" + "-" * 60 + "\n")
def __write_document_log(self, f, docs: List[Document]):
"""写入 Document 类型的日志内容"""
for i, doc in enumerate(docs):
f.write(f"Document {i + 1}:\n")
f.write(f"内容长度: {len(doc.page_content)} 字符\n")
f.write(f"DocID: {doc.id} \n")
f.write("内容预览:\n")
f.write(doc.page_content)
if doc.metadata:
f.write("Metadata:\n")
for k, v in doc.metadata.items():
f.write(f" {k}: {v}\n")
f.write("-" * 60 + "\n")
def __parse_documents(self, text):
import re
from ast import literal_eval
# 正则匹配每个Document片段
pattern = r"Document\(id='(.*?)', metadata=({.*?}), page_content='(.*?)'\)"
documents = []
for match in re.finditer(pattern, text, re.DOTALL):
id_str = match.group(1)
metadata_str = match.group(2)
page_content = match.group(3)
try:
# 解析metadata字典
metadata = literal_eval(metadata_str)
document = Document(
id=id_str,
page_content=page_content,
metadata=metadata,
)
documents.append(document)
except Exception as e:
print(f"解析错误: {e},跳过当前文档")
return documents
模型初始化类BuilderModel
# 加载LLMs和向量模型
class BuilderModel:
def get_embeddings(self, type: str):
if type == 'ollama':
from langchain_ollama import OllamaEmbeddings
embeddings = OllamaEmbeddings(model='nomic-embed-text:latest')
elif type == 'vllm':
embeddings = OpenAIEmbeddings(
model='qwen_embed',
base_url='http://localhost:6334/v1',
api_key='ltingzx'
)
elif type == 'sentence':
# 加载本地 Embedding 模型
# 先使用modelscopt或者HuggingFace的snapshot_download下载下来的模型和使用“SentenceTransformer”下载下来的模型目录可能不一样;
# 而langchain_huggingface的HuggingFaceEmbeddings包装了SentenceTransformer,所以不能直接加载会报错“Could not locate the configuration_hf_nomic_bert.py inside nomic-ai/nomic-bert-2048.”
# 因为 nomic-embed-text-v1.5 依赖于“nomic-bert-2048”中的“configuration_hf_nomic_bert.py、modeling_hf_nomic_bert.py”这两个文件
# 解决方法就是:1、直接使用SentenceTransformer下载(因为SentenceTransformer会同时下载nomic-embed-text-v1.5和nomic-bert-2048)
# 2、单独下载这两个文件,然后拼装成一样的目录即可;
### 比如下载后的目录“models--nomic-ai--nomic-bert-2048”放到同样的位置即可
from sentence_transformers import SentenceTransformer
embeddings = SentenceTransformer( # 原生的可能会被移除
# model_name_or_path会是使用该路径为模型名称...
# 如果要使用离线的,把模型名称改为本地全路径即可
model_name_or_path=r"D:\A4Project\LLM\nomic-ai\nomic-embed-text-v1.5",
cache_folder=r"D:\A4Project\LLM\nomic-ai\nomic-embed-text-v1.5",
local_files_only=True,
trust_remote_code=True
)
else:
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings( # 需要使用nomic-embed-text:latest向量模型,不然要报错。----要保持入库和检索的Embedding模型一致
model_name=r"D:\A4Project\LLM\nomic-ai\nomic-embed-text-v1.5",
cache_folder=r"D:\A4Project\LLM\nomic-ai\nomic-embed-text-v1.5",
model_kwargs={'trust_remote_code': True, 'device': 'cpu', 'local_files_only': True}, # 添加信任远程代码参数
encode_kwargs={'normalize_embeddings': False}
)
return embeddings
def get_llm(self, type: str):
if type == 'ollama':
llm = ChatOllama(model="qwen2.5:3b")
elif type == 'vllm':
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
model='Deepseek-R1',
api_key='ltingzx',
base_url='http://localhost:6333/v1',
)
else:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
r"D:\A4Project\LLM\Qwen\Qwen3-4B",
trust_remote_code=True,
)
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
r"D:\A4Project\LLM\Qwen\Qwen3-4B",
use_cache=False, # 显式禁用缓存
low_cpu_mem_usage=True, # 低CPU使用
trust_remote_code=True, # 是否信任远程代码
# quantization_config=bits_config, # 添加量化配置 # 本地要错;
torch_dtype=torch.bfloat16,
device_map="cpu", # 自动分配设备
)
from transformers import pipeline
text_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
return_full_text=True,
temperature=0.95,
torch_dtype=torch.float16, # 使用半精度加速
# device_map="auto" # 与模型设备保持一致 --- CPU设备只能使用CPU
)
llm = text_pipeline
return llm
向量库初始化
抽象类VectorStorage
如果要实现除FAISS/ES以外的向量库,可以实现继承VectorStorage此类,然后重写CRUD方法
class VectorStorage(ABC):
@abstractmethod
def create_embed_document(self, documents: list[Document], batch_size: int = 128):
pass
@abstractmethod
def add_embed_document(self, documents: list[Document], batch_size: int = 128):
pass
@abstractmethod
def update_embed_document(self, documents: list[Document]):
pass
@abstractmethod
def delete_embed_document(self, doc_id: Union[str, list]):
pass
def _process_batches(
self,
documents: List[Document],
batch_size: int,
db_name: str,
add_func: Callable[[List[Document], int, int], None], # (当前批次, 起始索引, 结束索引)
save_refresh_func: Callable[[], None],
refresh_time: bool = False,
):
"""
通用的批量添加文档方法
:param documents: 文档列表
:param batch_size: 批次大小
:param add_func: 添加文档的函数
:param save_refresh_func: 可选的保存/刷新函数
:param db_name: 数据库名称(用于日志)
:param refresh_time: 是否实时刷新/保存 --- FAISS如果一次加载在本地数据可能太大,就需要一次性刷新,默认不实时刷新
"""
if not documents:
print(f"警告:{db_name}跳过处理空文档列表")
return
print(f'开始{db_name}向量数据库添加数据...')
total = len(documents)
batches = [documents[i:i + batch_size] for i in range(0, total, batch_size)]
for batch_idx, batch in enumerate(batches, 1):
current_batch_size = len(batch)
start_index = (batch_idx - 1) * batch_size
end_index = start_index + current_batch_size - 1
print(f'处理{db_name}第{batch_idx}批数据({current_batch_size}条) 索引[{start_index}-{end_index}]')
try:
add_func(batch, start_index, end_index)
if refresh_time:
save_refresh_func()
except Exception as e:
print(f"{db_name}第{batch_idx}批处理失败: {str(e)}")
raise
if not refresh_time:
save_refresh_func()
print(f'{db_name}向量数据库添加完成,共处理{len(documents)}条数据')
FAISS库
# Faiss库
class FaissStorage(VectorStorage):
def __init__(self, index_name, faiss_data_dir, embeddings):
from typing import Dict
self.embeddings = embeddings
self.faiss_instances: Dict[str, FAISS] = {}
self.index_name = index_name
self.faiss_data_dir = f"{faiss_data_dir}/{index_name}"
def create_embed_document(self, documents: list[Document], batch_size: int = 128):
# 创建的话可以使用
# faiss_store = FAISS.from_documents(
# documents=[dummy_doc],
# embedding=self.embeddings
# )
pass
def add_embed_document(self, documents: list[Document], batch_size: int = 128):
"""
向量数据库构建...
"""
faiss = self.get_faiss_instance()
def process_batch(batch, _, __):
faiss.add_documents(batch)
def save_refresh_func():
faiss.save_local(
folder_path=self.faiss_data_dir,
index_name=self.index_name
)
self._process_batches(
documents=documents,
batch_size=batch_size,
db_name=f"FAISS_{self.index_name}",
add_func=process_batch,
refresh_time=True,
save_refresh_func=save_refresh_func # 所有批次处理完成后统一保存
)
def update_embed_document(self, document: list[Document]):
# 没有修改。
# 只有删除然后重新添加
pass
def delete_embed_document(self, ids):
faiss_vector_store = self.get_faiss_instance()
try:
faiss_vector_store.delete(ids)
self.save(faiss_vector_store)
print(f"删除ids:{ids} 成功")
except Exception as e:
print(f"删除失败{e}")
def load_faiss_store(self, faiss_index_name: str):
print("加载FAISS向量库...")
return FAISS.load_local(
folder_path=self.faiss_data_dir,
embeddings=self.embeddings,
index_name=faiss_index_name, # 需与保存时一致
allow_dangerous_deserialization=True
)
def get_faiss_instance(self, recreate: bool = False) -> FAISS:
"""
获取或创建 FAISS 实例:
FAISS 的 save_local() 和 load_local() 会生成和读取的文件是这样的格式:
{folder_path}/{index_name}.faiss
{folder_path}/{index_name}.pkl
:param recreate: 是否强制重建实例
"""
# faiss_path = os.path.join(self.faiss_data_dir, self.index_name)
# if not recreate and faiss_path in self.faiss_instances:
# return self.faiss_instances[faiss_path]
faiss_index_path = os.path.join(self.faiss_data_dir, f"{self.index_name}.faiss")
faiss_pkl_path = os.path.join(self.faiss_data_dir, f"{self.index_name}.pkl")
if not recreate and faiss_index_path in self.faiss_instances:
return self.faiss_instances[faiss_index_path]
# 这儿应该是路径+路径,不然要报错“Error: 'f' failed: could not open local_db\ruozhiba\ruozhiba_v2.faiss for reading: No such file or directory”
# if os.path.exists(faiss_path):
if os.path.exists(faiss_index_path) and os.path.exists(faiss_pkl_path):
faiss_store = FAISS.load_local(
folder_path=self.faiss_data_dir,
embeddings=self.embeddings,
index_name=self.index_name, # 需与保存时一致
allow_dangerous_deserialization=True
)
else:
# 创建新索引时至少需要一个文档
dummy_doc = Document(page_content="dummy", metadata={"source": "init"})
faiss_store = FAISS.from_documents(
documents=[dummy_doc],
embedding=self.embeddings
)
os.makedirs(self.faiss_data_dir, exist_ok=True)
faiss_store.save_local(
folder_path=self.faiss_data_dir,
index_name=self.index_name
)
self.faiss_instances[faiss_index_path] = faiss_store
return faiss_store
def save(self, faiss):
faiss.save_local(
folder_path=self.faiss_data_dir,
index_name=self.index_name
)
def store_list(self):
faiss_vector_store = self.get_faiss_instance()
print(f"{self.faiss_data_dir}/{self.index_name}数量:{len(list(faiss_vector_store.docstore._dict.values()))}")
# print(faiss_vector_store.docstore._dict.keys())
# print(faiss_vector_store.docstore._dict.values())
print(f"{"-" * 10}")
ES库
# ES库
class EsStorage(VectorStorage):
def __init__(self,
index_name,
es_host,
es_port,
es_password,
es_username,
embeddings: Optional[Embeddings] = None,
strategy: str = "bm25"
):
# 实例缓存池
from typing import Dict
super().__init__()
self.embeddings = embeddings
self.es_client_url = f"http://{es_username}:{es_password}@{es_host}:{es_port}"
self.index_name = index_name
self.strategy_type = strategy
self.es_instances: Dict[str, ElasticsearchStore] = {}
def create_embed_document(self, documents: list[Document], batch_size: int = 128):
# es.add_documents(documents) # 向已存在的 ElasticsearchStore 实例关联的索引中追加文档。增量更新,向已有索引添加新文档。
pass
def add_embed_document(self, documents: list[Document], batch_size: int = 1024):
es = self.get_deep_es_instance()
def process_batch(batch, _, __):
es.add_documents(batch)
def save_refresh_func():
es.client.indices.refresh(index=self.index_name)
self._process_batches(
documents=documents,
batch_size=batch_size,
db_name=f"ES_{self.index_name}",
add_func=process_batch,
save_refresh_func=save_refresh_func # 所有批次处理完成后统一保存
)
def update_embed_document(self, documents: list[Document]):
es = self.get_deep_es_instance()
es.client.update(index=self.index_name, )
pass
def delete_embed_document(self, doc_id):
try:
es = self.get_deep_es_instance()
es.client.delete(
index=self.index_name,
# body={"query": {"match_all": {}}}
id=doc_id
)
return True
except Exception as e:
print(f"删除:{doc_id} 失败: {str(e)}")
return False
# 深度封装
def get_deep_es_instance(
self,
recreate: bool = False
) -> ElasticsearchStore:
"""
获取或创建 ES 实例(支持 BM25/向量检索)
:param index_name: 索引名称
:param strategy: 检索策略(bm25/dense)
:param recreate: 是否强制重建实例
"""
from langchain_elasticsearch import ElasticsearchStore
from langchain_community.vectorstores import DistanceStrategy
# 检查缓存
if not recreate and self.index_name in self.es_instances:
return self.es_instances[self.index_name]
# 策略选择
if self.strategy_type == "bm25":
# strategy=ElasticsearchStore.BM25RetrievalStrategy(), # 要使用BM25检索时使用 BM25Strategy/ --- 新版本弃用
strategy = BM25Strategy() # 设置检索策略为要使用BM25检索时使用 BM25Strategy/
distance_strategy = None # BM25 不需要距离策略
embedding = None
else:
strategy = DenseVectorStrategy()
distance_strategy = DistanceStrategy.COSINE # 设置距离相似性算法 COSINE(默认)/EUCLIDEAN_DISTANCE/DOT_PRODUCT;
embedding = self.embeddings
es_store = ElasticsearchStore(
es_url=self.es_client_url,
index_name=self.index_name,
embedding=embedding,
strategy=strategy,
distance_strategy=distance_strategy
)
# 缓存实例
self.es_instances[self.index_name] = es_store
return es_store
# 浅封装
def get_shallow_es_client(self):
from langchain_community.retrievers import ElasticSearchBM25Retriever
import elasticsearch
client = elasticsearch.Elasticsearch(self.es_client_url)
return ElasticSearchBM25Retriever(
client=client,
index_name=self.index_name,
)
# 原生客户端
def get_native_es_client(self):
from elasticsearch import Elasticsearch
# """在指定索引中搜索文档"""
# query = {
# "query": {
# "match": {
# "text": "根据国际能源署的2021年印度能源展望报告煤炭约占印度发电量的70该国是全球第二大煤炭消费国和进口国"
# }
# }
# }
# print(es.search(index=es_index_name, body=query))
return Elasticsearch(
hosts=[self.es_client_url],
# basic_auth=("username", "password"),
verify_certs=False, # 告诉客户端是否验证Elasticsearch服务器的TLS证书。在生产环境中,我们应该将其设置为True以确保安全的通信
# ca_certs='conf/http_ca.crt' # 当verify_certs=True时,这里指定了CA证书的路径,客户端将使用它来验证服务器证书的签名。
)
检索服务
基础服务RetrieverService
def log_retrieved_docs(ctx):
# print(f"[{ctx['msgid']}] [{ctx['query']}] Retrieved documents:[{ctx['content']['content']}]")
print(f"Retrieved documents:[{ctx}]")
return ctx # 确保返回原数据继续链式传递
def result_format(type, content, msgid):
return f"data: {json.dumps({
'type': type,
'content': content,
'msgid': msgid,
'timestamp': f'{int(datetime.now().timestamp())}'
}, ensure_ascii=False)} \n\n"
def generate_result(doc, msgid):
yield from result_format("chunk", doc, msgid)
yield from result_format("end", "", msgid)
ranker_dir = 'Rerank的重排序模型本地位置---可以使用默认的,但是如果使用默认的话需要VPN'
bge_model_path = 'bge-reranker模型位置'
class RetrieverService:
def __init__(self,
faiss_storage: FaissStorage,
es_storage: EsStorage,
llm,
embeddings):
self.faiss = faiss_storage
self.es = es_storage
self.llm = llm
self.embeddings = embeddings
self.llm_chain = {}
# 1. 单独检索 ================================================
def faiss_bm25_retriever(self, k=5) -> BM25Retriever:
"""FAISS BM25检索(需注意实际是从内存文档检索)"""
docs = self._get_faiss_documents()
return BM25Retriever.from_documents(
docs,
k=k, # 返回数量
k1=1.5, # 默认1.2,增大使高频词贡献更高
b=0.8 # 默认0.75,减小以降低文档长度影响
)
def faiss_vector_retriever(self, k=5) -> FAISS.as_retriever:
"""FAISS向量库单独检索"""
return self.faiss.get_faiss_instance().as_retriever(
search_type="similarity_score_threshold",
search_kwargs={
"k": k,
"score_threshold": 0.7
}
)
def es_bm25_retriever(self, k=5) -> ElasticsearchStore.as_retriever:
"""ES+BM25检索 --- 因为入库的时候我们使用的是BM25Strategy"""
return self.es.get_deep_es_instance().as_retriever(search_kwargs={"k": k})
# 2. 混合检索 ================================================
def ensemble_faiss_faiss_bm25(self, weights=(0.5, 0.5), k=5) -> EnsembleRetriever:
"""FAISS+FAISS_BM25"""
return self.ensemble_retriever([
self.faiss_vector_retriever(k),
self.faiss_bm25_retriever(k)
], weights=weights)
# return EnsembleRetriever(
# retrievers=[
# self.faiss_vector_retriever(k),
# self.faiss_bm25_retriever(k)
# ],
# weights=weights
# )
def ensemble_faiss_es_bm25(self, weights=(0.5, 0.5), k=5) -> EnsembleRetriever:
"""FAISS+ES_BM25"""
return self.ensemble_retriever([
self.faiss_vector_retriever(k),
self.es_bm25_retriever(k)
], weights=weights)
def ensemble_retriever(self, rerivevers: list[VectorStoreRetriever], weights=(0.5, 0.5)) -> EnsembleRetriever:
return EnsembleRetriever(
retrievers=rerivevers,
weights=weights
)
# 3. 重排序优化 ==============================================
def compressed_faiss_hybrid(self, model_name="ms-marco-MiniLM-L-12-v2", top_n=5) -> ContextualCompressionRetriever:
"""
FAISS+FAISS_BM25排序
默认模型 ms-marco-TinyBERT-L-2-v2 (约4MB)
最佳交叉编码器重排序器 ms-marco-MiniLM-L-12-v2 (约34MB)
最佳非交叉编码器重排序器 rank-T5-flan (约110MB)
支持100多种语言的多语言模型 ms-marco-MultiBERT-L-12 (约150MB)
微调的 ce-esci-MiniLM-L12-v2
大型上下文窗口和较快性能的 rank_zephyr_7b_v1_full (约4GB,4比特量化)
专用阿拉伯语重排序器 miniReranker_arabic_v1
:param model_name:
:param top_n:
:return:
"""
return self.flashrank_rerank(self.ensemble_faiss_faiss_bm25(), model_name, top_n)
# compressor = FlashrankRerank(client=self._get_rank(), model=model_name, top_n=top_n)
# base_retriever = self.ensemble_faiss_faiss_bm25()
# return ContextualCompressionRetriever(
# base_compressor=compressor,
# base_retriever=base_retriever
# )
def compressed_faiss_esbm25_hybrid(self, top_n=5) -> ContextualCompressionRetriever:
"""混合检索:FAISS+ES的BM25"""
return self.bge_rerank(self.ensemble_faiss_es_bm25(), top_n)
# 使用flashrank
def flashrank_rerank(self,
base_retriever: BaseRetriever,
model_name="ms-marco-MiniLM-L-12-v2",
top_n=5) -> ContextualCompressionRetriever:
compressor = FlashrankRerank(client=self._get_rank(), model=model_name, top_n=top_n)
return ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
# 使用bge-rerank
def bge_rerank(self,
base_retriever: BaseRetriever,
top_n: int = 5):
from langchain.retrievers.document_compressors import CrossEncoderReranker
compressor = CrossEncoderReranker(model=self._get_bge_rank(), top_n=top_n)
return ContextualCompressionRetriever(
base_compressor=compressor,
base_retriever=base_retriever
)
# 4. 多路召回系统 ============================================
class MultiRetriever(BaseRetriever):
# 显式声明 retrievers 字段
retrievers: List[BaseRetriever]
def __init__(self, retrievers: List[BaseRetriever]):
super().__init__(retrievers=retrievers)
self.retrievers = retrievers
def _get_relevant_documents(self, query: str, **kwargs) -> List[Document]:
from itertools import chain
return list(chain.from_iterable(
r.invoke(query, **kwargs)
for r in self.retrievers
))
# # 添加异步支持
# async def _aget_relevant_documents(self, query: str, **kwargs) -> List[Document]:
# from itertools import chain
# from asyncio import gather
# results = await gather(*(r.invoke(query, **kwargs) for r in self.retrievers))
# return list(chain.from_iterable(results))
# 5-9. LLM集成 ===============================================
def _build_qa_chain(self, retriever, chain_type="stuff"):
"""使用RetrievalQA构建链"""
from langchain.chains import RetrievalQA
return RetrievalQA.from_chain_type(
llm=self.llm,
chain_type=chain_type,
retriever=retriever,
return_source_documents=True,
chain_type_kwargs={
"prompt": PromptTemplate(
template="""综合以下多个来源的信息回答问题,如果来源中没有相关问题,请直接回复“不知道...”即可,切勿擅自添加或者额外理解回答:
来源文档:{context}
问题:{question}
整合回答:""",
input_variables=["context", "question"]
)
},
)
def _build_llm_chain(self, base_retriever):
# 处理检索结果的函数(将文档列表转换为字符串)
from langchain_core.runnables import RunnableLambda
# process_docs = RunnableLambda(lambda docs: "\n".join([doc.page_content for doc in docs]))
from langchain_core.runnables import RunnablePassthrough
prompt = """
请根据以下内容回答问题,在内容中“--answer--”分割符前面是我检索到的问题,分隔符后面的是对应问题的答案。
你需要判断我输入的问题与分隔符前面的问题相似度,如果是一样类型的问题,才参考后面的答案,回复。
内容中如果没有,那就回答“请咨询人工...”,内容中如果有其他不相干的内容,直接删除即可。
内容:{content}
问题:{query}
回答:
"""
prompt_template = ChatPromptTemplate.from_template(prompt)
from operator import itemgetter
chain = (
# RunnableLambda(log_retrieved_docs) | # 直接打印传递进来的参数
{
# 这个content和query会继续往下传递,直到prompt --->{content}、{query}
"content": RunnableLambda(lambda x: x[
"query"]) # 必须,不然要报错“TypeError: Expected a Runnable, callable or dict.Instead got an unsupported type: <class 'str'”
| base_retriever # 检索
| RunnableLambda(log_retrieved_docs) # 打印出检索到的文档,检索后未处理
# 先检索再处理文档
| RunnableLambda(lambda docs: "\n".join([doc.page_content for doc in docs])),
# | process_docs # 先检索再处理文档 --- 和上面方法二选一
# | RunnableLambda(log_retrieved_docs), # 打印出检索后的文档 --- 这里传递的仅仅是检索到的内容且预处理后的内容
"query": itemgetter("query"), # 直接传递用户原始问题
"msgid": RunnableLambda(lambda x: x["msgid"]), # 显示传递msgid --- 和itemgetter同样的效果
}
# | RunnableLambda(log_retrieved_docs) # 传递的是前面整个content、query、msgid的值到日志中
| prompt_template # 组合成完整 prompt
| self.llm # 传给大模型生成回答
# | RunnableLambda(log_retrieved_docs) # 传递的是LLM生成的内容 --- 但是在这一步以后,系统会同步返回---不推荐在这里打印日志
)
return chain
# 同步处理
def query_invoke(self, query: str, retriever: BaseRetriever) -> Dict:
qa = self._build_qa_chain(retriever)
return qa.invoke({"query": query})
# 异步处理
async def query_aInvoke(self, query: str, retriever: BaseRetriever) -> Dict:
qa = self._build_qa_chain(retriever)
return await qa.ainvoke({"query": query})
async def query_qa_aStream(self,
query: str,
msgid: str,
retriever: BaseRetriever = ensemble_faiss_es_bm25
) -> AsyncIterator:
# 如果这里使用QA链的话,好像无法流式返回
llm_chain = self._build_qa_chain(retriever)
result = ""
async for chunk in llm_chain.astream({"query": query}):
print(chunk)
result += chunk['result']
yield result_format("chunk", chunk['result'], msgid)
print(f"query:{query} msgid:{msgid} llm:{result} ")
# 异步流式处理
async def query_aStream(self,
query: str,
msgid: str,
retriever: BaseRetriever = ensemble_faiss_es_bm25
) -> AsyncIterator:
# 如果这里使用QA链的话,好像无法流式返回
# 确保为不同retriever创建独立的缓存键
cache_key = f"chain_{id(retriever)}"
# 获取或构建LLM链(线程安全方式)
llm_chain = self.llm_chain.get(cache_key)
if llm_chain is None:
llm_chain = self._build_llm_chain(retriever)
self.llm_chain[cache_key] = llm_chain # 缓存新创建的链
# 使用StringIO高效拼接结果(避免字符串不可变导致的多次内存分配)
result_buffer = io.StringIO()
try:
async for chunk in llm_chain.astream({"query": query, "msgid": msgid}):
result_buffer.write(chunk.content)
yield result_format("chunk", chunk.content, msgid)
except Exception as e:
# | File "D:\A2SoftwareData\MiniConda\Lib\site-packages\openai\_base_client.py", line 1616, in _request
# | raise APIConnectionError(request=request) from err
# | openai.APIConnectionError: Connection error.
# 异常处理:返回错误信息并记录日志
error_msg = f"Stream processing failed: {str(e)}"
logger.error(f"query:{query} msgid:{msgid} error:{error_msg}")
yield result_format("error", error_msg, msgid)
return
# 获取完整响应并记录日志
final_result = result_buffer.getvalue()
logger.info(f"query:'{query}' msgid:{msgid} response:'{final_result}'")
async def query_new_aStream(
self,
query: str,
msgid: str,
retriever: BaseRetriever = ensemble_faiss_es_bm25
) -> AsyncIterator[str]:
# 确保为不同retriever创建独立的缓存键
cache_key = f"chain_{id(retriever)}"
# 获取或构建LLM链
llm_chain = self.llm_chain.get(cache_key)
if llm_chain is None:
llm_chain = self._build_llm_chain(retriever)
self.llm_chain[cache_key] = llm_chain
# 初始化流管理器(单例模式)
if not hasattr(self, '_stream_manager'):
self._stream_manager = QueryStreamManager(
max_concurrent=128, # 最大并发流数
max_queue_size=512 # 最大等待队列长度
)
manager = self._stream_manager
# 检查队列是否已满
if len(manager.queue) >= manager.max_queue_size:
error_msg = "队列满了..."
manager.logger.error(f"query:{query} msgid:{msgid} rejected: {error_msg}")
yield result_format("error", error_msg, msgid)
return
# 使用StringIO高效拼接结果
result_buffer = io.StringIO()
# 创建任务并加入队列
queue_position = len(manager.queue) + 1
if queue_position > 1:
manager.logger.info(f"query:{query} msgid:{msgid} queued (position: {queue_position})")
# 将任务加入队列
manager.queue.append((query, msgid, result_buffer))
try:
# 等待信号量(控制并发)
await manager.semaphore.acquire()
# 从队列中取出任务
task_query, task_msgid, task_buffer = manager.queue.popleft()
manager.active_tasks += 1
# 处理流(带重试机制)
stream_processor = manager.process_stream(
llm_chain, task_query, task_msgid, task_buffer
)
try:
async for chunk in stream_processor:
yield chunk
finally:
manager.semaphore.release()
manager.active_tasks -= 1
except asyncio.CancelledError:
# 处理任务取消
manager.logger.warning(f"query:{query} msgid:{msgid} was cancelled")
yield result_format("error", "Request cancelled", msgid)
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
manager.logger.exception(f"query:{query} msgid:{msgid} error:{error_msg}")
yield result_format("error", error_msg, msgid)
finally:
# 确保从队列中移除当前任务(如果在队列中)
if (query, msgid, result_buffer) in manager.queue:
manager.queue.remove((query, msgid, result_buffer))
# 获取完整响应并记录日志
final_result = result_buffer.getvalue()
manager.logger.info(f"query:'{query}' msgid:{msgid} response:'{final_result}'")
# 辅助方法 ===================================================
def _get_faiss_documents(self) -> List[Document]:
"""获取FAISS中所有数据"""
return list(self.faiss.get_faiss_instance().docstore._dict.values())
def _get_rank(self):
"""获取重排序模型"""
return Ranker(cache_dir=ranker_dir)
def _get_bge_rank(self, top_n: int = 3):
"""
核心机制:使用BgeReranker模型对检索结果进行质量重排序
工作流程:原始检索 → 用专用模型重新打分 → 仅保留top_n结果
特点:不改变文档内容,但改变排序和结果数量(从k到top_n)
适用场景:需要提升Top结果准确性的场景
:param top_n: 最终结果数
[
Document(id='9f171b5b-d19e-4a6f-804d-99f1cd799dc1', metadata={'source': './data/ruozhiba_qa.txt', 'doc_id': 'c383d60012f242fda1f6a87a5045d414'}, page_content='为什么没人说ABCD型的成语?🤔这是因为中文成语一般都是四字成语,每个字都有其特定的含义,四个字合在一起构成一个完整的意思。而ABCD型的成语最常见,所以大家不会刻意强调。'),
Document(id='27439b5c-c111-4714-8257-8c4199b88097', metadata={'source': './data/ruozhiba_qa.txt', 'doc_id': '827bfc78f31e4369b3ce0bdcec815f2e'}, page_content='樟脑丸是我吃过最难吃的硬糖有奇怪的味道怎么还有人买樟脑丸并不是硬糖,而是一种常见的驱虫药,不能食用。虽然它的味道可能不太好,但是由于其有效的驱虫效果,所以仍然有很多人会购买。'),
Document(id='1d244b62-9502-4f8c-9b37-fa348c7f1dbd', metadata={'source': './data/ruozhiba_qa.txt', 'doc_id': 'b82db3c22fa94b3a85d18c7cb5f84a80'}, page_content='我算知道艾滋病人为什么受歧视了 全校组织无偿献血,一个艾滋病人都没来艾滋病人不能献血,因为他们的血液中含有HIV病毒,如果将这种血液输给其他人,会导致其他人也感染艾滋病。所以,艾滋病人没有参加无偿献血活动,是为了保护其他人的健康,而不是因为他们受到歧视。')
]
"""
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
return HuggingFaceCrossEncoder(
model_name=bge_model_path,
model_kwargs={'device': 'cpu'}
)
查询管理
class QueryStreamManager:
def __init__(self, max_concurrent: int = 5, max_queue_size: int = 20):
self.semaphore = asyncio.Semaphore(max_concurrent)
self.queue = deque()
self.max_queue_size = max_queue_size
self.active_tasks = 0
self.logger = logging.getLogger(self.__class__.__name__)
async def process_stream(self, llm_chain, query: str, msgid: str, result_buffer):
max_retries = 3
retry_delay = 1.0 # 初始重试延迟1秒
for attempt in range(max_retries):
try:
async for chunk in llm_chain.astream({"query": query, "msgid": msgid}):
content = chunk.content
result_buffer.write(content)
yield result_format("chunk", content, msgid)
return # 成功完成,退出函数
except Exception as e:
if attempt < max_retries - 1:
self.logger.warning(f"Stream error (attempt {attempt + 1}/{max_retries}): {str(e)}")
await asyncio.sleep(retry_delay)
retry_delay *= 2 # 指数退避
else:
error_msg = f"Stream processing failed after {max_retries} attempts: {str(e)}"
self.logger.error(f"query:{query} msgid:{msgid} error:{error_msg}")
yield result_format("error", error_msg, msgid)
raise
文本入库
准备入库文本,格式如下
本文会按照“\n\n”进行切分,然后将Q&A合并为一行,向量,入库;
Q
A
Q
A
Q
A
Q
A
Q
A
.....
准备停用词文本
本文只处理特殊符号---具体停用词,根据实际业务定
!
"
(
)
+
,
;
-
--
.
..
...
......
...................
/
//
:
://
::
<
=
>
>>
?
@
[
\
]
^
_
`
|
}
~
~~~~
·
×
×××
—
——
———
‘
’
’‘
”
”,
…
……
…………………………………………………③
′∈
′|
℃
Ⅲ
↑
→
∈[
∪φ∈
≈
①
②
②c
③
③]
④
⑤
⑥
⑦
⑧
⑨
⑩
──
■
▲
、
。
〈
〉
《
》
》),
」
『
』
【
】
〔
〕
〕〔
㈧
!
#
$
%
&
'
(
)
)÷(1-
)、
*
+
+ξ
++
,
,也
-
-β
--
-[*]-
.
/
0
0:2
1
1.
12%
2
2.3%
3
4
5
5:0
6
7
8
9
:
<
<±
<Δ
<λ
<φ
<<
=
=″
=☆
=(
=-
=[
={
>
>λ
?
@
A
LI
R.L.
ZXFITL
[
[①①]
[①②]
[①③]
[①④]
[①⑤]
[①⑥]
[①⑦]
[①⑧]
[①⑨]
[①A]
[①B]
[①C]
[①D]
[①E]
[①]
[①a]
[①c]
[①d]
[①e]
[①f]
[①g]
[①h]
[①i]
[①o]
[②
[②①]
[②②]
[②③]
[②④
[②⑤]
[②⑥]
[②⑦]
[②⑧]
[②⑩]
[②B]
[②G]
[②]
[②a]
[②b]
[②c]
[②d]
[②e]
[②f]
[②g]
[②h]
[②i]
[②j]
[③①]
[③⑩]
[③F]
[③]
[③a]
[③b]
[③c]
[③d]
[③e]
[③g]
[③h]
[④]
[④a]
[④b]
[④c]
[④d]
[④e]
[⑤]
[⑤]]
[⑤a]
[⑤b]
[⑤d]
[⑤e]
[⑤f]
[⑥]
[⑦]
[⑧]
[⑨]
[⑩]
[*]
[-
[]
]
]∧′=[
][
_
a]
b]
c]
e]
f]
ng昉
{
{-
|
}
}>
~
~±
~+
¥
±
÷
∞
≠
≤
≥
∈
∪
∩
⊂
⊃
⊆
⊇
∧
∨
¬
⇒
⇔
∴
∵
∫
∑
∏
√
∂
∇
Δ
Γ
Σ
Φ
Ψ
Ω
α
β
γ
δ
ε
ζ
η
θ
ι
κ
λ
μ
ν
ξ
π
ρ
σ
τ
υ
φ
χ
ψ
ω
∈[
∪φ∈
′∈
′|
*
**
***
******
🤔
ES入库
本文默认已经安装好了ES、FAISS等库
es_index_name = "索引名称"
es_host = "ip"
es_port = "端口"
es_password = "密码"
es_username = "账号"
def es_store(documents, embeddings):
# 加载ES
es_store = EsStorage(es_index_name, es_host, es_port, es_password, es_username, embeddings)
es_store.add_embed_document(documents)
es = es_store.get_deep_es_instance()
print(es.as_retriever().invoke("xxxxxx问题"))
print("*" * 100)
FAISS入库
faiss_index_name = "索引名称"
faiss_data_dir = "./本地FAISS库路径"
def faiss_store(documents, embeddings):
# 加载Faiss
faiss_store = FaissStorage(faiss_index_name, faiss_data_dir, embeddings)
# Faiss向量入库
faiss_store.add_embed_document(documents, 256)
faiss_store1 = FaissStorage(faiss_index_name, faiss_data_dir, embeddings)
faiss_store1.store_list()
print(faiss_store1.get_faiss_instance().as_retriever().invoke("心脏"))
main方法
store_file_path = "要入库的文档.txt"
stop_word_file_path = '停用词文档.txt'
async def main():
# 初始化向量模型
llm_type = "ollama"
embed_type = "ollama"
embeddings = BuilderModel().get_embeddings(embed_type)
"""
chune_size:切分块大小,严格意义上讲,不能为0,因为原始数据问题(切分后的长度为5~1000)都有,如果设置太大,系统会自动合并分割后太小的块---
但是我想要的就是一个完整的Q&A切分,所以设置0,设置0的话,系统会严格遵守我们的切分规则;
"""
documents = BaseBuilder().load_documents(store_file_path, stop_word_file_path, chune_size=0)
"""
在load_documents()中,我们可以选择是否保持日志(即切分后的document),一方面是为了后续查看我们切分规则是否合理,一方面是为了保持在使用多个不同的向量库时保持数据一致;
比如后续,如果我们ES数据丢失了,我们不需要再从0开始加载、切分文档---重新加载切分的话,我们会生成新的文档ID,导致我们在FAISS和ES库中的文档ID不一致;
如果我们开启了step_log(默认开启),那么执行完load_documents以后,会在当前文件夹下面生成“step_logs”、“store”两个文件夹
step_logs下面有:
预处理日志_xxxx.log,预处理完成以后的日志(去除停用词、特殊符号等等..)
切分日志_xxxx.log,查看切分后的日志,会生成文档id等信息
store下面就是我们生成的一个完整的documnets,直接入库的documents,专门为“load_docs()”来恢复数据使用;
"""
# documents = BaseBuilder().load_docs("./store/入库_xxxx_20250605_121932.txt")
es_store(documents, embeddings)
faiss_store(documents, embeddings)
if __name__ == '__main__':
# asyncio.run(main())
检索---测试全部
def start():
# 初始化组件
query = "只剩一个心脏了还能活吗"
builder = BuilderModel()
embeddings = builder.get_embeddings("ollama")
llm = builder.get_llm("ollama")
faiss_storage = FaissStorage(faiss_index_name, faiss_data_dir, embeddings)
# print(faiss_storage.get_faiss_instance().docstore._dict.values())
# print(f"原始FAISS:{faiss_storage.get_faiss_instance().as_retriever().invoke(query)}")
es_storage = EsStorage(es_index_name, es_host, es_port, es_password, es_username, embeddings)
# es_storage = None
# print(f"原始ES:{es_storage.get_deep_es_instance().as_retriever().invoke(query)}")
service = RetrieverService(faiss_storage, es_storage, llm, embeddings)
# 示例1:混合检索+重排序
faiss_bm25 = service.faiss_bm25_retriever()
# print(f"基于FAISS的BM25向量:{faiss_bm25.invoke(query)}")
faiss_embed = service.faiss_vector_retriever()
# print(f"FAISS向量:{faiss_embed.invoke(query)}")
es_bm25 = service.es_bm25_retriever()
# print(f"基于ES的BM25:{es_bm25.invoke(query)}")
faiss_embed_bm25 = service.ensemble_faiss_faiss_bm25()
# print(f"FAISS向量+FAISS的BM25:{faiss_embed_bm25.invoke(query)}")
faiss_embed_es_bm25 = service.ensemble_faiss_es_bm25()
# print(f"FAISS向量+ES的BM25:{faiss_embed_es_bm25.invoke(query)}")
faiss_embed_bm25_falshrerank = service.compressed_faiss_hybrid()
# print(f"重排序优化(FAISS+FAISS的BM25):{faiss_embed_bm25_falshrerank.invoke(query)}")
print(f"混合检索+FlashRerank重排序:{service.query_invoke(query, faiss_embed_bm25_falshrerank)}")
faiss_embed_bm25_bgererank = service.compressed_faiss_esbm25_hybrid()
print(f"混合检索+FlashRerank重排序:{service.query_invoke(query, faiss_embed_bm25_bgererank)}")
# print(f"重排序优化(FAISS+FAISS的BM25):{faiss_embed_bm25_bgererank.invoke(query)}")
# 示例2:异步流式处理
async def run_query():
async for chunk in service.query_aStream(query, faiss_embed_bm25_falshrerank):
print("-----------")
print(chunk)
time.sleep(10)
multi_retriever = service.MultiRetriever([
faiss_bm25,
faiss_embed,
es_bm25,
faiss_embed_bm25,
faiss_embed_es_bm25,
# faiss_embed_bm25_falshrerank,
# faiss_embed_es_bm25_falshrerank
])
retriever = service.flashrank_rerank(multi_retriever, top_n=5)
print(f"多路召回检索器:{multi_retriever.invoke(query)}")
print(f"多路召回检索器:{retriever.invoke(query)}")
start()
对外发布接口
import asyncio
import io
import json
import logging
import time
from abc import abstractmethod, ABC
from asyncio.log import logger
from collections import deque
from datetime import datetime
from typing import Union, Optional, List, Dict, Callable, AsyncIterator
import torch
from fastapi import FastAPI
from flashrank import Ranker
from langchain.retrievers import EnsembleRetriever, ContextualCompressionRetriever
from langchain_community.document_compressors import FlashrankRerank
from langchain_community.retrievers import BM25Retriever
from langchain_community.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_elasticsearch.vectorstores import BM25Strategy, ElasticsearchStore, DenseVectorStrategy
from langchain_ollama import ChatOllama
from langchain_openai import OpenAIEmbeddings
import os
from langchain_core.documents import Document
from slowapi.util import get_remote_address
from starlette.responses import StreamingResponse
app = FastAPI(title='xxxxx', version='1.0.0', description='xxxx检索')
# 添加 CORS --- 跨域 中间件
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 允许所有来源,生产环境建议指定具体域名
allow_credentials=True, # 允许携带凭证(如cookies)
allow_methods=["*"], # 允许所有HTTP方法(可选:["GET", "POST"]等)
allow_headers=["*"], # 允许所有HTTP头
)
# 使用slowapi限流
from slowapi import Limiter
limiter = Limiter(key_func=get_remote_address)
# 同步返回
@app.get("/llm/astream")
async def llm_astream(query: str, msgid: str):
print(f"llm_retriveter_astream请求开始:query:{query} msgid:{msgid}")
return StreamingResponse(
service.query_aStream(query, msgid, retrievers.get('multi_rerank')),
media_type="text/event-stream;charset=utf-8" # text/plain、text/event-stream;强制响应头charset=utf-8
)
@app.get("/llm/new_astream")
async def query_new_aStream(query: str, msgid: str):
print(f"llm_retriveter_astream请求开始:query:{query} msgid:{msgid}")
return StreamingResponse(
service.query_new_aStream(query, msgid, retrievers.get('multi_rerank')),
media_type="text/event-stream;charset=utf-8" # text/plain、text/event-stream;强制响应头charset=utf-8
)
retrievers = {}
@app.get("/retriever")
async def retriever(query: str, msgid: str, type: str):
print(f"检索请求开始:type:{type} query:{query} msgid:{msgid}")
retriever = retrievers.get(type)
doc = "\n".join([doc.page_content for doc in retriever.invoke(query)])
return StreamingResponse(result_format("chunk", doc, msgid), media_type='text/event-stream')
@app.get("/retriever_async")
# @limiter.limit("1/second")
async def retriever(query: str, msgid: str, type: str):
print(f"检索请求开始:type:{type} query:{query} msgid:{msgid} 时间:{datetime.now()}")
retriever = retrievers.get(type)
async def document_generator():
# 流式获取文档
docs = await asyncio.to_thread(retriever.invoke, query)
for i, doc in enumerate(docs):
# 分块发送文档内容
chunk = {
"chunk": doc.page_content,
"index": i,
"total": len(docs),
"msgid": msgid
}
yield json.dumps(chunk) + "\n\n"
return StreamingResponse(
document_generator(),
media_type='text/event-stream',
headers={
"X-Stream-Output": "true",
"Cache-Control": "no-cache"
}
)
def build_retriever(service):
retrievers['faiss25'] = service.faiss_bm25_retriever()
retrievers['faiss25_rerank'] = service.flashrank_rerank(service.faiss_bm25_retriever(), top_n=5)
retrievers['es25'] = service.es_bm25_retriever()
retrievers['es25_rerank'] = service.flashrank_rerank(service.es_bm25_retriever(), top_n=5)
retrievers['faiss'] = service.faiss_vector_retriever()
retrievers['faiss_rerank'] = service.flashrank_rerank(service.faiss_vector_retriever(), top_n=5)
retrievers['faiss_faiss25'] = service.ensemble_faiss_faiss_bm25()
retrievers['faiss_faiss25_rerank'] = service.flashrank_rerank(service.ensemble_faiss_faiss_bm25(), top_n=5)
retrievers['faiss_es25'] = service.ensemble_faiss_es_bm25()
retrievers['faiss_es25_rerank'] = service.flashrank_rerank(service.ensemble_faiss_es_bm25(), top_n=5)
retrievers['faiss_es25_bge_rerank'] = service.bge_rerank(service.ensemble_faiss_es_bm25(), top_n=5)
retrievers['multi'] = service.MultiRetriever([
service.compressed_faiss_hybrid(),
service.ensemble_faiss_es_bm25(),
service.ensemble_faiss_faiss_bm25(),
service.es_bm25_retriever(),
service.faiss_vector_retriever(),
service.faiss_bm25_retriever(),
service.compressed_faiss_esbm25_hybrid()
])
retrievers['multi_rerank'] = service.flashrank_rerank(retrievers.get('multi'), top_n=5)
if __name__ == '__main__':
import uvicorn
builder = BuilderModel()
embeddings = builder.get_embeddings("ollama")
llm = builder.get_llm("vllm")
faiss_storage = FaissStorage(faiss_index_name, faiss_data_dir, embeddings)
es_storage = EsStorage(es_index_name, es_host, es_port, es_password, es_username, embeddings)
service = RetrieverService(faiss_storage, es_storage, llm, embeddings)
build_retriever(service)
uvicorn.run(app, host='localhost', port=8000)
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)