langchain基础教程(8)---使用LangGraph的搭建混合型的RAG
按照流程图, 使用add_edge 连接节点和节点,并行节点也使用add_edge,条件分支要使用add_conditional_edges。状态是一个类型化字典(TypedDict),用于存储各个节点的结果,加上输入问题 一共9个。全局的 “数据容器”, 用于在节点间传递信息(如上下文、工具结果、中间节点的结果数据)节点间的 “流转规则”, 决定了 “执行完当前节点后,下一步该去哪里”(节点0A
在前面的文章中,提到了混合型的RAG;这里要用langGraph实现一个混合型的RAG
复习一下LangGraph的基础:
1. 节点(Nodes) 任务的执行单元, 具体的逻辑或功能
节点类型:
工具调用节点:调用外部工具(如 API、数据库、计算器),例如 “调用天气 API 获取实时温度”。
LLM 节点:调用大语言模型(如 GPT-4、Claude)生成文本,例如 “根据用户问题生成回答初稿”。
逻辑处理节点:执行自定义代码(如数据清洗、格式校验),例如 “检查 LLM 输出是否包含敏感词”。
人工节点:暂停流程等待人工输入(如 “需人工审核高风险内容”)。
特性:节点是 “黑盒”,仅通过输入(状态中的数据)和输出(更新状态)与外部交互,便于模块化复用。
2. 边(Edges) 节点间的 “流转规则”, 决定了 “执行完当前节点后,下一步该去哪里”
边的类型:
无条件边:固定指向某一节点,例如 “生成初稿后必走审核节点”。
条件边:根据状态中的数据动态选择下一个节点,例如 “若审核通过则进入输出节点,否则返回修改节点”。
循环边:允许从节点 A 跳转回节点 B 形成闭环,例如 “修改后重新进入审核节点,直到通过”。
3. 状态(State) 全局的 “数据容器”, 用于在节点间传递信息(如上下文、工具结果、中间节点的结果数据)
特性:
可持久化:支持内存、文件或数据库(如 Redis)存储,确保流程中断后可恢复。
可扩展:通过类定义自定义状态结构(如 class AgentState: messages: list; tool_results: dict)。
现在开始这个项目,项目流程图如下
flowchart LR
A[输入问题] --> B[预处理(清洗/关键词提取)]
B --> C[语义检索(FAISS)]
B --> D[关键词检索(BM25)]
C & D --> E[结果合并+去重]
E --> F[重排序(CrossEncoder)]
F --> G{检索结果是否足够?}
G -- 不足 --> B
G -- 足够 --> H[生成回答(LLM+上下文)]
H --> I{回答是否合规?}
I -- 不合规 --> H
I -- 合规 --> J[返回最终回答]

第一步定义节点:
(节点0 A是输入问题,没有任何功能不能算个节点)
节点1 问题预处理节点(即B)
节点2 语义检索©
节点3 关键词检索(D)
节点4 结果合并去重(E)
节点5 重排序节点(F)
节点6 校验检索结果(G)
节点7 生成回答节点(H)
节点8 检查合规节点(I)
(节点9 输出回答,这里是直接输出,也不能算节点)
第二步定义边:
按照流程图, 使用add_edge 连接节点和节点, 并行节点也使用add_edge, 条件分支要使用add_conditional_edges
第三步定义状态:
状态是一个类型化字典(TypedDict),用于存储各个节点的结果,加上输入问题 一共9个
如果还需要其他结果,还可以继续定义,只要定义好,每个变量代表什么意思就行了
class RAGState(TypedDict):
user_question: str | None # 用户原始问题
cleaned_question: str | None # 节点1 : 问题预处理后的结果
semantic_docs: list | None # 节点2: 语义检索的结果
keyword_docs: list | None # 节点3:关键词检索结果
merged_docs: list | None # 节点4:结果合并去重后的文档
reranked_docs: list | None # 节点5 重排序节点产生的结果
retrieve_flag: bool | None # 节点6: 校验检索结果是否足够合规的结果((True/False)
generated_answer: str | None #节点7:生成回答节点的结果
answer_flag: bool | None # 节点8 检查合规节点结果(True/False)
以下是一个完整的demo
# 前置条件 设置一个向量数据库
import chromadb
# 模拟知识库(中文场景)
from langchain_huggingface import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(
model_name=r"D:\MyWork\GitProject\LLM\my_llm\.model\embedding\bge-small-zh-v1.5",
model_kwargs={
"device": "cpu",
"trust_remote_code": True,
},
encode_kwargs={"normalize_embeddings": True},
)
KNOWLEDGE_BASE = [
"智能水杯支持水温实时检测,范围0-100℃",
"智能水杯续航可达72小时,Type-C充电接口",
"智能水杯可连接手机APP,记录每日饮水量",
"智能水杯采用食品级316不锈钢材质,安全无毒",
"智能水杯具备防漏设计,适合户外出行使用",
"智能水杯的APP支持iOS和Android双系统",
"智能水杯充电时间约2小时,充满后自动断电",
"智能水杯的电池容量为500mAh,续航稳定"
]
def set_chroma_db():
chroma_client = chromadb.PersistentClient(
path="./.temp/chroma" # 数据会保存在当前目录的 chroma_memory_data 文件夹中
)
collection_clent = chroma_client.get_or_create_collection(name="my_rag")
counter = 0
id_list = [str(counter + i + 1) for i in range(len(KNOWLEDGE_BASE))]
embedding_results = embedding_model.embed_documents(KNOWLEDGE_BASE)
collection_clent.add(ids=id_list, documents=KNOWLEDGE_BASE, embeddings=embedding_results)
# ===================== 1. 定义全局状态 是一个类型化字典(TypedDict)=====================
from typing import TypedDict
class RAGState(TypedDict):
user_question: str | None # 用户原始问题
cleaned_question: str | None # 节点1 : 问题预处理后的结果
semantic_docs: list | None # 节点2: 语义检索的结果
keyword_docs: list | None # 节点3:关键词检索结果
merged_docs: list | None # 节点4:结果合并去重后的文档
reranked_docs: list | None # 节点5 重排序节点产生的结果
retrieve_flag: bool | None # 节点6: 校验检索结果是否足够合规的结果((True/False)
generated_answer: str | None #节点7:生成回答节点的结果
answer_flag: bool | None # 节点8 检查合规节点结果(True/False)
rerank_scores: list | None # 额外增加字段:重排序得分, 用于评估
# ========================================== 2. 定义节点函数====================================
# 节点1:预处理问题(清洗/关键词提取等)
import re
def preprocess_question(state: RAGState) -> RAGState:
question = state["user_question"]
# 删除所有无关的标点、特殊符号和首尾空白,保留有效文本(中文、字母、数字、中间空白)
cleaned_question = re.sub(r'[^\w\s]', '', question).strip()
print(f"【节点1: 预处理】清洗后的问题:{cleaned_question}")
return {"cleaned_question": cleaned_question}
# 节点2:语义检索
def semantic_retrieval(state: RAGState) -> RAGState:
question = state["cleaned_question"]
# 从向量数据库中检索相关文档
chroma_client = chromadb.PersistentClient(
path="./.temp/chroma" # 数据会保存在当前目录的 chroma_memory_data 文件夹中
)
collection_clent = chroma_client.get_or_create_collection(name="my_rag")
results = collection_clent.query(
query_embeddings=embedding_model.embed_documents([question]),
include=["documents"]
)
semantic_docs = results["documents"][0]
print(f"【节点2: 语义检索】结果:{semantic_docs}")
return {"semantic_docs": semantic_docs}
# 节点3:关键词检索
def keyword_retrieval(state: RAGState) -> RAGState:
question = state["cleaned_question"]
# BM25关键词检索相关性,当前成熟的方案是用ES, 这里这里只做简化演示
bm25_tokenizer = lambda text: re.findall(r'\w+', text.lower())
tokens = bm25_tokenizer(question)
tokenized_docs = [bm25_tokenizer(doc) for doc in KNOWLEDGE_BASE]
from rank_bm25 import BM25Okapi
bm25 = BM25Okapi(tokenized_docs)
scores = bm25.get_scores(tokens)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:4] # 多取1个
keyword_docs = [KNOWLEDGE_BASE[i] for i in top_indices]
print(f"【节点3: 关键词检索】结果:{keyword_docs}")
return {"keyword_docs": keyword_docs}
# 节点4:合并+去重
# 这个节点有俩输入, langGraph 会等前面两个节点都完成才会调用这个节点
def merge_docs(state: RAGState) -> RAGState:
merged = []
seen = set()
# 合并语义+关键词检索结果,去重
for doc in state["semantic_docs"] + state["keyword_docs"]:
if doc not in seen:
seen.add(doc)
merged.append(doc)
print(f"【节点4: 合并去重】结果(共{len(merged)}条):{merged}")
return {"merged_docs":merged}
# 节点5:重排序节点
# 从排序节点用到了重排序模型
def rerank_docs(state: RAGState) -> RAGState:
question = state["cleaned_question"]
docs = state["merged_docs"]
# 边界处理:无文档直接返回空
if not docs:
print("【重排序】无文档可排序,返回空")
return {"reranked_docs": [], "rerank_scores": []}
# 1. 构造重排序输入:(问题, 文档) 成对输入
rerank_pairs = [(question, doc) for doc in docs]
# 2. 计算相关性得分(BGE Reranker输出0~1的得分,越高越相关)
reranker_model_path = r"D:\MyWork\GitProject\LLM\my_llm\.model\reranker\bge-reranker-base"
from FlagEmbedding import FlagReranker
rerank_model = FlagReranker(model_name_or_path=reranker_model_path,normalize=True,device="cpu")
scores = rerank_model.compute_score(rerank_pairs)
print("【重排序】相关性得分:", scores)
# print(f"【重排序】原始得分:{[(doc, round(score, 3)) for doc, score in zip(docs, scores)]}")
# 3. 过滤低得分文档(≥阈值)
RERANK_THRESHOLD = 0.5 # 重排序得分阈值:只保留≥0.5的文档(可根据场景调整)
RERANK_THRESHOLD = 0.5 # 重排序得分阈值:只保留≥0.5的文档(可根据场景调整)
RERANK_TOP_K = 3 # 最终保留的TopK文档数
MIN_RERANK_NUM = 1 # 原始结果数不足时,保留的TopK文档数
filtered_docs_scores = [(doc, score) for doc, score in zip(docs, scores) if score >= RERANK_THRESHOLD]
if not filtered_docs_scores:
print(f"【重排序】无文档达到得分阈值(≥{RERANK_THRESHOLD}),保留原始Top{MIN_RERANK_NUM}")
filtered_docs_scores = [(doc, score) for doc, score in zip(docs, scores)][:MIN_RERANK_NUM]
# 4. 按得分降序排序,取TopK
filtered_docs_scores.sort(key=lambda x: x[1], reverse=True)
reranked_docs = [doc for doc, _ in filtered_docs_scores[:RERANK_TOP_K]]
rerank_scores = [round(score, 3) for _, score in filtered_docs_scores[:RERANK_TOP_K]]
print(f"【节点5: 重排序】最终结果(得分≥{RERANK_THRESHOLD},Top{RERANK_TOP_K}):")
for doc, score in zip(reranked_docs, rerank_scores):
print(f" - {doc}(得分:{score})")
return {
"reranked_docs": reranked_docs,
"rerank_scores": rerank_scores # 保存得分,方便后续调试/监控
}
# 节点6:校验检索结果=仅作演示 只判断数量
def check_retrieval(state: RAGState) -> str:
# if len(state["reranked_docs"]) >= 1:
# state["retrieve_flag"] = True
# print(f"【节点6: 检索校验】结果足够({len(state['reranked_docs'])}条),进入生成环节")
# return "generate_answer"
# else:
# state["retrieve_flag"] = False
# print(f"【节点6: 检索校验】结果不足,重新检索")
# return "preprocess_question"
# 这里我们人为制造一个循环,模拟检索结果不够的情况
import random
if random.random() > 0.8:
state["retrieve_flag"] = True
print(f"【节点6: 检索校验】结果足够({len(state['reranked_docs'])}条),进入生成环节")
return "generate_answer"
else:
state["retrieve_flag"] = False
print(f"【节点6: 检索校验】结果不足,重新检索")
return "preprocess_question"
# 节点7:生成回答 结合查询到的数据 调用大模型 生成回答
def generate_answer(state: RAGState) -> RAGState:
question = state["user_question"]
docs = state["reranked_docs"]
# 构建上下文
context = "\n".join([f"- {doc}" for doc in docs])
prompt = f"基于以下高相关的上下文信息回答用户问题,仅使用上下文内容,不要编造任何信息:\n{context}\n用户问题:{question}\n简洁回答:"
print(f"【生成回答】输入:{prompt}")
deepseek_api_base = "https://api.deepseek.com"
deepseek_api_key = "sk-xxxxxxxxxxxxxxxxx"
deepseek_model = "deepseek-chat"
# 设置DeepSeek API Key和基础地址
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(
model=deepseek_model,
openai_api_key=deepseek_api_key,
openai_api_base=deepseek_api_base, # 关键配置
temperature=0.7,
)
# 生成回答
answer = llm.invoke(prompt) # | StrOutputParser()
print(f"【节点7: 生成回答】结果:{answer}")
return {"generated_answer": answer}
# 节点8:校验回答
# 演示简单合规校验, 正常可以使用大模型校验
def check_answer(state: RAGState) -> str:
answer = state["generated_answer"]
print(f"【回答校验】输入:{answer}")
# 简单合规校验:非空、长度≥5、无违规词
# if not answer or len(answer.content) < 5 or "编造" in answer or "不知道" in answer:
# state["answer_flag"] = False
# print(f"【节点8: 回答校验】不合规,重新生成")
# return "generate_answer"
# else:
# state["answer_flag"] = True
# print(f"【节点8: 回答校验】合规,流程结束")
# return "end"
# 这里我们人为制造一个循环
import random
if random.random() > 0.5:
state["answer_flag"] = False
print(f"【节点8: 回答校验】不合规,重新生成")
return "generate_answer"
else:
state["answer_flag"] = True
print(f"【节点8: 回答校验】合规,流程结束")
return "end"
# ===================== 3. 构建LangGraph工作流=====================
from langgraph.graph import StateGraph, END,START
graph_builder = StateGraph(RAGState)
# 添加节点
graph_builder.add_node("preprocess_question", preprocess_question)
graph_builder.add_node("semantic_retrieval", semantic_retrieval)
graph_builder.add_node("keyword_retrieval", keyword_retrieval)
graph_builder.add_node("merge_docs", merge_docs)
graph_builder.add_node("rerank_docs", rerank_docs)
graph_builder.add_node("generate_answer", generate_answer)
# 设置边
# 设置入口 输入直接到节点1 也可以使用 graph_builder.set_entry_point("preprocess_question")
graph_builder.add_edge(START, "preprocess_question")
# 问题预处理之后, 同时进入节点2 语义检索 和 节点3 关键词检索
graph_builder.add_edge("preprocess_question", "semantic_retrieval")
graph_builder.add_edge("preprocess_question", "keyword_retrieval")
#节点2和节点3完成后 会合并到节点4 (扇入:等待两个检索节点都完成,才执行合并)
graph_builder.add_edge("semantic_retrieval", "merge_docs")
graph_builder.add_edge("keyword_retrieval", "merge_docs")
# 节点4 合并+去重 完成后, 进入节点5 重排序
graph_builder.add_edge("merge_docs", "rerank_docs")
# 节点6 是一个路由节点, 控制流程进入那个节点
graph_builder.add_conditional_edges(
"rerank_docs",
check_retrieval,
{"generate_answer": "generate_answer", "preprocess_question": "preprocess_question"}
)
# 同上
graph_builder.add_conditional_edges(
"generate_answer",
check_answer,
{"end": END, "generate_answer": "generate_answer"}
)
rag_graph = graph_builder.compile()
if __name__ == '__main__':
user_question = "智能水杯的充电时间和续航分别是多少?"
initial_state = {"user_question": user_question}
# 执行
result = rag_graph.invoke(initial_state)
print("\n===================== 最终结果 =====================")
print(f"用户问题:{user_question}")
print(f"重排序后高相关文档:{result['reranked_docs']}")
print(f"重排序得分:{result['rerank_scores']}")
print(f"最终回答:{result['generated_answer'].content}")
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)