在前面的文章中,提到了混合型的RAG;这里要用langGraph实现一个混合型的RAG
复习一下LangGraph的基础:
1. 节点(Nodes) 任务的执行单元, 具体的逻辑或功能

节点类型: 
	工具调用节点:调用外部工具(如 API、数据库、计算器),例如 “调用天气 API 获取实时温度”。
	LLM 节点:调用大语言模型(如 GPT-4、Claude)生成文本,例如 “根据用户问题生成回答初稿”。
	逻辑处理节点:执行自定义代码(如数据清洗、格式校验),例如 “检查 LLM 输出是否包含敏感词”。
	人工节点:暂停流程等待人工输入(如 “需人工审核高风险内容”)。
特性:节点是 “黑盒”,仅通过输入(状态中的数据)和输出(更新状态)与外部交互,便于模块化复用。

2. 边(Edges) 节点间的 “流转规则”, 决定了 “执行完当前节点后,下一步该去哪里”

边的类型: 
	无条件边:固定指向某一节点,例如 “生成初稿后必走审核节点”。
	条件边:根据状态中的数据动态选择下一个节点,例如 “若审核通过则进入输出节点,否则返回修改节点”。
	循环边:允许从节点 A 跳转回节点 B 形成闭环,例如 “修改后重新进入审核节点,直到通过”。

3. 状态(State) 全局的 “数据容器”, 用于在节点间传递信息(如上下文、工具结果、中间节点的结果数据)

特性:
	可持久化:支持内存、文件或数据库(如 Redis)存储,确保流程中断后可恢复。
	可扩展:通过类定义自定义状态结构(如 class AgentState: messages: list; tool_results: dict)。

现在开始这个项目,项目流程图如下

flowchart LR
    A[输入问题] --> B[预处理(清洗/关键词提取)]
    B --> C[语义检索(FAISS)]
    B --> D[关键词检索(BM25)]
    C & D --> E[结果合并+去重]
    E --> F[重排序(CrossEncoder)]
    F --> G{检索结果是否足够?}
    G -- 不足 --> B
    G -- 足够 --> H[生成回答(LLM+上下文)]
    H --> I{回答是否合规?}
    I -- 不合规 --> H
    I -- 合规 --> J[返回最终回答]

在这里插入图片描述
第一步定义节点:
(节点0 A是输入问题,没有任何功能不能算个节点)
节点1 问题预处理节点(即B)
节点2 语义检索©
节点3 关键词检索(D)
节点4 结果合并去重(E)
节点5 重排序节点(F)
节点6 校验检索结果(G)
节点7 生成回答节点(H)
节点8 检查合规节点(I)
(节点9 输出回答,这里是直接输出,也不能算节点)

第二步定义:
按照流程图, 使用add_edge 连接节点和节点, 并行节点也使用add_edge, 条件分支要使用add_conditional_edges

第三步定义状态:
状态是一个类型化字典(TypedDict),用于存储各个节点的结果,加上输入问题 一共9个
如果还需要其他结果,还可以继续定义,只要定义好,每个变量代表什么意思就行了

class RAGState(TypedDict):
    user_question: str | None         # 用户原始问题
    cleaned_question: str   | None     # 节点1 : 问题预处理后的结果
    semantic_docs: list  | None         # 节点2: 语义检索的结果
    keyword_docs: list   | None         # 节点3:关键词检索结果
    merged_docs: list    | None         # 节点4:结果合并去重后的文档
    reranked_docs: list   | None        # 节点5  重排序节点产生的结果
    retrieve_flag: bool    | None       # 节点6: 校验检索结果是否足够合规的结果((True/False)
    generated_answer: str   | None      #节点7:生成回答节点的结果
    answer_flag: bool    | None         # 节点8	检查合规节点结果(True/False)

以下是一个完整的demo

# 前置条件 设置一个向量数据库

import chromadb
# 模拟知识库(中文场景)
from langchain_huggingface import HuggingFaceEmbeddings
embedding_model = HuggingFaceEmbeddings(
    model_name=r"D:\MyWork\GitProject\LLM\my_llm\.model\embedding\bge-small-zh-v1.5",
    model_kwargs={
        "device": "cpu",
        "trust_remote_code": True,
    },
    encode_kwargs={"normalize_embeddings": True},
)

KNOWLEDGE_BASE = [
    "智能水杯支持水温实时检测,范围0-100℃",
    "智能水杯续航可达72小时,Type-C充电接口",
    "智能水杯可连接手机APP,记录每日饮水量",
    "智能水杯采用食品级316不锈钢材质,安全无毒",
    "智能水杯具备防漏设计,适合户外出行使用",
    "智能水杯的APP支持iOS和Android双系统",
    "智能水杯充电时间约2小时,充满后自动断电",
    "智能水杯的电池容量为500mAh,续航稳定"
    ]
def set_chroma_db():
    chroma_client = chromadb.PersistentClient(
        path="./.temp/chroma"  # 数据会保存在当前目录的 chroma_memory_data 文件夹中
    )
    collection_clent = chroma_client.get_or_create_collection(name="my_rag")

    counter = 0
    id_list = [str(counter + i + 1) for i in range(len(KNOWLEDGE_BASE))]
    embedding_results = embedding_model.embed_documents(KNOWLEDGE_BASE)
    collection_clent.add(ids=id_list, documents=KNOWLEDGE_BASE, embeddings=embedding_results)









# ===================== 1. 定义全局状态 是一个类型化字典(TypedDict)=====================
from typing import TypedDict
class RAGState(TypedDict):
    user_question: str | None         # 用户原始问题
    cleaned_question: str   | None     # 节点1 : 问题预处理后的结果
    semantic_docs: list  | None         # 节点2: 语义检索的结果
    keyword_docs: list   | None         # 节点3:关键词检索结果
    merged_docs: list    | None         # 节点4:结果合并去重后的文档
    reranked_docs: list   | None        # 节点5  重排序节点产生的结果
    retrieve_flag: bool    | None       # 节点6: 校验检索结果是否足够合规的结果((True/False)
    generated_answer: str   | None      #节点7:生成回答节点的结果
    answer_flag: bool    | None         # 节点8	检查合规节点结果(True/False)
    rerank_scores: list    | None       # 额外增加字段:重排序得分, 用于评估



# ========================================== 2. 定义节点函数====================================

# 节点1:预处理问题(清洗/关键词提取等)
import re
def preprocess_question(state: RAGState) -> RAGState:
    question = state["user_question"]

    # 删除所有无关的标点、特殊符号和首尾空白,保留有效文本(中文、字母、数字、中间空白)
    cleaned_question = re.sub(r'[^\w\s]', '', question).strip()
    print(f"【节点1: 预处理】清洗后的问题:{cleaned_question}")
    return {"cleaned_question": cleaned_question}

# 节点2:语义检索
def semantic_retrieval(state: RAGState) -> RAGState:
    question = state["cleaned_question"]

    # 从向量数据库中检索相关文档
    chroma_client = chromadb.PersistentClient(
        path="./.temp/chroma"  # 数据会保存在当前目录的 chroma_memory_data 文件夹中
    )
    collection_clent = chroma_client.get_or_create_collection(name="my_rag")
    results = collection_clent.query(
         query_embeddings=embedding_model.embed_documents([question]),   
         include=["documents"]
    )
    semantic_docs = results["documents"][0]
    print(f"【节点2: 语义检索】结果:{semantic_docs}")
    return {"semantic_docs": semantic_docs}

# 节点3:关键词检索

def keyword_retrieval(state: RAGState) -> RAGState:
    question = state["cleaned_question"]

    # BM25关键词检索相关性,当前成熟的方案是用ES, 这里这里只做简化演示
    bm25_tokenizer = lambda text: re.findall(r'\w+', text.lower())
    tokens = bm25_tokenizer(question)
    tokenized_docs = [bm25_tokenizer(doc) for doc in KNOWLEDGE_BASE]
    from rank_bm25 import BM25Okapi
    bm25 = BM25Okapi(tokenized_docs)
    scores = bm25.get_scores(tokens)
    top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:4]  # 多取1个
    keyword_docs = [KNOWLEDGE_BASE[i] for i in top_indices]
    print(f"【节点3: 关键词检索】结果:{keyword_docs}")
    return {"keyword_docs": keyword_docs}



# 节点4:合并+去重
# 这个节点有俩输入, langGraph 会等前面两个节点都完成才会调用这个节点
def merge_docs(state: RAGState) -> RAGState:
    merged = []
    seen = set()
    # 合并语义+关键词检索结果,去重
    for doc in state["semantic_docs"] + state["keyword_docs"]:
        if doc not in seen:
            seen.add(doc)
            merged.append(doc)
    print(f"【节点4: 合并去重】结果(共{len(merged)}条):{merged}")
    return {"merged_docs":merged}


# 节点5:重排序节点
# 从排序节点用到了重排序模型
def rerank_docs(state: RAGState) -> RAGState:
    question = state["cleaned_question"]
    docs = state["merged_docs"]
    
    # 边界处理:无文档直接返回空
    if not docs:
        print("【重排序】无文档可排序,返回空")
        return {"reranked_docs": [], "rerank_scores": []}
    
    # 1. 构造重排序输入:(问题, 文档) 成对输入
    rerank_pairs = [(question, doc) for doc in docs]
    
    # 2. 计算相关性得分(BGE Reranker输出0~1的得分,越高越相关)
    reranker_model_path = r"D:\MyWork\GitProject\LLM\my_llm\.model\reranker\bge-reranker-base"
    from FlagEmbedding import FlagReranker
    rerank_model = FlagReranker(model_name_or_path=reranker_model_path,normalize=True,device="cpu")
    scores = rerank_model.compute_score(rerank_pairs)
    print("【重排序】相关性得分:", scores)
   # print(f"【重排序】原始得分:{[(doc, round(score, 3)) for doc, score in zip(docs, scores)]}")
    
    # 3. 过滤低得分文档(≥阈值)
    RERANK_THRESHOLD = 0.5  # 重排序得分阈值:只保留≥0.5的文档(可根据场景调整)
    RERANK_THRESHOLD = 0.5  # 重排序得分阈值:只保留≥0.5的文档(可根据场景调整)
    RERANK_TOP_K = 3        # 最终保留的TopK文档数
    MIN_RERANK_NUM = 1      # 原始结果数不足时,保留的TopK文档数
    filtered_docs_scores = [(doc, score) for doc, score in zip(docs, scores) if score >= RERANK_THRESHOLD]
    if not filtered_docs_scores:
        print(f"【重排序】无文档达到得分阈值(≥{RERANK_THRESHOLD}),保留原始Top{MIN_RERANK_NUM}")
        filtered_docs_scores = [(doc, score) for doc, score in zip(docs, scores)][:MIN_RERANK_NUM]
    
    # 4. 按得分降序排序,取TopK
    filtered_docs_scores.sort(key=lambda x: x[1], reverse=True)
    reranked_docs = [doc for doc, _ in filtered_docs_scores[:RERANK_TOP_K]]
    rerank_scores = [round(score, 3) for _, score in filtered_docs_scores[:RERANK_TOP_K]]
    
    print(f"【节点5:  重排序】最终结果(得分≥{RERANK_THRESHOLD},Top{RERANK_TOP_K}):")
    for doc, score in zip(reranked_docs, rerank_scores):
        print(f"  - {doc}(得分:{score})")
    return {
        "reranked_docs": reranked_docs,
       "rerank_scores": rerank_scores  # 保存得分,方便后续调试/监控
    }

# 节点6:校验检索结果=仅作演示  只判断数量
def check_retrieval(state: RAGState) -> str:
    # if len(state["reranked_docs"]) >= 1:
    #     state["retrieve_flag"] = True
    #     print(f"【节点6: 检索校验】结果足够({len(state['reranked_docs'])}条),进入生成环节")
    #     return "generate_answer"
    # else:
    #     state["retrieve_flag"] = False
    #     print(f"【节点6: 检索校验】结果不足,重新检索")
    #     return "preprocess_question"

    # 这里我们人为制造一个循环,模拟检索结果不够的情况
    import random
    if random.random() > 0.8:
        state["retrieve_flag"] = True
        print(f"【节点6: 检索校验】结果足够({len(state['reranked_docs'])}条),进入生成环节")
        return "generate_answer"
    else:
        state["retrieve_flag"] = False
        print(f"【节点6: 检索校验】结果不足,重新检索")
        return "preprocess_question"


# 节点7:生成回答    结合查询到的数据 调用大模型  生成回答
def generate_answer(state: RAGState) -> RAGState:
    question = state["user_question"]
    docs = state["reranked_docs"]
    # 构建上下文
    context = "\n".join([f"- {doc}" for doc in docs])
    prompt = f"基于以下高相关的上下文信息回答用户问题,仅使用上下文内容,不要编造任何信息:\n{context}\n用户问题:{question}\n简洁回答:"
    print(f"【生成回答】输入:{prompt}")
    deepseek_api_base = "https://api.deepseek.com"
    deepseek_api_key = "sk-xxxxxxxxxxxxxxxxx"
    deepseek_model = "deepseek-chat"

    # 设置DeepSeek API Key和基础地址
    from langchain_openai import ChatOpenAI
    llm = ChatOpenAI(
        model=deepseek_model,
        openai_api_key=deepseek_api_key,
        openai_api_base=deepseek_api_base,  # 关键配置
        temperature=0.7,
    )
    # 生成回答
    answer = llm.invoke(prompt) # | StrOutputParser()
    print(f"【节点7: 生成回答】结果:{answer}")
    return {"generated_answer": answer}


# 节点8:校验回答 
# 演示简单合规校验, 正常可以使用大模型校验
def check_answer(state: RAGState) -> str:
    answer = state["generated_answer"]
    print(f"【回答校验】输入:{answer}")
    # 简单合规校验:非空、长度≥5、无违规词
    # if not answer or len(answer.content) < 5 or "编造" in answer or "不知道" in answer:
    #     state["answer_flag"] = False
    #     print(f"【节点8:  回答校验】不合规,重新生成")
    #     return "generate_answer"
    # else:
    #     state["answer_flag"] = True
    #     print(f"【节点8: 回答校验】合规,流程结束")
    #     return "end"

    # 这里我们人为制造一个循环
    import random
    if random.random() > 0.5:
        state["answer_flag"] = False
        print(f"【节点8:  回答校验】不合规,重新生成")
        return "generate_answer"
    else:
        state["answer_flag"] = True
        print(f"【节点8: 回答校验】合规,流程结束")
        return "end"

# ===================== 3. 构建LangGraph工作流=====================
from langgraph.graph import StateGraph, END,START
graph_builder = StateGraph(RAGState)
# 添加节点
graph_builder.add_node("preprocess_question", preprocess_question)
graph_builder.add_node("semantic_retrieval", semantic_retrieval)
graph_builder.add_node("keyword_retrieval", keyword_retrieval)
graph_builder.add_node("merge_docs", merge_docs)
graph_builder.add_node("rerank_docs", rerank_docs)
graph_builder.add_node("generate_answer", generate_answer)

# 设置边

# 设置入口  输入直接到节点1   也可以使用 graph_builder.set_entry_point("preprocess_question")
graph_builder.add_edge(START, "preprocess_question")

# 问题预处理之后, 同时进入节点2 语义检索 和 节点3 关键词检索
graph_builder.add_edge("preprocess_question", "semantic_retrieval")
graph_builder.add_edge("preprocess_question", "keyword_retrieval")

#节点2和节点3完成后 会合并到节点4 (扇入:等待两个检索节点都完成,才执行合并)
graph_builder.add_edge("semantic_retrieval", "merge_docs")
graph_builder.add_edge("keyword_retrieval", "merge_docs")

# 节点4 合并+去重 完成后, 进入节点5 重排序
graph_builder.add_edge("merge_docs", "rerank_docs")

# 节点6 是一个路由节点,  控制流程进入那个节点
graph_builder.add_conditional_edges(
    "rerank_docs",
    check_retrieval,
    {"generate_answer": "generate_answer", "preprocess_question": "preprocess_question"}
)

# 同上
graph_builder.add_conditional_edges(
    "generate_answer",
    check_answer,
    {"end": END, "generate_answer": "generate_answer"}
)

rag_graph = graph_builder.compile()


if __name__ == '__main__':
    user_question = "智能水杯的充电时间和续航分别是多少?"
    initial_state = {"user_question": user_question}
    # 执行
    result = rag_graph.invoke(initial_state)
    print("\n===================== 最终结果 =====================")
    print(f"用户问题:{user_question}")
    print(f"重排序后高相关文档:{result['reranked_docs']}")
    print(f"重排序得分:{result['rerank_scores']}")
    print(f"最终回答:{result['generated_answer'].content}")
Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐