本文档基于字节跳动(扣子) Coze Studio 实际源码分析,全面解析知识库RAG多种方式检索的技术实现。涵盖向量检索、全文检索、结构化数据检索的并行处理机制,以及基于 RRF 算法的智能结果融合策略。

🚀项目地址:https://github.com/coze-dev/coze-studio

📋 目录


一、检索服务架构概览

1.1 核心功能模块

1.1.1 各检索方式功能
  • 向量检索:基于语义相似度的智能检索,理解用户问题的含义
  • ES检索:基于关键词的传统检索,精确匹配特定词汇
  • NL2SQL检索:自然语言转SQL的数据库查询,处理结构化数据
  • 混合检索:多种检索方式的智能融合,提高检索覆盖度
1.1.2 重排序与融合功能
  • RRF重排序:基于多种检索结果的智能重排序算法
  • 分数过滤:过滤低质量检索结果,提高结果质量
  • 结果融合:将不同来源的结果智能合并,去除重复
  • 元数据丰富:为检索结果添加文档信息、来源说明等
1.1.3 链式处理架构
  • 并行检索:同时执行多种检索方式,提高响应速度
  • 重排序处理:对检索结果进行智能排序和过滤
  • 结果打包:格式化最终结果,便于后续处理
  • 错误处理:优雅处理各种检索失败情况

二、RAG检索流程

2.1 检索主入口时序图

用户 检索主入口 向量检索 ES检索 NL2SQL检索 重排序 结果打包 发送查询请求 参数验证 上下文构建 链式处理开始 查询重写 并行检索阶段 向量检索 返回向量结果 ES检索 返回ES结果 NL2SQL检索 返回SQL结果 par [并行执行] 结果融合阶段 重排序处理 返回排序结果 结果打包 返回最终结果 返回检索结果 完整流程耗时优化 用户 检索主入口 向量检索 ES检索 NL2SQL检索 重排序 结果打包

时序图说明:

  • 并行检索:向量检索、ES检索、NL2SQL检索同时进行,提高响应速度
  • 结果融合:多种检索结果通过重排序算法智能融合
  • 错误隔离:单个检索方式失败不影响其他检索方式
  • 性能优化:通过并行处理显著减少总体响应时间

2.2 检索主入口实现

检索主入口执行流程:

  1. 参数验证:检查请求参数的有效性
  2. 上下文构建:创建检索上下文,包含文档、策略等信息
  3. 处理链构建:组装查询重写、并行检索、重排序、结果打包的处理链
  4. 并行检索:同时执行向量检索、ES检索、NL2SQL检索
  5. 结果融合:通过重排序融合多种检索结果
  6. 结果打包:将最终结果转换为响应格式
// backend/domain/knowledge/service/retrieve.go
func (k *knowledgeSVC) Retrieve(ctx context.Context, request *RetrieveRequest) (response *RetrieveResponse, err error) {
	if request == nil {
		return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "request is nil"))
	}
	if len(request.Query) == 0 {
		return &knowledgeModel.RetrieveResponse{}, nil
	}
	
	// 1. 构建检索上下文
	retrieveContext, err := k.newRetrieveContext(ctx, request)
	if err != nil {
		return nil, err
	}
	if len(retrieveContext.Documents) == 0 {
		return &knowledgeModel.RetrieveResponse{}, nil
	}
	
	// 2. 构建处理链:查询重写 -> 并行检索 -> 重排序 -> 结果打包
	chain := compose.NewChain[*RetrieveContext, []*knowledgeModel.RetrieveSlice]()
	rewriteNode := compose.InvokableLambda(k.queryRewriteNode)
	vectorRetrieveNode := compose.InvokableLambda(k.vectorRetrieveNode)
	EsRetrieveNode := compose.InvokableLambda(k.esRetrieveNode)
	Nl2SqlRetrieveNode := compose.InvokableLambda(k.nl2SqlRetrieveNode)
	passRequestContextNode := compose.InvokableLambda(k.passRequestContext)
	reRankNode := compose.InvokableLambda(k.reRankNode)
	packResult := compose.InvokableLambda(k.packResults)
	
	// 3. 并行检索节点:同时执行向量、ES、NL2SQL检索
	parallelNode := compose.NewParallel().
		AddLambda("vectorRetrieveNode", vectorRetrieveNode).
		AddLambda("esRetrieveNode", EsRetrieveNode).
		AddLambda("nl2SqlRetrieveNode", Nl2SqlRetrieveNode).
		AddLambda("passRequestContext", passRequestContextNode)

	// 4. 编译并执行处理链
	r, err := chain.
		AppendLambda(rewriteNode).           // 第1步:查询重写
		AppendParallel(parallelNode).        // 第2步:并行检索
		AppendLambda(reRankNode).            // 第3步:重排序
		AppendLambda(packResult).            // 第4步:结果打包
		Compile(ctx)
	if err != nil {
		logs.CtxErrorf(ctx, "compile chain failed: %v", err)
		return nil, errorx.New(errno.ErrKnowledgeBuildRetrieveChainFailCode, errorx.KV("msg", err.Error()))
	}
	
	// 5. 执行检索流程
	output, err := r.Invoke(ctx, retrieveContext)
	if err != nil {
		return nil, err
	}
	
	return &knowledgeModel.RetrieveResponse{
		Slices: output,
	}, nil
}
2.2.1 检索上下文管理

检索上下文的作用:

  • 参数管理:统一管理检索参数和配置
  • 状态跟踪:跟踪检索过程中的状态变化
  • 结果传递:在不同处理节点间传递结果
// backend/domain/knowledge/service/interface.go
// 检索上下文
type RetrieveContext struct {
	Ctx              context.Context
	OriginQuery      string                   // 原始 query
	RewrittenQuery   *string                  // 改写后的 query, 如果没有改写,就是 nil, 会在执行过程中添加上去
	ChatHistory      []*schema.Message        // 如果没有对话历史或者不需要历史,则为 nil
	KnowledgeIDs     sets.Set[int64]          // 本次检索涉及的知识库id
	KnowledgeInfoMap map[int64]*KnowledgeInfo // 知识库id到文档id的映射
	// 召回策略
	Strategy *entity.RetrievalStrategy
	// 检索涉及的 document 信息
	Documents []*model.KnowledgeDocument
	// 用于 nl2sql 和 message to query 的 chat model
	ChatModel chatmodel.BaseChatModel
}

// 构建检索上下文
func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequest) (*RetrieveContext, error) {
	// 1. 验证策略参数
	if req.Strategy == nil {
		return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "strategy is required"))
	}
	
	// 2. 准备RAG文档:过滤启用的知识库和文档
	knowledgeIDSets := sets.FromSlice(req.KnowledgeIDs)
	docIDSets := sets.FromSlice(req.DocumentIDs)
	enableDocs, enableKnowledge, err := k.prepareRAGDocuments(ctx, docIDSets.ToSlice(), knowledgeIDSets.ToSlice())
	if err != nil {
		logs.CtxErrorf(ctx, "prepare rag documents failed: %v", err)
		return nil, err
	}
	if len(enableDocs) == 0 {
		return &RetrieveContext{}, nil
	}
	
	// 3. 构建知识库信息映射
	knowledgeInfoMap := make(map[int64]*KnowledgeInfo)
	for _, kn := range enableKnowledge {
		if knowledgeInfoMap[kn.ID] == nil {
			knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
			knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
			knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
		}
	}
	
	// 4. 关联文档到知识库
	for _, doc := range enableDocs {
		info, found := knowledgeInfoMap[doc.KnowledgeID]
		if !found {
			continue
		}
		info.DocumentIDs = append(info.DocumentIDs, doc.ID)
		if info.DocumentType == knowledgeModel.DocumentTypeTable && info.TableColumns == nil && doc.TableInfo != nil {
			info.TableColumns = doc.TableInfo.Columns
		}
	}

	// 5. 创建聊天模型(如果配置了)
	var cm chatmodel.BaseChatModel
	if req.ChatModelProtocol != nil && req.ChatModelConfig != nil {
		cm, err = k.modelFactory.CreateChatModel(ctx, ptr.From(req.ChatModelProtocol), req.ChatModelConfig)
		if err != nil {
			return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode,
				errorx.KV("msg", "invalid retriever chat model protocol or config"))
		}
	}

	// 6. 构建检索上下文
	resp := RetrieveContext{
		Ctx:              ctx,
		OriginQuery:      req.Query,
		ChatHistory:      append(req.ChatHistory, schema.UserMessage(req.Query)),
		KnowledgeIDs:     knowledgeIDSets,
		KnowledgeInfoMap: knowledgeInfoMap,
		Strategy:         req.Strategy,
		Documents:        enableDocs,
		ChatModel:        cm,
	}
	return &resp, nil
}
2.2.1 查询重写节点实现

查询重写机制深度解析:

查询重写是RAG系统中的重要优化步骤,它通过分析聊天历史来优化用户查询,提高检索的准确性和相关性。

核心功能定位:

  • 上下文理解:基于聊天历史理解用户查询的上下文
  • 查询优化:将简单查询转换为更精确的检索查询
  • 意图识别:识别用户的真实检索意图
  • 相关性提升:通过重写提高检索结果的相关性

执行条件:

  1. 有聊天历史:必须存在聊天历史记录
  2. 启用重写功能:策略中EnableQueryRewrite为true
  3. 重写器可用:系统配置了查询重写器

重写流程:

  1. 历史检查:检查是否存在聊天历史
  2. 功能检查:检查是否启用查询重写功能
  3. 模型调用:使用大语言模型进行查询重写
  4. 结果应用:将重写后的查询应用到检索流程
// backend/domain/knowledge/service/retrieve.go
func (k *knowledgeSVC) queryRewriteNode(ctx context.Context, req *RetrieveContext) (newRetrieveContext *RetrieveContext, err error) {
	// 1. 检查聊天历史:如果没有历史记录,直接返回原始请求
	if len(req.ChatHistory) == 0 {
		return req, nil
	}
	
	// 2. 检查重写功能:如果未启用重写或重写器不可用,直接返回
	if !req.Strategy.EnableQueryRewrite || k.rewriter == nil {
		return req, nil
	}
	
	// 3. 构建重写选项:如果指定了聊天模型则使用指定模型
	var opts []messages2query.Option
	if req.ChatModel != nil {
		opts = append(opts, messages2query.WithChatModel(req.ChatModel))
	}
	
	// 4. 执行查询重写:基于聊天历史重写用户查询
	rewrittenQuery, err := k.rewriter.MessagesToQuery(ctx, req.ChatHistory, opts...)
	if err != nil {
		logs.CtxErrorf(ctx, "rewrite query failed: %v", err)
		return req, nil  // 重写失败时返回原始请求
	}
	
	// 5. 应用重写结果:将重写后的查询设置到请求中
	req.RewrittenQuery = &rewrittenQuery
	return req, nil
}

2.3 检索策略配置

检索策略参数说明:

// backend/api/model/crossdomain/knowledge/knowledge.go
type RetrievalStrategy struct {
    TopK      *int64   // 1-10 default 3,最终返回数量
    MinScore  *float64 // 0.01-0.99 default 0.5,最小匹配分数
    MaxTokens *int64   // 最大token数

    SelectType         SelectType // 调用方式
    SearchType         SearchType // 搜索策略
    EnableQueryRewrite bool       // 是否启用查询重写
    EnableRerank       bool       // 是否启用重排序
    EnableNL2SQL       bool       // 是否启用NL2SQL
}

type SearchType int64

const (
    SearchTypeSemantic SearchType = 0 // 语义搜索:仅向量检索
    SearchTypeFullText SearchType = 1 // 全文搜索:仅ES检索  
    SearchTypeHybrid   SearchType = 2 // 混合搜索:向量+ES
)

type SelectType int64

const (
    SelectTypeAuto     = 0 // 自动调用
    SelectTypeOnDemand = 1 // 按需调用
)
2.3.1 检索参数配置表
参数名称 默认值 取值范围 作用说明 影响
TopK 3 1-10 最终返回结果数量 控制输出数量
MinScore 0.5 0.01-0.99 最小匹配分数阈值 过滤低质量结果
MaxTokens - 动态 最大token数限制 控制处理复杂度
SearchType Semantic Semantic/FullText/Hybrid 搜索策略选择 决定检索方式
EnableQueryRewrite false true/false 是否启用查询重写 优化查询质量
EnableRerank false true/false 是否启用重排序 提升结果质量
EnableNL2SQL false true/false 是否启用NL2SQL 支持结构化查询
2.3.2 多种检索方式数据量对比表
检索方式 默认检索数量 检索原理 适用场景 优势 劣势
向量检索 4条 向量相似度匹配 语义搜索 语义理解强、容错性好 计算开销大
ES检索 10条 倒排索引+TF-IDF 全文搜索 精确匹配、速度快 语义理解弱
NL2SQL检索 动态 自然语言转SQL 表格数据 结构化查询 仅限表格数据
混合检索 向量4条+ES10条 多路并行+重排序 综合场景 全面覆盖 计算复杂度高
2.3.3 重排序阶段
  • 合并候选:向量4条 + ES 10条 = 最多14条候选
  • TopK筛选:根据TopK参数选择最终数量(默认3条)
  • MinScore过滤:根据MinScore参数过滤低分结果(默认0.5)
2.3.4 最终输出
  • 最大数量:由TopK参数控制(默认3条)
  • 质量过滤:由MinScore参数控制(默认0.5)
  • 实际数量:可能少于TopK,取决于匹配质量

三、多种检索方式核心实现

3.1 向量检索实现

向量检索原理:

  1. 文本向量化:将用户问题和文档都转换为向量
  2. 相似度计算:计算向量之间的余弦相似度
  3. 排序返回:按相似度排序返回最相关的结果
3.1.1 向量检索常量定义
// backend/infra/impl/document/searchstore/milvus/consts.go
const (
    batchSize = 100
    topK      = 4  // 向量检索默认返回4条结果
)

数据量说明: 向量检索默认返回4条结果,这是基于语义相似度排序后的最相关文档。

3.1.2 向量检索入口

检索机制深度剖析:

  1. 检索策略判断:首先检查SearchType,如果是SearchTypeFullText(仅全文搜索),则跳过向量检索
  2. 向量存储管理器选择:遍历所有搜索存储管理器,找到类型为TypeVectorStore的Milvus管理器
  3. 通道检索调用:调用retrieveChannels函数进行实际的向量检索操作
// backend/domain/knowledge/service/retrieve.go
// 向量检索节点
func (k *knowledgeSVC) vectorRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
    // 1. 检查搜索类型,如果是全文搜索则跳过向量检索
    if req.Strategy.SearchType == knowledgeModel.SearchTypeFullText {
        return nil, nil
    }
    
    // 2. 获取向量存储管理器
    var manager searchstore.Manager
    for i := range k.searchStoreManagers {
        m := k.searchStoreManagers[i]
        if m != nil && m.GetType() == searchstore.TypeVectorStore {
            manager = m
            break
        }
    }
    
    // 3. 检查向量存储是否可用
    if manager == nil {
        logs.CtxErrorf(ctx, "err:%s", errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", "未实现vectorStore")).Error())
        return nil, nil
    }
    
    // 4. 执行向量检索    详情参考本文章:四、多知识库并发检索机制
    retrieveResult, err = k.retrieveChannels(ctx, req, manager)
    if err != nil {
        logs.CtxErrorf(ctx, "retrieveChannels err:%s", err.Error())
    }
    
    return retrieveResult, nil
}

向量检索优势:

  • 语义理解:能理解问题的含义,不只是关键词匹配
  • 容错性强:即使用词不同,只要意思相近就能找到
  • 扩展性好:支持同义词、近义词等语义扩展
3.1.3 Milvus向量检索核心实现
// backend/infra/impl/document/searchstore/milvus/milvus_searchstore.go
func (m *milvusSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
	cli := m.config.Client
	emb := m.config.Embedding
	options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
	
	// 1. 获取集合描述信息
	desc, err := cli.DescribeCollection(ctx, client.NewDescribeCollectionOption(m.collectionName))
	
	// 2. 查询向量化处理
	if enableSparse {
		dense, sparse, err = emb.EmbedStringsHybrid(ctx, []string{query})  // 混合向量化
	} else {
		dense, err = emb.EmbedStrings(ctx, []string{query})  // 密集向量化
	}
	
	// 3. 构建搜索请求
	if enableSparse {
		// 混合搜索:同时使用密集向量和稀疏向量
		searchOption := client.NewHybridSearchOption(m.collectionName, ptr.From(options.TopK), annRequests...).
			WithPartitons(implSpecOptions.Partitions...).
			WithReranker(client.NewRRFReranker()).  // 使用RRF重排序
			WithOutputFields(outputFields...)
		result, err = cli.HybridSearch(ctx, searchOption)
	} else {
		// 纯向量搜索
		searchOption := client.NewSearchOption(m.collectionName, ptr.From(options.TopK), dv).
			WithPartitions(implSpecOptions.Partitions...).
			WithFilter(expr).
			WithOutputFields(outputFields...).
			WithSearchParam(mindex.MetricTypeKey, string(metricsType))
		result, err = cli.Search(ctx, searchOption)
	}
	
	// 4. 结果转换
	docs, err := m.resultSet2Document(result, scoreNormType)
	return docs, nil
}

向量检索算法原理:

  • 密集向量检索:使用余弦相似度或欧几里得距离计算查询向量与文档向量的相似度
  • 混合向量检索:同时使用密集向量(语义信息)和稀疏向量(关键词信息),通过RRF算法融合结果
  • 分区过滤:根据文档ID进行分区过滤,提高检索效率
  • 相似度计算:支持多种距离度量方式(IP、L2、COSINE等)

3.2 Elasticsearch检索实现

Elasticsearch检索原理:

  1. 倒排索引:建立词汇到文档的映射
  2. TF-IDF算法:计算词汇在文档中的重要性
  3. 布尔查询:支持AND、OR、NOT等逻辑运算3.2.1 ES检索常量定义
3.2.1 Elasticsearch检索常量
// backend/infra/impl/document/searchstore/elasticsearch/consts.go
const (
    topK = 10  // Elasticsearch检索默认返回10条结果
)

数据量说明: Elasticsearch检索默认返回10条结果,这是基于关键词匹配和TF-IDF算法排序后的文档。

3.2.2 Elasticsearch检索入口

检索机制深度剖析:

  1. 检索策略判断:检查SearchType,如果是SearchTypeSemantic(仅语义搜索),则跳过Elasticsearch检索
  2. 文本存储管理器选择:找到类型为TypeTextStore的Elasticsearch管理器
  3. 通道检索调用:调用retrieveChannels函数进行实际的Elasticsearch检索操作
// backend/domain/knowledge/service/retrieve.go
// ES检索节点
func (k *knowledgeSVC) esRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
    // 1. 检查搜索类型,如果是语义搜索则跳过ES检索
    if req.Strategy.SearchType == knowledgeModel.SearchTypeSemantic {
        return nil, nil
    }
    
    // 2. 获取文本存储管理器
    var manager searchstore.Manager
    for i := range k.searchStoreManagers {
        m := k.searchStoreManagers[i]
        if m != nil && m.GetType() == searchstore.TypeTextStore {
            manager = m
            break
        }
    }
    
    // 3. 检查ES存储是否可用
    if manager == nil {
        logs.CtxErrorf(ctx, "err:%s", errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", "未实现esStore")).Error())
        return nil, nil
    }
    
    // 4. 执行ES检索    详情参考本文章:四、多知识库并发检索机制
    retrieveResult, err = k.retrieveChannels(ctx, req, manager)
    if err != nil {
        logs.CtxErrorf(ctx, "retrieveChannels err:%s", err.Error())
    }
    
    return retrieveResult, nil
}

ES检索特点:

  • 精确匹配:能找到包含特定关键词的文档
  • 过滤支持:支持复杂的过滤条件
  • 快速检索:基于倒排索引,检索速度极快
3.2.3 Elasticsearch检索核心实现
// backend/infra/impl/document/searchstore/elasticsearch/elasticsearch_searchstore.go
func (e *esSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
	var (
		cli   = e.config.Client
		index = e.indexName
		options = retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
		req = &es.Request{
			Query: &es.Query{
				Bool: &es.BoolQuery{},
			},
			Size: options.TopK,
		}
	)

	// 1. 构建查询条件
	if implSpecOptions.MultiMatch == nil {
		// 单字段匹配
		req.Query.Bool.Must = append(req.Query.Bool.Must,
			es.NewMatchQuery(searchstore.FieldTextContent, query))
	} else {
		// 多字段匹配
		req.Query.Bool.Must = append(req.Query.Bool.Must,
			es.NewMultiMatchQuery(implSpecOptions.MultiMatch.Fields, query,
				"best_fields", es.Or))
	}

	// 2. 应用DSL过滤条件
	dsl, err := searchstore.LoadDSL(options.DSLInfo)
	if err = e.travDSL(req.Query, dsl); err != nil {
		return nil, err
	}

	// 3. 设置分数阈值
	if options.ScoreThreshold != nil {
		req.MinScore = options.ScoreThreshold
	}

	// 4. 执行搜索
	resp, err := cli.Search(ctx, index, req)
	
	// 5. 解析结果
	docs, err := e.parseSearchResult(resp)
	return docs, nil
}

ES检索算法原理:

  • 倒排索引查询:基于关键词的倒排索引快速定位包含查询词的文档
  • TF-IDF评分:计算词频-逆文档频率,评估文档与查询的相关性
  • 多字段匹配:支持在多个字段中同时搜索,使用best_fields策略
  • 布尔查询:支持复杂的AND、OR、NOT逻辑组合
  • 分数阈值过滤:只返回分数超过阈值的文档

3.3 NL2SQL检索实现

3.3.1 NL2SQL检索机制深度剖析

NL2SQL检索算法原理:

  • 语义理解:通过大语言模型理解用户查询的语义意图
  • SQL生成:将自然语言查询转换为标准的SQL查询语句
  • 表结构映射:将逻辑表名映射到物理表名,字段名映射到数据库字段
  • 精确匹配:基于结构化数据的精确查询,返回完全匹配的结果
  • 并行查询:支持对多个表格同时进行NL2SQL查询

NL2SQL执行逻辑详解:

  1. 输入处理阶段
    • 接收用户自然语言查询和聊天历史
    • 识别表格文档并提取表结构信息
    • 构建表格描述字符串(包含表名、字段名、字段类型、是否必填等)
  2. 模型调用阶段
    • 使用Jinja2模板将表格描述和用户查询组合成提示
    • 调用配置的大语言模型(如GPT、Claude等)
    • 模型返回JSON格式的SQL语句和错误信息
  3. SQL处理阶段
    • 解析模型返回的JSON结果
    • 添加切片ID列到SQL语句中
    • 构建表名和字段的映射关系(逻辑名到物理名的转换)
  4. SQL执行阶段
    • 使用SQL解析器修改SQL语句(替换表名和字段名)
    • 在关系数据库中执行修改后的SQL
    • 获取查询结果集
  5. 结果转换阶段
    • 将SQL查询结果转换为统一的文档格式
    • 为每个结果分配默认分数1(表示精确匹配)
    • 返回文档列表供后续重排序使用

数据量说明:

NL2SQL检索返回的结果数量取决于SQL查询的结果行数,每条结果默认分数为1,表示结构化数据的精确匹配。

3.3.2 NL2SQL检索入口
  1. 表格文档识别:遍历所有文档,识别类型为DocumentTypeTable的表格文档
  2. 并行处理:使用errgroup并发处理多个表格文档的NL2SQL查询
  3. 自然语言转SQL:调用NL2SQL模型将用户查询转换为SQL语句
  4. SQL执行:在关系数据库中执行生成的SQL查询
  5. 结果转换:将SQL查询结果转换为统一的文档格式
// backend/domain/knowledge/service/retrieve.go
// NL2SQL检索节点:处理表格文档的自然语言转SQL查询
// 参数:
//   - ctx: 上下文,用于超时控制和日志记录
//   - req: 检索上下文,包含用户查询、文档列表、策略配置等
// 返回:
//   - retrieveResult: 检索结果文档列表
//   - err: 错误信息
func (k *knowledgeSVC) nl2SqlRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
	hasTable := false  // 标记是否存在表格文档
	var tableDocs []*model.KnowledgeDocument  // 存储所有表格文档
	
	// 遍历所有文档,识别类型为DocumentTypeTable的表格文档
    for _, doc := range req.Documents {
        if doc.DocumentType == int32(knowledgeModel.DocumentTypeTable) {
            hasTable = true
            tableDocs = append(tableDocs, doc)
        }
    }
    
	// 构建NL2SQL选项,如果指定了聊天模型则使用指定的模型
    var opts []nl2sql.Option
    if req.ChatModel != nil {
        opts = append(opts, nl2sql.WithChatModel(req.ChatModel))
    }
    
	// 如果存在表格文档且启用了NL2SQL功能,则执行NL2SQL查询
    if hasTable && req.Strategy.EnableNL2SQL {
		mu := sync.Mutex{}  // 互斥锁,用于保护并发写入结果
		
		// 使用errgroup并发处理多个表格文档的NL2SQL查询
        eg, ctx := errgroup.WithContext(ctx)
		eg.SetLimit(len(tableDocs))  // 设置并发限制为表格文档数量
		res := make([]*schema.Document, 0)  // 存储所有查询结果
        
		// 为每个表格文档启动一个goroutine进行并发查询
        for i := range tableDocs {
            t := i
            eg.Go(func() error {
                doc := tableDocs[t]
				// 对单个表格文档执行NL2SQL查询
                docs, execErr := k.nl2SqlExec(ctx, doc, req, opts)
                if execErr != nil {
                    logs.CtxErrorf(ctx, "nl2sql exec failed: %v", execErr)
                    return errorx.New(errno.ErrKnowledgeNL2SqlExecFailCode, errorx.KV("msg", execErr.Error()))
                }
				// 线程安全地添加查询结果
                mu.Lock()
                res = append(res, docs...)
                mu.Unlock()
                return nil
            })
        }
        
		// 等待所有goroutine完成
        err = eg.Wait()
        if err != nil {
            logs.CtxErrorf(ctx, "nl2sql exec failed: %v", err)
            return nil, nil
        }
        return res, nil
    } else {
		// 如果没有表格文档或未启用NL2SQL,返回空结果
        return nil, nil
    }
}
3.3.3 NL2SQL核心实现
// backend/domain/knowledge/service/retrieve.go
// NL2SQL执行函数:将自然语言查询转换为SQL并执行
// 参数:
//   - ctx: 上下文
//   - doc: 表格文档,包含表结构信息
//   - retrieveCtx: 检索上下文,包含用户查询历史
//   - opts: NL2SQL选项,如聊天模型配置
// 返回:
//   - retrieveResult: 查询结果文档列表
//   - err: 错误信息
func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocument, retrieveCtx *RetrieveContext, opts []nl2sql.Option) (
    retrieveResult []*schema.Document, err error) {
    
    // 步骤1: 调用NL2SQL模型生成SQL
    // 将表格文档打包成TableSchema格式,传入聊天历史和选项
    sql, err := k.nl2Sql.NL2SQL(ctx, retrieveCtx.ChatHistory, []*document.TableSchema{packNL2SqlRequest(doc)}, opts...)
    if err != nil {
        logs.CtxErrorf(ctx, "nl2sql failed: %v", err)
        return nil, err
    }
    
    // 步骤2: 添加切片ID列到SQL中,用于标识数据来源
    sql = addSliceIdColumn(sql)
    
    // 步骤3: 构建表名和字段映射关系
    // 将逻辑表名映射到物理表名,字段名映射到数据库字段名
    replaceMap := map[string]sqlparsercontract.TableColumn{}
    replaceMap[doc.Name] = sqlparsercontract.TableColumn{
        NewTableName: ptr.Of(doc.TableInfo.PhysicalTableName),  // 物理表名
        ColumnMap: map[string]string{
            pkID: consts.RDBFieldID,  // 主键字段映射
        },
    }
    
    // 遍历表格的所有列,建立字段名映射关系
    for i := range doc.TableInfo.Columns {
        if doc.TableInfo.Columns[i] == nil {
            continue
        }
        // 跳过RDB字段ID,避免重复映射
        if doc.TableInfo.Columns[i].Name == consts.RDBFieldID {
            continue
        }
        // 将逻辑字段名映射到物理字段名
        replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
    }
    
    // 步骤4: 解析并修改SQL(表名和字段名映射)
    // 使用SQL解析器将逻辑表名和字段名替换为物理名称
    parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
    if err != nil {
        logs.CtxErrorf(ctx, "parse sql failed: %v", err)
        return nil, err
    }
    
    // 步骤5: 执行SQL查询
    // 在关系数据库中执行修改后的SQL语句
    resp, err := k.rdb.ExecuteSQL(ctx, &rdb.ExecuteSQLRequest{
        SQL: parsedSQL,
    })
    if err != nil {
        logs.CtxErrorf(ctx, "execute sql failed: %v", err)
        return nil, err
    }
    
    // 步骤6: 转换结果为文档格式
    // 将数据库查询结果转换为统一的文档格式,便于后续处理
    for i := range resp.ResultSet.Rows {
        // 从结果行中提取ID字段
        id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
        if !ok {
            logs.CtxWarnf(ctx, "convert id failed, row: %v", resp.ResultSet.Rows[i])
            return nil, errors.New("convert id failed")
        }
        
        // 创建文档对象,设置ID和元数据
        d := &schema.Document{
            ID:       strconv.FormatInt(id, 10),  // 将ID转换为字符串
            Content:  "",  // NL2SQL结果内容为空,因为结构化数据不需要文本内容
            MetaData: map[string]any{},  // 元数据为空
        }
        d.WithScore(1)  // NL2SQL结果默认分数为1,表示精确匹配
        retrieveResult = append(retrieveResult, d)
    }
    
    return retrieveResult, nil
}
3.3.4 NL2SQL模型详细实现
// backend/infra/impl/document/nl2sql/builtin/nl2sql.go
// NL2SQL模型核心实现:将自然语言转换为SQL语句
// 参数:
//   - ctx: 上下文
//   - messages: 聊天历史消息列表
//   - tables: 表格结构信息列表
//   - opts: 可选的NL2SQL配置选项
// 返回:
//   - sql: 生成的SQL语句
//   - err: 错误信息
func (n *n2s) NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
    // 初始化NL2SQL选项,默认使用配置的聊天模型
    o := &nl2sql.Options{ChatModel: n.cm}
    for _, opt := range opts {
        opt(o)  // 应用自定义选项
    }

    // 检查聊天模型是否已配置
    if o.ChatModel == nil {
        return "", fmt.Errorf("[NL2SQL] chat model not configured")
    }

    // 构建处理链:输入处理 -> 模板渲染 -> 模型调用 -> 结果解析
    c := compose.NewChain[*nl2sqlInput, string]().
        // 第一步:输入处理Lambda,构建表格描述信息
        AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *nl2sqlInput) (output map[string]any, err error) {
            // 检查表格信息是否为空
            if len(input.tables) == 0 {
                return nil, errors.New("table meta is empty")
            }
            
            // 构建表格描述信息字符串
            tableDesc := strings.Builder{}
            for _, table := range input.tables {
                // 添加表名和表描述
                tableDesc.WriteString(fmt.Sprintf(defaultTableFmt, table.Name, table.Comment))
                // 添加每个字段的信息(字段名、描述、类型、是否必填)
                for _, column := range table.Columns {
                    tableDesc.WriteString(fmt.Sprintf(defaultColumnFmt, column.Name, column.Description, column.Type.String(), !column.Nullable))
                }
            }
            
            // 返回模板渲染所需的数据
            return map[string]interface{}{
                "messages":     input.messages,     // 聊天历史
                "table_schema": tableDesc.String(), // 表格结构描述
            }, nil
        })).
        // 第二步:使用Jinja2模板渲染提示
        AppendChatTemplate(n.tpl).
        // 第三步:调用大语言模型生成SQL
        AppendChatModel(o.ChatModel).
        // 第四步:结果解析Lambda,解析模型返回的JSON格式结果
        AppendLambda(compose.InvokableLambda(func(ctx context.Context, msg *schema.Message) (sql string, err error) {
            // 解析模型返回的JSON格式结果
            var promptResp *promptResponse
            if err := json.Unmarshal([]byte(msg.Content), &promptResp); err != nil {
                logs.CtxWarnf(ctx, "unmarshal failed: %v", err)
                return "", err
            }
            
            // 检查是否成功生成SQL
            if promptResp.SQL == "" {
                logs.CtxInfof(ctx, "no sql generated, err_code: %v, err_msg: %v", promptResp.ErrCode, promptResp.ErrMsg)
                return "", errors.New(promptResp.ErrMsg)
            }
            
            return promptResp.SQL, nil
        }))

    // 编译处理链
    r, err := c.Compile(ctx)
    if err != nil {
        return "", err
    }

    // 准备输入数据
    input := &nl2sqlInput{
        messages: messages,  // 聊天历史
        tables:   tables,    // 表格结构
    }

    // 执行处理链并返回生成的SQL
    return r.Invoke(ctx, input)
}
3.3.5 NL2SQL提示模板
// backend/conf/prompt/nl2sql_template_jinja2.json
[
  {
    "role": "system",
    "content": "# Role: NL2SQL Consultant\n\n## Goals\nTranslate natural language statements into SQL queries in MySQL standard. Follow the Constraints and return only a JSON always.\n\n## Format\n- JSON format only. JSON contains field \"sql\" for generated SQL, filed \"err_code\" for reason type, field \"err_msg\" for detail reason (prefer more than 10 words)\n- Don't use \"```json\" markdown format\n\n## Skills\n- Good at Translate natural language statements into SQL queries in MySQL standard.\n\n## Define\n\"err_code\" Reason Type Define:\n- 0 means you generated a SQL\n- 3002 means you cannot generate a SQL because of timeout\n- 3003 means you cannot generate a SQL because of table schema missing\n- 3005 means you cannot generate a SQL because of some term is ambiguous\n\n## Example\nQ: Help me implement NL2SQL.\n​.table schema description: ​​CREATE TABLE `sales_records` (\\n  `sales_id` bigint(20) unsigned NOT NULL COMMENT 'id of sales person',\\n  `product_id` bigint(64) COMMENT 'id of product',\\n  `sale_date` datetime(3) COMMENT 'sold date and time',\\n  `quantity_sold` int(11) COMMENT 'sold amount',\\n  PRIMARY KEY (`sales_id`)\\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='销售记录表';\n​.natural language description of the SQL requirement:  ​​​​查询上月的销量总额第一名的销售员和他的销售总额\nA: {\n  \"sql\":\"SELECT sales_id, SUM(quantity_sold) AS total_sales FROM sales_records WHERE MONTH(sale_date) = MONTH(CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR(sale_date) = YEAR(CURRENT_DATE - INTERVAL 1 MONTH) GROUP BY sales_id ORDER BY total_sales DESC LIMIT 1\",\n  \"err_code\":0,\n  \"err_msg\":\"SQL query generated successfully\"\n}"
  },
  {
    "role": "user",
    "content": "help me implement NL2SQL.\ntable schema description:{{table_schema}}\nnatural language description of the SQL requirement: {{messages}}."
  }
]
3.3.6 表格数据打包函数
// backend/domain/knowledge/service/retrieve.go
// 表格数据打包函数:将知识库文档转换为NL2SQL所需的表格结构
// 参数:
//   - doc: 知识库文档,包含表格信息
// 返回:
//   - *document.TableSchema: 表格结构信息,包含表名、描述、字段列表
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
    res := &document.TableSchema{}  // 创建表格结构对象
    
    
    if doc.TableInfo == nil {
        return res  
    }
    
    // 设置表格基本信息
    res.Name = doc.TableInfo.VirtualTableName    // 虚拟表名(逻辑表名)
    res.Comment = doc.TableInfo.TableDesc        // 表格描述
    res.Columns = []*document.Column{}           // 初始化字段列表
    
    // 遍历表格的所有列,构建字段信息
    for _, column := range doc.TableInfo.Columns {
        // 跳过RDB字段ID,因为这是系统内部字段,不需要暴露给NL2SQL
        if column.Name == consts.RDBFieldID {
            continue
        }
        
        // 添加字段信息到表格结构中
        res.Columns = append(res.Columns, &document.Column{
            Name:        column.Name,        // 字段名
            Type:        column.Type,        // 字段类型
            Description: column.Description, // 字段描述
            Nullable:    !column.Indexing,  // 是否可为空(与索引字段相反)
            IsPrimary:   false,             // 是否为主键(这里统一设为false)
        })
    }
    
    return res
}
3.3.7 SQL解析器实现
// backend/infra/impl/sqlparser/sql_parser.go
// SQL解析器:解析和修改SQL语句,实现表名和字段名的映射
// 参数:
//   - sql: 原始SQL语句
//   - tableColumns: 表名和字段的映射关系
// 返回:
//   - string: 修改后的SQL语句
//   - error: 错误信息
func (p *Impl) ParseAndModifySQL(sql string, tableColumns map[string]sqlparser.TableColumn) (string, error) {
    
    if len(tableColumns) == 0 {
        return sql, nil
    }

    
    for originalTableName, tableColumn := range tableColumns {
        
        if originalTableName == "" {
            return "", fmt.Errorf("original TableName must be non-empty")
        }
        
        
        if tableColumn.ColumnMap != nil {
            for key, value := range tableColumn.ColumnMap {
                if (key == "") != (value == "") {
                    return "", fmt.Errorf("ColumnMap key and value must be either both empty or both non-empty")
                }
            }
        }
    }

    // 使用TiDB解析器解析SQL语句
    stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
        if err != nil {
        return "", fmt.Errorf("failed to parse SQL: %v", err)
    }

    // 第一遍遍历:收集所有表别名
    aliasCollector := NewAliasCollector()
    stmt.Accept(aliasCollector)

    // 检查别名冲突:原始表名不能与别名相同
    for originalTableName, _ := range tableColumns {
        if _, ok := aliasCollector.tableAliases[originalTableName]; ok {
            return "", fmt.Errorf("alisa table name should not equal with origin table name")
        }
    }

    // 第二遍遍历:使用收集的别名信息修改AST(抽象语法树)
    modifier := NewSQLModifier(tableColumns, aliasCollector.tableAliases)
    stmt.Accept(modifier)

    // 将修改后的AST转换回SQL字符串
    var sb strings.Builder
    // 设置格式化标志:使用单引号、移除字符集前缀
    flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset
    restoreCtx := format.NewRestoreCtx(flags, &sb)
    err = stmt.Restore(restoreCtx)
    if err != nil {
        return "", fmt.Errorf("failed to restore SQL: %v", err)
    }

    return sb.String(), nil
}
3.3.8 RDB执行SQL实现
// backend/infra/impl/rdb/mysql.go
// RDB执行SQL:在关系数据库中执行SQL查询
// 参数:
//   - ctx: 上下文
//   - req: SQL执行请求,包含SQL语句和参数
// 返回:
//   - *rdb.ExecuteSQLResponse: 执行结果
//   - error: 错误信息
func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
    
    if req == nil {
        return nil, fmt.Errorf("invalid request")
    }

    logs.CtxInfof(ctx, "[ExecuteSQL] req is %v", req)

    var processedSQL string      // 处理后的SQL语句
    var processedParams []interface{}  // 处理后的参数
    var err error

    // 根据SQL类型处理参数
    if req.SQLType == entity2.SQLType_Raw {
        // 原始SQL类型:不处理参数,直接使用原始SQL
        processedSQL = req.SQL
        processedParams = nil
    } else {
        // 参数化SQL类型:处理切片参数
        processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
        if err != nil {
            return nil, fmt.Errorf("failed to process parameters: %v", err)
        }
    }

    // 获取SQL操作类型(SELECT/INSERT/UPDATE/DELETE)
    operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
    if err != nil {
        return nil, err
    }
    
    // 处理非SELECT操作(INSERT/UPDATE/DELETE)
    if operation != sqlparsercontract.OperationTypeSelect {
        // 执行非查询SQL,返回影响的行数
        result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...)
        if result.Error != nil {
            return nil, fmt.Errorf("failed to execute SQL: %v", result.Error)
        }

        // 构建结果集(非查询操作没有数据行)
        resultSet := &entity2.ResultSet{
            Columns:      []string{},                    // 空列名
            Rows:         []map[string]interface{}{},   // 空数据行
            AffectedRows: result.RowsAffected,          // 影响的行数
        }

        return &rdb.ExecuteSQLResponse{
            ResultSet: resultSet,
        }, nil
    }

    // 处理SELECT查询操作
    // 执行查询并获取结果集
    rows, err := m.db.WithContext(ctx).Raw(processedSQL, processedParams...).Rows()
    if err != nil {
        return nil, fmt.Errorf("failed to execute SQL: %v", err)
    }
    defer rows.Close()  // 确保结果集被关闭

    // 获取列名信息
    columns, err := rows.Columns()
    if err != nil {
        return nil, fmt.Errorf("failed to get columns: %v", err)
    }

    // 初始化结果集
    resultSet := &entity2.ResultSet{
        Columns: columns,  // 设置列名
        Rows:    make([]map[string]interface{}, 0),  // 初始化数据行列表
    }

    // 遍历查询结果的每一行
    for rows.Next() {
        // 为当前行准备值容器
        values := make([]interface{}, len(columns))
        valuePtrs := make([]interface{}, len(columns))
        for i := range values {
            valuePtrs[i] = &values[i]  // 创建指向值的指针
        }

        // 扫描当前行数据到值容器
        if err := rows.Scan(valuePtrs...); err != nil {
            return nil, fmt.Errorf("failed to scan row: %v", err)
        }

        // 将当前行数据转换为map格式
        rowData := make(map[string]interface{})
        for i, col := range columns {
            rowData[col] = values[i]  // 列名 -> 值的映射
        }
        resultSet.Rows = append(resultSet.Rows, rowData)
    }

   
    if err := rows.Err(); err != nil {
        return nil, fmt.Errorf("error while reading rows: %v", err)
    }

    // 返回查询结果
    return &rdb.ExecuteSQLResponse{
        ResultSet: resultSet,
    }, nil
}

四、多知识库并发检索机制

多知识库并发检索机制深度解析:

retrieveChannels函数是ES检索和向量检索的核心实现,负责统一处理多个知识库的并发检索。这个函数在整个RAG系统中扮演着"检索调度器"的角色,它的主要作用是:

核心功能定位:

  • 统一入口:为ES检索和向量检索提供统一的处理入口
  • 并发调度:管理多个知识库的并发检索任务
  • 资源控制:通过限制并发数避免系统资源过载
  • 结果聚合:安全地合并来自不同知识库的检索结果

在RAG流程中的位置:

用户查询 → 2. 查询重写 → 3. 多知识库并发检索 → 4. 结果重排序 → 5. 结果打包

设计优势:

  • 性能优化:通过并发检索显著提升响应速度
  • 容错机制:单个知识库失败不影响整体检索
  • 资源管理:限制并发数避免系统过载
  • 精确检索:通过分区和DSL过滤确保检索精度

执行流程图:

开始: 接收检索请求
查询预处理
是否启用查询重写?
使用重写后的查询
使用原始查询
初始化并发控制
遍历知识库列表
构建DSL过滤条件
创建分区列表
分区列表为空?
跳过当前知识库
构建检索选项
还有知识库?
是表格文档?
添加多字段匹配
启动并发检索任务
获取搜索存储实例
执行检索操作
合并检索结果
所有任务完成?
返回最终结果

详细执行步骤:

  1. 知识库遍历:遍历所有相关的知识库,为每个知识库创建独立的检索任务
  2. 分区过滤:根据文档ID创建分区列表,限制检索范围,提高检索效率
  3. DSL条件构建:构建文档ID的过滤条件,确保只检索指定文档
  4. 多字段匹配:对于表格文档,支持在多个索引字段中进行匹配
  5. 并发检索:使用errgroup并发执行多个知识库的检索任务
  6. 结果合并:使用互斥锁安全地合并所有检索结果

核心优化策略:

  • 分区检索:通过文档ID分区,减少检索范围,提高效率
  • 并发处理:支持多个知识库的并发检索,最多同时处理2个知识库
  • 字段索引:只对设置了索引的字段进行匹配,避免无效检索
  • 错误隔离:单个知识库检索失败不影响其他知识库的检索
  • 结果去重:通过文档ID确保结果唯一性
// backend/domain/knowledge/service/retrieve.go
// 多知识库并发检索函数 - 负责多知识库并发检索的核心实现
func (k *knowledgeSVC) retrieveChannels(ctx context.Context, req *RetrieveContext, manager searchstore.Manager) (result []*schema.Document, err error) {
	// 1. 查询预处理:根据策略选择使用原始查询还是重写后的查询
	query := req.OriginQuery
	if req.Strategy.EnableQueryRewrite && req.RewrittenQuery != nil {
		query = *req.RewrittenQuery
	}
	
	// 2. 并发控制:使用互斥锁保护结果合并,限制最大并发数为2
	mu := sync.Mutex{}
	eg, ctx := errgroup.WithContext(ctx)
	eg.SetLimit(2) // 最多同时处理2个知识库,避免资源竞争
	
	// 3. 遍历所有知识库,为每个知识库创建独立的检索任务
	for knowledgeID, knowledgeInfo := range req.KnowledgeInfoMap {
		kid := knowledgeID
		info := knowledgeInfo
		collectionName := getCollectionName(kid) // 获取知识库对应的集合名称

		// 4. DSL条件构建:创建文档ID的过滤条件,确保只检索指定文档
		dsl := &searchstore.DSL{
			Op:    searchstore.OpIn,           // 使用IN操作符
			Field: "document_id",              // 过滤字段为document_id
			Value: knowledgeInfo.DocumentIDs,  // 过滤值为指定的文档ID列表
		}
		
		// 5. 分区过滤:根据文档ID创建分区列表,限制检索范围
		partitions := make([]string, 0, len(req.Documents))
		for _, doc := range req.Documents {
			if doc.KnowledgeID == kid {
				partitions = append(partitions, strconv.FormatInt(doc.ID, 10))
			}
		}
		if len(partitions) == 0 {
			continue // 如果没有相关文档,跳过当前知识库
		}
		
		// 6. 构建检索选项:设置分区键、分区列表和DSL过滤条件
		opts := []retriever.Option{
			searchstore.WithRetrieverPartitionKey(fieldNameDocumentID), // 设置分区键
			searchstore.WithPartitions(partitions),                     // 设置分区列表
			retriever.WithDSLInfo(dsl.DSL()),                          // 设置DSL过滤条件
		}
		
		// 7. 多字段匹配:对于表格文档,支持在多个索引字段中进行匹配
		if info.DocumentType == knowledgeModel.DocumentTypeTable && !k.enableCompactTable {
			var matchCols []string
			for _, col := range info.TableColumns {
				if col.Indexing { // 只对设置了索引的字段进行匹配
					matchCols = append(matchCols, getColName(col.ID))
				}
			}
			opts = append(opts, searchstore.WithMultiMatch(matchCols, query)) // 添加多字段匹配选项
		}
		
		// 8. 并发检索:使用goroutine并发执行检索任务
		eg.Go(func() error {
			// 获取当前知识库的搜索存储实例
			ss, err := manager.GetSearchStore(ctx, collectionName)
    if err != nil {
				return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", err.Error()))
    }
    
			// 执行检索操作
			retrievedDocs, err := ss.Retrieve(ctx, query, opts...)
    if err != nil {
				return errorx.New(errno.ErrKnowledgeRetrieveExecFailCode, errorx.KV("msg", err.Error()))
			}
			
			// 9. 结果合并:使用互斥锁安全地合并检索结果
			mu.Lock()
			result = append(result, retrievedDocs...)
			mu.Unlock()
			return nil
		})
	}
	
	// 10. 等待所有并发任务完成,检查是否有错误
	if err = eg.Wait(); err != nil {
		return nil, err
	}
	return
}

五、重排序与结果融合

5.1 重排序机制

重排序机制深度剖析:

  1. 结果收集:从并行执行的结果映射中提取各种检索方式的结果
  2. 数据转换:将文档转换为重排序器需要的格式,包含文档和分数信息
  3. 策略路由:根据检索策略选择要重排序的结果集:
    • 语义搜索:只重排序向量检索结果
    • 全文搜索:只重排序ES检索结果
    • 混合搜索:重排序向量和ES检索结果
    • NL2SQL:如果启用,则包含NL2SQL结果
  4. 查询选择:选择使用原始查询或重写后的查询
  5. 重排序执行:调用重排序器对所有结果进行重新排序
  6. 分数过滤:过滤掉分数低于最小阈值的结果

重排序算法原理:

  • 多模态融合:将不同检索方式的结果进行统一排序
  • 语义相关性:基于查询与文档的语义相似度进行重排序
  • 分数归一化:将不同检索方式的分数进行标准化处理
  • 阈值过滤:通过最小分数阈值过滤低质量结果
  • 结果去重:去除重复的文档,确保结果唯一性
// backend/domain/knowledge/service/retrieve.go
func (k *knowledgeSVC) reRankNode(ctx context.Context, resultMap map[string]any) (retrieveResult []*schema.Document, err error) {
	// 首先获取下retrieve上下文
	retrieveCtx, ok := resultMap["passRequestContext"].(*RetrieveContext)
	if !ok {
		logs.CtxErrorf(ctx, "retrieve context is not found")
		return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "retrieve context is not found"))
	}
	// 获取下向量化召回的接口
	vectorRetrieveResult, ok := resultMap["vectorRetrieveNode"].([]*schema.Document)
	if !ok {
		logs.CtxErrorf(ctx, "vector retrieve result is not found")
		vectorRetrieveResult = []*schema.Document{}
	}
	// 获取下es召回的接口
	esRetrieveResult, ok := resultMap["esRetrieveNode"].([]*schema.Document)
	if !ok {
		logs.CtxErrorf(ctx, "es retrieve result is not found")
		esRetrieveResult = []*schema.Document{}
	}
	// 获取下nl2sql召回的接口
	nl2SqlRetrieveResult, ok := resultMap["nl2SqlRetrieveNode"].([]*schema.Document)
	if !ok {
		logs.CtxErrorf(ctx, "nl2sql retrieve result is not found")
		nl2SqlRetrieveResult = []*schema.Document{}
	}

	docs2RerankData := func(docs []*schema.Document) []*rerank.Data {
		data := make([]*rerank.Data, 0, len(docs))
		for i := range docs {
			doc := docs[i]
			data = append(data, &rerank.Data{Document: doc, Score: doc.Score()})
		}
		return data
	}

	// 根据召回策略从不同渠道获取召回结果
	var retrieveResultArr [][]*rerank.Data
	if retrieveCtx.Strategy.EnableNL2SQL {
		// nl2sql结果
		retrieveResultArr = append(retrieveResultArr, docs2RerankData(nl2SqlRetrieveResult))
	}
	switch retrieveCtx.Strategy.SearchType {
	case knowledgeModel.SearchTypeSemantic:
		retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
	case knowledgeModel.SearchTypeFullText:
		retrieveResultArr = append(retrieveResultArr, docs2RerankData(esRetrieveResult))
	case knowledgeModel.SearchTypeHybrid:
		retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
		retrieveResultArr = append(retrieveResultArr, docs2RerankData(esRetrieveResult))
	default:
		retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
	}

	query := retrieveCtx.OriginQuery
	if retrieveCtx.Strategy.EnableQueryRewrite && retrieveCtx.RewrittenQuery != nil {
		query = ptr.From(retrieveCtx.RewrittenQuery)
	}

	resp, err := k.reranker.Rerank(ctx, &rerank.Request{
		Query: query,
		Data:  retrieveResultArr,
		TopN:  retrieveCtx.Strategy.TopK,
	})
    if err != nil {
		logs.CtxErrorf(ctx, "rerank failed: %v", err)
        return nil, err
    }
    
	retrieveResult = make([]*schema.Document, 0, len(resp.SortedData))
	for _, item := range resp.SortedData {
		if item.Score < ptr.From(retrieveCtx.Strategy.MinScore) {
			continue
		}
		doc := item.Document
		doc.WithScore(item.Score)
		retrieveResult = append(retrieveResult, doc)
	}

	return retrieveResult, nil
}

5.2 RRF重排序实现

// backend/infra/impl/document/rerank/rrf/rrf.go
// RRF重排序器实现
func NewRRFReranker(k int64) rerank.Reranker {
	if k == 0 {
		k = 60
	}
	return &rrfReranker{k}
}

type rrfReranker struct {
	k int64
}

func (r *rrfReranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
	if req == nil || req.Data == nil || len(req.Data) == 0 {
		return nil, fmt.Errorf("invalid request: no data provided")
	}
	
	// 计算每个文档的RRF分数
	id2Score := make(map[string]float64)
	id2Data := make(map[string]*rerank.Data)
	
	for _, resultList := range req.Data {
		for rank := range resultList {
			result := resultList[rank]
			if result != nil && result.Document != nil {
				// RRF公式:1 / (k + rank)
				score := 1.0 / (float64(rank) + float64(r.k))
				if score > id2Score[result.Document.ID] {
					id2Score[result.Document.ID] = score
					id2Data[result.Document.ID] = result
				}
			}
		}
	}
	
	// 按分数排序
	var sorted []*rerank.Data
	for _, data := range id2Data {
		sorted = append(sorted, data)
	}
	sort.Slice(sorted, func(i, j int) bool {
		return id2Score[sorted[i].Document.ID] > id2Score[sorted[j].Document.ID]
	})
	
	// 限制返回数量
	topN := int64(len(sorted))
	if req.TopN != nil && ptr.From(req.TopN) != 0 && ptr.From(req.TopN) < topN {
		topN = ptr.From(req.TopN)
	}

	return &rerank.Response{SortedData: sorted[:topN]}, nil
}

RRF算法说明:

  • k参数:控制不同来源的权重,k越大,排名差异越小(默认60)
  • RRF公式1 / (k + rank),排名越靠前分数越高
  • 分数聚合:同一文档在不同来源中取RRF分数最高的作为最终分数(Coze Studio定制实现)
  • 排序输出:按RRF分数降序排列,返回TopN结果

RRF算法示例:

假设有三个检索来源(向量检索、ES检索、NL2SQL检索),每个来源返回以下结果:

来源 排名 文档ID 原始分数
向量检索 1 doc_001 0.95
向量检索 2 doc_002 0.87
向量检索 3 doc_003 0.76
ES检索 1 doc_002 0.92
ES检索 2 doc_001 0.88
ES检索 3 doc_004 0.81
NL2SQL检索 1 doc_003 0.89
NL2SQL检索 2 doc_005 0.85

RRF分数计算(k=60):

文档ID 向量检索RRF ES检索RRF NL2SQL检索RRF 最终RRF分数(取最高)
doc_001 1/(60+1) = 0.0164 1/(60+2) = 0.0161 - 0.0164
doc_002 1/(60+2) = 0.0161 1/(60+1) = 0.0164 - 0.0164
doc_003 1/(60+3) = 0.0159 - 1/(60+1) = 0.0164 0.0164
doc_004 - 1/(60+3) = 0.0159 - 0.0159
doc_005 - - 1/(60+2) = 0.0161 0.0161

最终排序结果:

  1. doc_001 (RRF: 0.0164)
  2. doc_002 (RRF: 0.0164)
  3. doc_003 (RRF: 0.0164)
  4. doc_005 (RRF: 0.0161)
  5. doc_004 (RRF: 0.0159)

算法特点:

  • 取最高分数:同一文档在不同来源中取RRF分数最高的作为最终分数
  • 排名敏感:排名越靠前,RRF分数越高
  • 去重处理:避免同一文档重复出现
  • 可调节性:通过k参数控制排名差异的影响程度

5.3 结果打包与格式化

结果打包的作用:

  1. 去重处理:去除重复的检索结果
  2. 元数据丰富:添加文档信息、来源说明等
  3. 数量控制:限制返回结果数量
  4. 格式统一:统一结果格式,便于后续处理
// backend/domain/knowledge/service/retrieve.go
// 结果打包节点
func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema.Document) (results []*knowledgeModel.RetrieveSlice, err error) {
	if len(retrieveResult) == 0 {
		return nil, nil
	}
	
	// 1. 收集所有相关的ID
	sliceIDs := make(sets.Set[int64])
	docIDs := make(sets.Set[int64])
	knowledgeIDs := make(sets.Set[int64])
	documentMap := map[int64]*model.KnowledgeDocument{}
	knowledgeMap := map[int64]*model.Knowledge{}
	sliceScoreMap := map[int64]float64{}
	
	for _, doc := range retrieveResult {
		id, err := strconv.ParseInt(doc.ID, 10, 64)
		if err != nil {
			logs.CtxErrorf(ctx, "convert id failed: %v", err)
			return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
		}
		sliceIDs[id] = struct{}{}
		sliceScoreMap[id] = doc.Score()
	}
	
	// 2. 批量查询相关数据
	slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
	if err != nil {
		logs.CtxErrorf(ctx, "mget slices failed: %v", err)
		return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
	}
	
	for _, slice := range slices {
		docIDs[slice.DocumentID] = struct{}{}
		knowledgeIDs[slice.KnowledgeID] = struct{}{}
	}
	
	// 3. 查询知识库和文档信息
	knowledgeModels, err := k.knowledgeRepo.FilterEnableKnowledge(ctx, knowledgeIDs.ToSlice())
	if err != nil {
		logs.CtxErrorf(ctx, "filter enable knowledge failed: %v", err)
		return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
	}
	
	for _, kn := range knowledgeModels {
		knowledgeMap[kn.ID] = kn
	}
	
	documents, err := k.documentRepo.MGetByID(ctx, docIDs.ToSlice())
    if err != nil {
		logs.CtxErrorf(ctx, "mget documents failed: %v", err)
		return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
	}
	
	for _, doc := range documents {
		documentMap[doc.ID] = doc
	}
	
	// 4. 构建结果
	results = []*knowledgeModel.RetrieveSlice{}
	for i := range slices {
		doc := documentMap[slices[i].DocumentID]
		kn := knowledgeMap[slices[i].KnowledgeID]
		
		sliceEntity := entity.Slice{
			Info: knowledgeModel.Info{
				ID:          slices[i].ID,
				CreatorID:   slices[i].CreatorID,
				SpaceID:     doc.SpaceID,
				AppID:       kn.AppID,
				CreatedAtMs: slices[i].CreatedAt,
				UpdatedAtMs: slices[i].UpdatedAt,
			},
			KnowledgeID:  slices[i].KnowledgeID,
			DocumentID:   slices[i].DocumentID,
			DocumentName: doc.Name,
			Sequence:     int64(slices[i].Sequence),
			ByteCount:    int64(len(slices[i].Content)),
			SliceStatus:  knowledgeModel.SliceStatus(slices[i].Status),
			CharCount:    int64(utf8.RuneCountInString(slices[i].Content)),
		}
		
		// 5. 添加额外信息
		docUri := documentMap[slices[i].DocumentID].URI
		var docURL string
		if len(docUri) != 0 {
			docURL, err = k.storage.GetObjectUrl(ctx, docUri)
			if err != nil {
				logs.CtxErrorf(ctx, "get object url failed: %v", err)
				return nil, errorx.New(errno.ErrKnowledgeGetObjectURLFailCode, errorx.KV("msg", err.Error()))
			}
		}
		
		sliceEntity.Extra = map[string]string{
			consts.KnowledgeName: kn.Name,
			consts.DocumentURL:   docURL,
		}
		
		// 6. 根据文档类型处理内容
		switch knowledgeModel.DocumentType(doc.DocumentType) {
		case knowledgeModel.DocumentTypeText:
			sliceEntity.RawContent = []*knowledgeModel.SliceContent{
				{Type: knowledgeModel.SliceContentTypeText, Text: ptr.Of(k.formatSliceContent(ctx, slices[i].Content))},
			}
		case knowledgeModel.DocumentTypeTable:
			// 表格数据处理逻辑
		case knowledgeModel.DocumentTypeImage:
			img := fmt.Sprintf(`<img src="" data-tos-key="%s">`, documentMap[slices[i].DocumentID].URI)
			sliceEntity.RawContent = []*knowledgeModel.SliceContent{
				{Type: knowledgeModel.SliceContentTypeText, Text: ptr.Of(k.formatSliceContent(ctx, img+slices[i].Content))},
			}
		default:
		}
		
		results = append(results, &knowledgeModel.RetrieveSlice{
			Slice: &sliceEntity,
			Score: sliceScoreMap[slices[i].ID],
		})
	}
	
	// 7. 更新命中计数
	err = k.sliceRepo.IncrementHitCount(ctx, sliceIDs.ToSlice())
    if err != nil {
		logs.CtxWarnf(ctx, "increment hit count failed: %v", err)
	}
	
	return results, nil
}
Logo

中国智能体开发者社区,聚焦智能体与大模型开发,提供前沿资讯、实用工具链、开源项目及行业案例。通过技术沙龙、开发者大赛等活动,促进经验交流与协作,助力开发者快速构建创新智能应用。

更多推荐