目录

1. 框架学习

1.1 功能支持

1.2 召回(2路召回)

(1)带权重的关键词匹配(核心功能是elasticsearch boost查询)

(2)向量匹配(核心算法是elasticsearch knn检索)

1.3 排序

(1)关键词相似度

(2)向量匹配相似度

2. 源码阅读

2.1 借用当前针对github代码解析的大模型进行总结概括

2.2 回溯看源码

2.2.2 api/apps/conversation_app.py:completion()(API端点)

2.2.3 api/db/services/dialog_service.py:chat()(服务层)

2.2.4 rag/nlp/search.py:Dealer.retrieval()

2.2.5 rag/nlp/search.py:Dealer.search()

2.2.6 rerank

2.2.6.1 rag/nlp/search.py:Dealer.rerank()

2.2.6.2 rag/nlp/search.py:Dealer.rerank_by_model()

2.2.7 rag/nlp/query.py:question()<——self.qryr.question

3. ragflow网站


1. 框架学习

1.1 功能支持

        多种数据文档类型、多种切分方式

1.2 召回(2路召回)

(1)带权重的关键词匹配(核心功能是elasticsearch boost查询)

        elasticsearch boost查询:人为放大或缩小某些词的相关性得分。用boost在查询里“显式拉分”用 function_score把业务因子“算进得分”,两者组合即可精细化控制 Elasticsearch 的排序结果。

  • 查询级 boost(最简单)

        在 match / term / bool 等查询里直接加boost 数值,大于 1 表示“加分”,小于 1 “降权”。

  • function_score 精细 boost(高级场景)

        当单纯倍数调整不够时,用 function_score把业务字段(销量、评分、距离、时间衰减等)变成额外得分因子。eg:把销量对数化后乘以原得分,并用高斯衰减让远离 300 元价格区间的文档降分

(2)向量匹配(核心算法是elasticsearch knn检索)

        elasticsearch knn检索:核心思想是 分层,逐层定位,到底层贪婪搜索beam search。 建立图索引,降低耗时,在设置索引时,设置"knn": true。

        牺牲 1%-5% 精度,换来 100-1000× 速度提升。

        近似 k-NN(HNSW)与精确 k-NN(暴力)对比:

维度

近似 k-NN(HNSW)

精确 k-NN(暴力)

算法

HNSW(分层可导航小世界图)

线性扫描所有向量

速度

毫秒级(索引后)

随数据量线性增长,秒~分钟

召回率

>95%(可调)

100%

资源

需额外内存存储图索引

仅原始向量即可

适用场景

百万~十亿向量、实时检索

数据量小或离线验证

ES 开启方式

建索引时 "knn": true + dense_vector 索引

查询里用 script_score 或 knn: false

可调参数

ef_searchef_constructionM

参数

位置

建议值

作用

ef_search

index settings

100-200

查询时遍历节点,越高越准越慢

ef_construction

index settings

200-400

构建时遍历节点,越高索引越慢、召回越高

M

index settings

16-64

每个节点最大邻居数,越大内存/精度↑

num_candidates

查询

1.5-2×k

候选集大小

1.3 排序

        关键词相似度与向量匹配相似度加权

(1)关键词相似度

(2)向量匹配相似度

        专门的排序模型,输入是关键词序列,不是原始文本内容

2. 源码阅读

2.1 借用当前针对github代码解析的大模型进行总结概括

        可以直接提问,还挺好用的: https://deepwiki.com/infiniflow/ragflow/2-system-architecture   

2.2 回溯看源码

2.2.1 api/ragflow_server.py(系统启动,主进程)

Python
# 启动基于Web的服务器、初始化数据库表和数据、加载插件
run_simple(
    hostname=settings.HOST_IP,
    port=settings.HOST_PORT,
    application=app,
    threaded=True,
    use_reloader=RuntimeConfig.DEBUG,
    use_debugger=RuntimeConfig.DEBUG,
)

api/apps/init.py(Flask应用,路由注册)):实例化web应用程序(配置)、用户认证

2.2.2 api/apps/conversation_app.py:completion()(API端点)

Python
@manager.route("/completion", methods=["POST"])  # noqa: F821
@login_required
@validate_request("conversation_id", "messages")
def completion():
    req = request.json
    msg = []
    for m in req["messages"]:
        if m["role"] == "system":
            continue
        if m["role"] == "assistant" and not msg:
            continue
        msg.append(m)
    message_id = msg[-1].get("id")
    try:
        e, conv = ConversationService.get_by_id(req["conversation_id"])
        if not e:
            return get_data_error_result(message="Conversation not found!")
        conv.message = deepcopy(req["messages"])
        e, dia = DialogService.get_by_id(conv.dialog_id)
        if not e:
            return get_data_error_result(message="Dialog not found!")
        del req["conversation_id"]
        del req["messages"]

        if not conv.reference:
            conv.reference = []
        else:
            for ref in conv.reference:
                if isinstance(ref, list):
                    continue
                ref["chunks"] = chunks_format(ref)

        if not conv.reference:
            conv.reference = []
        conv.reference.append({"chunks": [], "doc_aggs": []})

        def stream():
            nonlocal dia, msg, req, conv
            try:
                for ans in chat(dia, msg, True, **req):
                    ans = structure_answer(conv, ans, message_id, conv.id)
                    yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
                ConversationService.update_by_id(conv.id, conv.to_dict())
            except Exception as e:
                traceback.print_exc()
                yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
            yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"

        if req.get("stream", True):
            resp = Response(stream(), mimetype="text/event-stream")
            resp.headers.add_header("Cache-control", "no-cache")
            resp.headers.add_header("Connection", "keep-alive")
            resp.headers.add_header("X-Accel-Buffering", "no")
            resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
            return resp

        else:
            answer = None
            for ans in chat(dia, msg, **req):
                answer = structure_answer(conv, ans, message_id, conv.id)
                ConversationService.update_by_id(conv.id, conv.to_dict())
                break
            return get_json_result(data=answer)
    except Exception as e:
        return server_error_response(e)

Python
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat

2.2.3 api/db/services/dialog_service.py:chat()(服务层)

Python
def chat(dialog, messages, stream=True, **kwargs):
    assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
    if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
        for ans in chat_solo(dialog, messages, stream):
            yield ans
        return

    chat_start_ts = timer()

    if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text":
        llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
    else:
        llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)

    max_tokens = llm_model_config.get("max_tokens", 8192)

    check_llm_ts = timer()

    langfuse_tracer = None
    trace_context = {}
    langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id)
    if langfuse_keys:
        langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
        if langfuse.auth_check():
            langfuse_tracer = langfuse
            trace_id = langfuse_tracer.create_trace_id()
            trace_context = {"trace_id": trace_id}

    check_langfuse_tracer_ts = timer()
    kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog)
    toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools")
    if toolcall_session and tools:
        chat_mdl.bind_tools(toolcall_session, tools)
    bind_models_ts = timer()

    retriever = settings.retrievaler
    questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
    attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
    if "doc_ids" in messages[-1]:
        attachments = messages[-1]["doc_ids"]
    prompt_config = dialog.prompt_config
    field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
    # try to use sql if field mapping is good to go
    if field_map:
        logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
        ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
        if ans:
            yield ans
            return

    for p in prompt_config["parameters"]:
        if p["key"] == "knowledge":
            continue
        if p["key"] not in kwargs and not p["optional"]:
            raise KeyError("Miss parameter: " + p["key"])
        if p["key"] not in kwargs:
            prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")

    if len(questions) > 1 and prompt_config.get("refine_multiturn"):
        questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
    else:
        questions = questions[-1:]

    if prompt_config.get("cross_languages"):
        questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]

    if prompt_config.get("keyword", False):
        questions[-1] += keyword_extraction(chat_mdl, questions[-1])

    refine_question_ts = timer()

    thought = ""
    kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}

    if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
        knowledges = []
    else:
        tenant_ids = list(set([kb.tenant_id for kb in kbs]))
        knowledges = []
        if prompt_config.get("reasoning", False):
            reasoner = DeepResearcher(
                chat_mdl,
                prompt_config,
                partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3),
            )

            for think in reasoner.thinking(kbinfos, " ".join(questions)):
                if isinstance(think, str):
                    thought = think
                    knowledges = [t for t in think.split("\n") if t]
                elif stream:
                    yield think
        else:
            if embd_mdl:
                kbinfos = retriever.retrieval(
                    " ".join(questions),
                    embd_mdl,
                    tenant_ids,
                    dialog.kb_ids,
                    1,
                    dialog.top_n,
                    dialog.similarity_threshold,
                    dialog.vector_similarity_weight,
                    doc_ids=attachments,
                    top=dialog.top_k,
                    aggs=False,
                    rerank_mdl=rerank_mdl,
                    rank_feature=label_question(" ".join(questions), kbs),
                )
            if prompt_config.get("tavily_api_key"):
                tav = Tavily(prompt_config["tavily_api_key"])
                tav_res = tav.retrieve_chunks(" ".join(questions))
                kbinfos["chunks"].extend(tav_res["chunks"])
                kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
            if prompt_config.get("use_kg"):
                ck = settings.kg_retrievaler.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT))
                if ck["content_with_weight"]:
                    kbinfos["chunks"].insert(0, ck)

            knowledges = kb_prompt(kbinfos, max_tokens)

    logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges)))

    retrieval_ts = timer()
    if not knowledges and prompt_config.get("empty_response"):
        empty_res = prompt_config["empty_response"]
        yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
        return {"answer": prompt_config["empty_response"], "reference": kbinfos}

    kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
    gen_conf = dialog.llm_setting

    msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
    prompt4citation = ""
    if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
        prompt4citation = citation_prompt()
    msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"])
    used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
    assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
    prompt = msg[0]["content"]

    if "max_tokens" in gen_conf:
        gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count)

Python
from api import settings

settings.py:

Python
def init_settings():
    retrievaler = search.Dealer(docStoreConn)

Python
from rag.nlp import search

2.2.4 rag/nlp/search.py:Dealer.retrieval()

Python
# 召回+排序+返回字段
def retrieval(self, question, embd_mdl, tenant_ids, kb_ids, page, page_size, similarity_threshold=0.2,
              vector_similarity_weight=0.3, top=1024, doc_ids=None, aggs=True,
              rerank_mdl=None, highlight=False,
              rank_feature: dict | None = {PAGERANK_FLD: 10}):
    ranks = {"total": 0, "chunks": [], "doc_aggs": {}}# 返回chunk的结构
    if not question:
        return ranks

    RERANK_LIMIT = 64
    RERANK_LIMIT = int(RERANK_LIMIT//page_size + ((RERANK_LIMIT%page_size)/(page_size*1.) + 0.5)) * page_size if page_size>1 else 1
    if RERANK_LIMIT < 1: ## when page_size is very large the RERANK_LIMIT will be 0.
        RERANK_LIMIT = 1
    req = {"kb_ids": kb_ids, "doc_ids": doc_ids, "page": math.ceil(page_size*page/RERANK_LIMIT), "size": RERANK_LIMIT,
           "question": question, "vector": True, "topk": top,
           "similarity": similarity_threshold,
           "available_int": 1}


    if isinstance(tenant_ids, str):
        tenant_ids = tenant_ids.split(",")

    sres = self.search(req, [index_name(tid) for tid in tenant_ids],
                       kb_ids, embd_mdl, highlight, rank_feature=rank_feature)

    if rerank_mdl and sres.total > 0:# 根据是否指定里rerank模型确定rerank
        sim, tsim, vsim = self.rerank_by_model(rerank_mdl,
                                               sres, question, 1 - vector_similarity_weight,
                                               vector_similarity_weight,
                                               rank_feature=rank_feature)
    else:
        sim, tsim, vsim = self.rerank(
            sres, question, 1 - vector_similarity_weight, vector_similarity_weight,
            rank_feature=rank_feature)
    # Already paginated in search function
    # 从相似度数组中筛选出满足特定相似度阈值的文档,根据页数返回结果
    idx = np.argsort(sim * -1)[(page - 1) * page_size:page * page_size] # 为什么要变成负数再排序???
    dim = len(sres.query_vector)
    vector_column = f"q_{dim}_vec"
    zero_vector = [0.0] * dim
    sim_np = np.array(sim)
    if doc_ids: # 如果给定的doc_id,相似度阈值设为0,自定义的不起作用???
        similarity_threshold = 0
    filtered_count = (sim_np >= similarity_threshold).sum()    
    ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
    for i in idx:
        if sim[i] < similarity_threshold:
            break

        id = sres.ids[i]
        chunk = sres.field[id]
        dnm = chunk.get("docnm_kwd", "")
        did = chunk.get("doc_id", "")

        if len(ranks["chunks"]) >= page_size:
            if aggs:
                if dnm not in ranks["doc_aggs"]:
                    ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
                ranks["doc_aggs"][dnm]["count"] += 1
                continue
            break

        position_int = chunk.get("position_int", [])
        d = {
            "chunk_id": id,
            "content_ltks": chunk["content_ltks"],
            "content_with_weight": chunk["content_with_weight"],
            "doc_id": did,
            "docnm_kwd": dnm,
            "kb_id": chunk["kb_id"],
            "important_kwd": chunk.get("important_kwd", []),
            "image_id": chunk.get("img_id", ""),
            "similarity": sim[i],
            "vector_similarity": vsim[i],
            "term_similarity": tsim[i],
            "vector": chunk.get(vector_column, zero_vector),
            "positions": position_int,
            "doc_type_kwd": chunk.get("doc_type_kwd", "")
        }
        if highlight and sres.highlight:
            if id in sres.highlight:
                d["highlight"] = rmSpace(sres.highlight[id])
            else:
                d["highlight"] = d["content_with_weight"]
        ranks["chunks"].append(d)
        if dnm not in ranks["doc_aggs"]:
            ranks["doc_aggs"][dnm] = {"doc_id": did, "count": 0}
        ranks["doc_aggs"][dnm]["count"] += 1
    ranks["doc_aggs"] = [{"doc_name": k,
                          "doc_id": v["doc_id"],
                          "count": v["count"]} for k,
                                                   v in sorted(ranks["doc_aggs"].items(),
                                                               key=lambda x: x[1]["count"] * -1)]
    ranks["chunks"] = ranks["chunks"][:page_size]

    return ranks

        retrieval()方法实现步骤:

  1. 初始化返回结果:创建一个字典 ranks,用于存储检索结果。
  2. 检查问题是否为空:如果问题为空,直接返回空结果。
  3. 计算重新排序的限制:根据 page_size 计算重新排序的限制 RERANK_LIMIT
  4. 构建检索请求:创建一个包含检索参数的请求字典 req
  5. 处理租户ID:如果 tenant_ids 是字符串,则将其转换为列表。
  6. 执行检索:调用 search 方法执行检索操作,获取检索结果 sres
  7. 重新排序:如果提供了重新排序模型 rerank_mdl,则使用该模型进行重新排序;否则,使用默认的重新排序方法。
  8. 分页处理:根据 page 和 page_size 对检索结果进行分页处理。
  9. 过滤相似度:根据相似度阈值过滤结果。
  10. 构建返回结果:将过滤后的结果构建成最终返回的字典 ranks
  11. 返回结果:返回检索结果 ranks

2.2.5 rag/nlp/search.py:Dealer.search()

Python
def search(self, req, idx_names: str | list[str], # idx_names是什么???
           kb_ids: list[str],
           emb_mdl=None,
           highlight=False,
           rank_feature: dict | None = None
           ):
    filters = self.get_filters(req) # 筛选关键字段的key和value:dataset_id、doc_id、entity等keywords
    orderBy = OrderByExpr() # 通常在数据库查询或数据处理操作中,根据特定字段对数据进行排序

    pg = int(req.get("page", 1)) - 1
    topk = int(req.get("topk", 1024))
    ps = int(req.get("size", topk))
    offset, limit = pg * ps, ps # 定位到指定页的召回结果

    # src指定了什么?
    src = req.get("fields",
                  ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", "position_int",
                   "doc_id", "page_num_int", "top_int", "create_timestamp_flt", "knowledge_graph_kwd",
                   "question_kwd", "question_tks", "doc_type_kwd",
                   "available_int", "content_with_weight", PAGERANK_FLD, TAG_FLD])
    kwds = set([])

    qst = req.get("question", "")
    q_vec = []
    if not qst:
        if req.get("sort"):
            orderBy.asc("page_num_int")
            orderBy.asc("top_int")
            orderBy.desc("create_timestamp_flt")
        res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
        total = self.dataStore.getTotal(res)
        logging.debug("Dealer.search TOTAL: {}".format(total))
    else:
        highlightFields = ["content_ltks", "title_tks"] if highlight else [] # hig hlight什么功能???内容和标题fen ci
        matchText, keywords = self.qryr.question(qst, min_match=0.3) # query匹配什么,关键词是什么???min_match控制着什么???
        if emb_mdl is None:
            matchExprs = [matchText]
            res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
                                        idx_names, kb_ids, rank_feature=rank_feature) # 根据query匹配结果去数据库检索
            total = self.dataStore.getTotal(res)
            logging.debug("Dealer.search TOTAL: {}".format(total))
        else:
            matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) # topk是计算向量相似度返回的chunk数量,为什么是query em的入参,有什么关系???
            q_vec = matchDense.embedding_data
            src.append(f"q_{len(q_vec)}_vec")

            fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"})
            matchExprs = [matchText, matchDense, fusionExpr] # 如果有emb模型,匹配表达式就多了向量、融合,匹配表达式的text到底是什么

            res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
                                        idx_names, kb_ids, rank_feature=rank_feature)
            total = self.dataStore.getTotal(res)
            logging.debug("Dealer.search TOTAL: {}".format(total))

            # If result is empty, try again with lower min_match
            if total == 0:
                if filters.get("doc_id"):
                    res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
                    total = self.dataStore.getTotal(res)
                else:
                    matchText, _ = self.qryr.question(qst, min_match=0.1)
                    matchDense.extra_options["similarity"] = 0.17
                    res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
                                                orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
                    total = self.dataStore.getTotal(res)
                logging.debug("Dealer.search 2 TOTAL: {}".format(total))

        for k in keywords:
            kwds.add(k)
            for kk in rag_tokenizer.fine_grained_tokenize(k).split():
                if len(kk) < 2:
                    continue
                if kk in kwds:
                    continue
                kwds.add(kk) # 增加关键词的细粒度分词

    logging.debug(f"TOTAL: {total}")
    ids = self.dataStore.getChunkIds(res)
    keywords = list(kwds)
    highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
    aggs = self.dataStore.getAggregation(res, "docnm_kwd")
    return self.SearchResult(
        total=total,
        ids=ids,
        query_vector=q_vec,
        aggregation=aggs,
        highlight=highlight,
        field=self.dataStore.getFields(res, src),
        keywords=keywords
    )

search()实现步骤:

  1. 获取过滤条件:通过 self.get_filters(req) 获取过滤条件。
  2. 设置排序表达式:通过 OrderByExpr() 创建排序表达式。
  3. 解析请求参数:从请求对象 req 中解析出页码、每页大小、查询问题等参数。
  4. 设置查询源:定义查询源字段列表 src
  5. 处理查询问题:根据查询问题 qst,生成查询向量 q_vec 和匹配表达式 matchExprs
  6. 执行搜索:根据查询向量、匹配表达式和过滤条件,通过 self.dataStore.search 方法执行搜索。
  7. 处理搜索结果:获取搜索结果的总数、文档ID、高亮信息、聚合信息等。
  8. 返回搜索结果:将搜索结果封装成 SearchResult 对象并返回。

2.2.6 rerank

2.2.6.1 rag/nlp/search.py:Dealer.rerank()

Python
def rerank(self, sres, query, tkweight=0.3,
           vtweight=0.7, cfield="content_ltks",
           rank_feature: dict | None = None
           ):
    _, keywords = self.qryr.question(query)
    vector_size = len(sres.query_vector)
    vector_column = f"q_{vector_size}_vec"
    zero_vector = [0.0] * vector_size
    ins_embd = [] # 召回的每个chunk的emb
    for chunk_id in sres.ids:
        vector = sres.field[chunk_id].get(vector_column, zero_vector) # 获取每个chunk的emb
        if isinstance(vector, str):
            vector = [get_float(v) for v in vector.split("\t")]
        ins_embd.append(vector)
    if not ins_embd:
        return [], [], []

    for i in sres.ids:
        if isinstance(sres.field[i].get("important_kwd", []), str):
            sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
    ins_tw = []
    for i in sres.ids:
        content_ltks = list(OrderedDict.fromkeys(sres.field[i][cfield].split())) # 分词、去重
        title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
        question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t]
        important_kwd = sres.field[i].get("important_kwd", [])
        tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6 # content、title、kw、query的token数量权重不一样,为什么这样做?
        ins_tw.append(tks) 

    ## For rank feature(tag_fea) scores.
    rank_fea = self._rank_feature_scores(rank_feature, sres) # 干啥的?调整最终相似度的?

    sim, tksim, vtsim = self.qryr.hybrid_similarity(sres.query_vector, # query的emb
                                                    ins_embd, # chunks的emb
                                                    keywords, # query的关键词,是分词吗?
                                                    ins_tw, tkweight, vtweight) # 带权重的关键词?、关键词相似度权重、向量相似度权重

    return sim + rank_fea, tksim, vtsim

2.2.6.2 rag/nlp/search.py:Dealer.rerank_by_model()

Python
rerank_by_model()
def rerank_by_model(self, rerank_mdl, sres, query, tkweight=0.3,
                    vtweight=0.7, cfield="content_ltks",
                    rank_feature: dict | None = None):
    _, keywords = self.qryr.question(query)

    for i in sres.ids:
        if isinstance(sres.field[i].get("important_kwd", []), str):
            sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]]
    ins_tw = []
    for i in sres.ids:
        content_ltks = sres.field[i][cfield].split()
        title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t]
        important_kwd = sres.field[i].get("important_kwd", [])
        tks = content_ltks + title_tks + important_kwd
        ins_tw.append(tks)

    tksim = self.qryr.token_similarity(keywords, ins_tw) # 词相似度,query中的词和每个chunk的词相似度
    vtsim, _ = rerank_mdl.similarity(query, [rmSpace(" ".join(tks)) for tks in ins_tw]) # query和每个chunk的词组成的句子相似度
    ## For rank feature(tag_fea) scores.
    rank_fea = self._rank_feature_scores(rank_feature, sres)

    return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim

2.2.7 rag/nlp/query.py:question()<——self.qryr.question

未完待续...

3. ragflow网站

shell
curl --request POST \     --url http://{address}/api/v1/retrieval \     --header 'Content-Type: application/json' \     --header 'Authorization: Bearer <YOUR_API_KEY>' \     --data '
    {
        "page": 1, # 返回召回的第1页
        "page_size": 20# 控制返回的chunk数量,优先级高于相似度阈值,如果不设定,默认是topk
        "similarity_threshold": 0.3, # 不起作用,如果给定doc_id,相似度阈值设为0 ???
        "vector_similarity_weight": 0.3,
        "top_k": 10, # 不起作用???
        "question": "我去过几次北京",    
        "dataset_ids": ["462189d4561f11f09b760242ac160006"],
        "document_ids":["fd588544645711f0b3c50242ac160006"]
    }'

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐