【扣子源码分析】Coze Studio RAG 技术深度解析(二):知识库多种检索方式完整实现
本文基于字节跳动Coze Studio源码,深入解析知识库RAG(检索增强生成)系统的多模态检索实现。系统采用并行处理架构,支持向量检索、全文检索和结构化数据检索三种方式,通过RRF算法实现智能结果融合。
本文档基于字节跳动(扣子) Coze Studio 实际源码分析,全面解析知识库RAG多种方式检索的技术实现。涵盖向量检索、全文检索、结构化数据检索的并行处理机制,以及基于 RRF 算法的智能结果融合策略。
🚀项目地址:https://github.com/coze-dev/coze-studio
📋 目录
一、检索服务架构概览
1.1 核心功能模块
1.1.1 各检索方式功能
- 向量检索:基于语义相似度的智能检索,理解用户问题的含义
- ES检索:基于关键词的传统检索,精确匹配特定词汇
- NL2SQL检索:自然语言转SQL的数据库查询,处理结构化数据
- 混合检索:多种检索方式的智能融合,提高检索覆盖度
1.1.2 重排序与融合功能
- RRF重排序:基于多种检索结果的智能重排序算法
- 分数过滤:过滤低质量检索结果,提高结果质量
- 结果融合:将不同来源的结果智能合并,去除重复
- 元数据丰富:为检索结果添加文档信息、来源说明等
1.1.3 链式处理架构
- 并行检索:同时执行多种检索方式,提高响应速度
- 重排序处理:对检索结果进行智能排序和过滤
- 结果打包:格式化最终结果,便于后续处理
- 错误处理:优雅处理各种检索失败情况
二、RAG检索流程
2.1 检索主入口时序图
时序图说明:
- 并行检索:向量检索、ES检索、NL2SQL检索同时进行,提高响应速度
- 结果融合:多种检索结果通过重排序算法智能融合
- 错误隔离:单个检索方式失败不影响其他检索方式
- 性能优化:通过并行处理显著减少总体响应时间
2.2 检索主入口实现
检索主入口执行流程:
- 参数验证:检查请求参数的有效性
- 上下文构建:创建检索上下文,包含文档、策略等信息
- 处理链构建:组装查询重写、并行检索、重排序、结果打包的处理链
- 并行检索:同时执行向量检索、ES检索、NL2SQL检索
- 结果融合:通过重排序融合多种检索结果
- 结果打包:将最终结果转换为响应格式
// backend/domain/knowledge/service/retrieve.go
func (k *knowledgeSVC) Retrieve(ctx context.Context, request *RetrieveRequest) (response *RetrieveResponse, err error) {
if request == nil {
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "request is nil"))
}
if len(request.Query) == 0 {
return &knowledgeModel.RetrieveResponse{}, nil
}
// 1. 构建检索上下文
retrieveContext, err := k.newRetrieveContext(ctx, request)
if err != nil {
return nil, err
}
if len(retrieveContext.Documents) == 0 {
return &knowledgeModel.RetrieveResponse{}, nil
}
// 2. 构建处理链:查询重写 -> 并行检索 -> 重排序 -> 结果打包
chain := compose.NewChain[*RetrieveContext, []*knowledgeModel.RetrieveSlice]()
rewriteNode := compose.InvokableLambda(k.queryRewriteNode)
vectorRetrieveNode := compose.InvokableLambda(k.vectorRetrieveNode)
EsRetrieveNode := compose.InvokableLambda(k.esRetrieveNode)
Nl2SqlRetrieveNode := compose.InvokableLambda(k.nl2SqlRetrieveNode)
passRequestContextNode := compose.InvokableLambda(k.passRequestContext)
reRankNode := compose.InvokableLambda(k.reRankNode)
packResult := compose.InvokableLambda(k.packResults)
// 3. 并行检索节点:同时执行向量、ES、NL2SQL检索
parallelNode := compose.NewParallel().
AddLambda("vectorRetrieveNode", vectorRetrieveNode).
AddLambda("esRetrieveNode", EsRetrieveNode).
AddLambda("nl2SqlRetrieveNode", Nl2SqlRetrieveNode).
AddLambda("passRequestContext", passRequestContextNode)
// 4. 编译并执行处理链
r, err := chain.
AppendLambda(rewriteNode). // 第1步:查询重写
AppendParallel(parallelNode). // 第2步:并行检索
AppendLambda(reRankNode). // 第3步:重排序
AppendLambda(packResult). // 第4步:结果打包
Compile(ctx)
if err != nil {
logs.CtxErrorf(ctx, "compile chain failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeBuildRetrieveChainFailCode, errorx.KV("msg", err.Error()))
}
// 5. 执行检索流程
output, err := r.Invoke(ctx, retrieveContext)
if err != nil {
return nil, err
}
return &knowledgeModel.RetrieveResponse{
Slices: output,
}, nil
}
2.2.1 检索上下文管理
检索上下文的作用:
- 参数管理:统一管理检索参数和配置
- 状态跟踪:跟踪检索过程中的状态变化
- 结果传递:在不同处理节点间传递结果
// backend/domain/knowledge/service/interface.go
// 检索上下文
type RetrieveContext struct {
Ctx context.Context
OriginQuery string // 原始 query
RewrittenQuery *string // 改写后的 query, 如果没有改写,就是 nil, 会在执行过程中添加上去
ChatHistory []*schema.Message // 如果没有对话历史或者不需要历史,则为 nil
KnowledgeIDs sets.Set[int64] // 本次检索涉及的知识库id
KnowledgeInfoMap map[int64]*KnowledgeInfo // 知识库id到文档id的映射
// 召回策略
Strategy *entity.RetrievalStrategy
// 检索涉及的 document 信息
Documents []*model.KnowledgeDocument
// 用于 nl2sql 和 message to query 的 chat model
ChatModel chatmodel.BaseChatModel
}
// 构建检索上下文
func (k *knowledgeSVC) newRetrieveContext(ctx context.Context, req *RetrieveRequest) (*RetrieveContext, error) {
// 1. 验证策略参数
if req.Strategy == nil {
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode, errorx.KV("msg", "strategy is required"))
}
// 2. 准备RAG文档:过滤启用的知识库和文档
knowledgeIDSets := sets.FromSlice(req.KnowledgeIDs)
docIDSets := sets.FromSlice(req.DocumentIDs)
enableDocs, enableKnowledge, err := k.prepareRAGDocuments(ctx, docIDSets.ToSlice(), knowledgeIDSets.ToSlice())
if err != nil {
logs.CtxErrorf(ctx, "prepare rag documents failed: %v", err)
return nil, err
}
if len(enableDocs) == 0 {
return &RetrieveContext{}, nil
}
// 3. 构建知识库信息映射
knowledgeInfoMap := make(map[int64]*KnowledgeInfo)
for _, kn := range enableKnowledge {
if knowledgeInfoMap[kn.ID] == nil {
knowledgeInfoMap[kn.ID] = &KnowledgeInfo{}
knowledgeInfoMap[kn.ID].DocumentType = knowledgeModel.DocumentType(kn.FormatType)
knowledgeInfoMap[kn.ID].DocumentIDs = []int64{}
}
}
// 4. 关联文档到知识库
for _, doc := range enableDocs {
info, found := knowledgeInfoMap[doc.KnowledgeID]
if !found {
continue
}
info.DocumentIDs = append(info.DocumentIDs, doc.ID)
if info.DocumentType == knowledgeModel.DocumentTypeTable && info.TableColumns == nil && doc.TableInfo != nil {
info.TableColumns = doc.TableInfo.Columns
}
}
// 5. 创建聊天模型(如果配置了)
var cm chatmodel.BaseChatModel
if req.ChatModelProtocol != nil && req.ChatModelConfig != nil {
cm, err = k.modelFactory.CreateChatModel(ctx, ptr.From(req.ChatModelProtocol), req.ChatModelConfig)
if err != nil {
return nil, errorx.New(errno.ErrKnowledgeInvalidParamCode,
errorx.KV("msg", "invalid retriever chat model protocol or config"))
}
}
// 6. 构建检索上下文
resp := RetrieveContext{
Ctx: ctx,
OriginQuery: req.Query,
ChatHistory: append(req.ChatHistory, schema.UserMessage(req.Query)),
KnowledgeIDs: knowledgeIDSets,
KnowledgeInfoMap: knowledgeInfoMap,
Strategy: req.Strategy,
Documents: enableDocs,
ChatModel: cm,
}
return &resp, nil
}
2.2.1 查询重写节点实现
查询重写机制深度解析:
查询重写是RAG系统中的重要优化步骤,它通过分析聊天历史来优化用户查询,提高检索的准确性和相关性。
核心功能定位:
- 上下文理解:基于聊天历史理解用户查询的上下文
- 查询优化:将简单查询转换为更精确的检索查询
- 意图识别:识别用户的真实检索意图
- 相关性提升:通过重写提高检索结果的相关性
执行条件:
- 有聊天历史:必须存在聊天历史记录
- 启用重写功能:策略中
EnableQueryRewrite为true - 重写器可用:系统配置了查询重写器
重写流程:
- 历史检查:检查是否存在聊天历史
- 功能检查:检查是否启用查询重写功能
- 模型调用:使用大语言模型进行查询重写
- 结果应用:将重写后的查询应用到检索流程
// backend/domain/knowledge/service/retrieve.go
func (k *knowledgeSVC) queryRewriteNode(ctx context.Context, req *RetrieveContext) (newRetrieveContext *RetrieveContext, err error) {
// 1. 检查聊天历史:如果没有历史记录,直接返回原始请求
if len(req.ChatHistory) == 0 {
return req, nil
}
// 2. 检查重写功能:如果未启用重写或重写器不可用,直接返回
if !req.Strategy.EnableQueryRewrite || k.rewriter == nil {
return req, nil
}
// 3. 构建重写选项:如果指定了聊天模型则使用指定模型
var opts []messages2query.Option
if req.ChatModel != nil {
opts = append(opts, messages2query.WithChatModel(req.ChatModel))
}
// 4. 执行查询重写:基于聊天历史重写用户查询
rewrittenQuery, err := k.rewriter.MessagesToQuery(ctx, req.ChatHistory, opts...)
if err != nil {
logs.CtxErrorf(ctx, "rewrite query failed: %v", err)
return req, nil // 重写失败时返回原始请求
}
// 5. 应用重写结果:将重写后的查询设置到请求中
req.RewrittenQuery = &rewrittenQuery
return req, nil
}
2.3 检索策略配置
检索策略参数说明:
// backend/api/model/crossdomain/knowledge/knowledge.go
type RetrievalStrategy struct {
TopK *int64 // 1-10 default 3,最终返回数量
MinScore *float64 // 0.01-0.99 default 0.5,最小匹配分数
MaxTokens *int64 // 最大token数
SelectType SelectType // 调用方式
SearchType SearchType // 搜索策略
EnableQueryRewrite bool // 是否启用查询重写
EnableRerank bool // 是否启用重排序
EnableNL2SQL bool // 是否启用NL2SQL
}
type SearchType int64
const (
SearchTypeSemantic SearchType = 0 // 语义搜索:仅向量检索
SearchTypeFullText SearchType = 1 // 全文搜索:仅ES检索
SearchTypeHybrid SearchType = 2 // 混合搜索:向量+ES
)
type SelectType int64
const (
SelectTypeAuto = 0 // 自动调用
SelectTypeOnDemand = 1 // 按需调用
)
2.3.1 检索参数配置表
| 参数名称 | 默认值 | 取值范围 | 作用说明 | 影响 |
|---|---|---|---|---|
| TopK | 3 | 1-10 | 最终返回结果数量 | 控制输出数量 |
| MinScore | 0.5 | 0.01-0.99 | 最小匹配分数阈值 | 过滤低质量结果 |
| MaxTokens | - | 动态 | 最大token数限制 | 控制处理复杂度 |
| SearchType | Semantic | Semantic/FullText/Hybrid | 搜索策略选择 | 决定检索方式 |
| EnableQueryRewrite | false | true/false | 是否启用查询重写 | 优化查询质量 |
| EnableRerank | false | true/false | 是否启用重排序 | 提升结果质量 |
| EnableNL2SQL | false | true/false | 是否启用NL2SQL | 支持结构化查询 |
2.3.2 多种检索方式数据量对比表
| 检索方式 | 默认检索数量 | 检索原理 | 适用场景 | 优势 | 劣势 |
|---|---|---|---|---|---|
| 向量检索 | 4条 | 向量相似度匹配 | 语义搜索 | 语义理解强、容错性好 | 计算开销大 |
| ES检索 | 10条 | 倒排索引+TF-IDF | 全文搜索 | 精确匹配、速度快 | 语义理解弱 |
| NL2SQL检索 | 动态 | 自然语言转SQL | 表格数据 | 结构化查询 | 仅限表格数据 |
| 混合检索 | 向量4条+ES10条 | 多路并行+重排序 | 综合场景 | 全面覆盖 | 计算复杂度高 |
2.3.3 重排序阶段
- 合并候选:向量4条 + ES 10条 = 最多14条候选
- TopK筛选:根据TopK参数选择最终数量(默认3条)
- MinScore过滤:根据MinScore参数过滤低分结果(默认0.5)
2.3.4 最终输出
- 最大数量:由TopK参数控制(默认3条)
- 质量过滤:由MinScore参数控制(默认0.5)
- 实际数量:可能少于TopK,取决于匹配质量
三、多种检索方式核心实现
3.1 向量检索实现
向量检索原理:
- 文本向量化:将用户问题和文档都转换为向量
- 相似度计算:计算向量之间的余弦相似度
- 排序返回:按相似度排序返回最相关的结果
3.1.1 向量检索常量定义
// backend/infra/impl/document/searchstore/milvus/consts.go
const (
batchSize = 100
topK = 4 // 向量检索默认返回4条结果
)
数据量说明: 向量检索默认返回4条结果,这是基于语义相似度排序后的最相关文档。
3.1.2 向量检索入口
检索机制深度剖析:
- 检索策略判断:首先检查
SearchType,如果是SearchTypeFullText(仅全文搜索),则跳过向量检索 - 向量存储管理器选择:遍历所有搜索存储管理器,找到类型为
TypeVectorStore的Milvus管理器 - 通道检索调用:调用
retrieveChannels函数进行实际的向量检索操作
// backend/domain/knowledge/service/retrieve.go
// 向量检索节点
func (k *knowledgeSVC) vectorRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
// 1. 检查搜索类型,如果是全文搜索则跳过向量检索
if req.Strategy.SearchType == knowledgeModel.SearchTypeFullText {
return nil, nil
}
// 2. 获取向量存储管理器
var manager searchstore.Manager
for i := range k.searchStoreManagers {
m := k.searchStoreManagers[i]
if m != nil && m.GetType() == searchstore.TypeVectorStore {
manager = m
break
}
}
// 3. 检查向量存储是否可用
if manager == nil {
logs.CtxErrorf(ctx, "err:%s", errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", "未实现vectorStore")).Error())
return nil, nil
}
// 4. 执行向量检索 详情参考本文章:四、多知识库并发检索机制
retrieveResult, err = k.retrieveChannels(ctx, req, manager)
if err != nil {
logs.CtxErrorf(ctx, "retrieveChannels err:%s", err.Error())
}
return retrieveResult, nil
}
向量检索优势:
- 语义理解:能理解问题的含义,不只是关键词匹配
- 容错性强:即使用词不同,只要意思相近就能找到
- 扩展性好:支持同义词、近义词等语义扩展
3.1.3 Milvus向量检索核心实现
// backend/infra/impl/document/searchstore/milvus/milvus_searchstore.go
func (m *milvusSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
cli := m.config.Client
emb := m.config.Embedding
options := retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
// 1. 获取集合描述信息
desc, err := cli.DescribeCollection(ctx, client.NewDescribeCollectionOption(m.collectionName))
// 2. 查询向量化处理
if enableSparse {
dense, sparse, err = emb.EmbedStringsHybrid(ctx, []string{query}) // 混合向量化
} else {
dense, err = emb.EmbedStrings(ctx, []string{query}) // 密集向量化
}
// 3. 构建搜索请求
if enableSparse {
// 混合搜索:同时使用密集向量和稀疏向量
searchOption := client.NewHybridSearchOption(m.collectionName, ptr.From(options.TopK), annRequests...).
WithPartitons(implSpecOptions.Partitions...).
WithReranker(client.NewRRFReranker()). // 使用RRF重排序
WithOutputFields(outputFields...)
result, err = cli.HybridSearch(ctx, searchOption)
} else {
// 纯向量搜索
searchOption := client.NewSearchOption(m.collectionName, ptr.From(options.TopK), dv).
WithPartitions(implSpecOptions.Partitions...).
WithFilter(expr).
WithOutputFields(outputFields...).
WithSearchParam(mindex.MetricTypeKey, string(metricsType))
result, err = cli.Search(ctx, searchOption)
}
// 4. 结果转换
docs, err := m.resultSet2Document(result, scoreNormType)
return docs, nil
}
向量检索算法原理:
- 密集向量检索:使用余弦相似度或欧几里得距离计算查询向量与文档向量的相似度
- 混合向量检索:同时使用密集向量(语义信息)和稀疏向量(关键词信息),通过RRF算法融合结果
- 分区过滤:根据文档ID进行分区过滤,提高检索效率
- 相似度计算:支持多种距离度量方式(IP、L2、COSINE等)
3.2 Elasticsearch检索实现
Elasticsearch检索原理:
- 倒排索引:建立词汇到文档的映射
- TF-IDF算法:计算词汇在文档中的重要性
- 布尔查询:支持AND、OR、NOT等逻辑运算3.2.1 ES检索常量定义
3.2.1 Elasticsearch检索常量
// backend/infra/impl/document/searchstore/elasticsearch/consts.go
const (
topK = 10 // Elasticsearch检索默认返回10条结果
)
数据量说明: Elasticsearch检索默认返回10条结果,这是基于关键词匹配和TF-IDF算法排序后的文档。
3.2.2 Elasticsearch检索入口
检索机制深度剖析:
- 检索策略判断:检查
SearchType,如果是SearchTypeSemantic(仅语义搜索),则跳过Elasticsearch检索 - 文本存储管理器选择:找到类型为
TypeTextStore的Elasticsearch管理器 - 通道检索调用:调用
retrieveChannels函数进行实际的Elasticsearch检索操作
// backend/domain/knowledge/service/retrieve.go
// ES检索节点
func (k *knowledgeSVC) esRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
// 1. 检查搜索类型,如果是语义搜索则跳过ES检索
if req.Strategy.SearchType == knowledgeModel.SearchTypeSemantic {
return nil, nil
}
// 2. 获取文本存储管理器
var manager searchstore.Manager
for i := range k.searchStoreManagers {
m := k.searchStoreManagers[i]
if m != nil && m.GetType() == searchstore.TypeTextStore {
manager = m
break
}
}
// 3. 检查ES存储是否可用
if manager == nil {
logs.CtxErrorf(ctx, "err:%s", errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", "未实现esStore")).Error())
return nil, nil
}
// 4. 执行ES检索 详情参考本文章:四、多知识库并发检索机制
retrieveResult, err = k.retrieveChannels(ctx, req, manager)
if err != nil {
logs.CtxErrorf(ctx, "retrieveChannels err:%s", err.Error())
}
return retrieveResult, nil
}
ES检索特点:
- 精确匹配:能找到包含特定关键词的文档
- 过滤支持:支持复杂的过滤条件
- 快速检索:基于倒排索引,检索速度极快
3.2.3 Elasticsearch检索核心实现
// backend/infra/impl/document/searchstore/elasticsearch/elasticsearch_searchstore.go
func (e *esSearchStore) Retrieve(ctx context.Context, query string, opts ...retriever.Option) ([]*schema.Document, error) {
var (
cli = e.config.Client
index = e.indexName
options = retriever.GetCommonOptions(&retriever.Options{TopK: ptr.Of(topK)}, opts...)
req = &es.Request{
Query: &es.Query{
Bool: &es.BoolQuery{},
},
Size: options.TopK,
}
)
// 1. 构建查询条件
if implSpecOptions.MultiMatch == nil {
// 单字段匹配
req.Query.Bool.Must = append(req.Query.Bool.Must,
es.NewMatchQuery(searchstore.FieldTextContent, query))
} else {
// 多字段匹配
req.Query.Bool.Must = append(req.Query.Bool.Must,
es.NewMultiMatchQuery(implSpecOptions.MultiMatch.Fields, query,
"best_fields", es.Or))
}
// 2. 应用DSL过滤条件
dsl, err := searchstore.LoadDSL(options.DSLInfo)
if err = e.travDSL(req.Query, dsl); err != nil {
return nil, err
}
// 3. 设置分数阈值
if options.ScoreThreshold != nil {
req.MinScore = options.ScoreThreshold
}
// 4. 执行搜索
resp, err := cli.Search(ctx, index, req)
// 5. 解析结果
docs, err := e.parseSearchResult(resp)
return docs, nil
}
ES检索算法原理:
- 倒排索引查询:基于关键词的倒排索引快速定位包含查询词的文档
- TF-IDF评分:计算词频-逆文档频率,评估文档与查询的相关性
- 多字段匹配:支持在多个字段中同时搜索,使用best_fields策略
- 布尔查询:支持复杂的AND、OR、NOT逻辑组合
- 分数阈值过滤:只返回分数超过阈值的文档
3.3 NL2SQL检索实现
3.3.1 NL2SQL检索机制深度剖析
NL2SQL检索算法原理:
- 语义理解:通过大语言模型理解用户查询的语义意图
- SQL生成:将自然语言查询转换为标准的SQL查询语句
- 表结构映射:将逻辑表名映射到物理表名,字段名映射到数据库字段
- 精确匹配:基于结构化数据的精确查询,返回完全匹配的结果
- 并行查询:支持对多个表格同时进行NL2SQL查询
NL2SQL执行逻辑详解:
- 输入处理阶段:
- 接收用户自然语言查询和聊天历史
- 识别表格文档并提取表结构信息
- 构建表格描述字符串(包含表名、字段名、字段类型、是否必填等)
- 模型调用阶段:
- 使用Jinja2模板将表格描述和用户查询组合成提示
- 调用配置的大语言模型(如GPT、Claude等)
- 模型返回JSON格式的SQL语句和错误信息
- SQL处理阶段:
- 解析模型返回的JSON结果
- 添加切片ID列到SQL语句中
- 构建表名和字段的映射关系(逻辑名到物理名的转换)
- SQL执行阶段:
- 使用SQL解析器修改SQL语句(替换表名和字段名)
- 在关系数据库中执行修改后的SQL
- 获取查询结果集
- 结果转换阶段:
- 将SQL查询结果转换为统一的文档格式
- 为每个结果分配默认分数1(表示精确匹配)
- 返回文档列表供后续重排序使用
数据量说明:
NL2SQL检索返回的结果数量取决于SQL查询的结果行数,每条结果默认分数为1,表示结构化数据的精确匹配。
3.3.2 NL2SQL检索入口
- 表格文档识别:遍历所有文档,识别类型为
DocumentTypeTable的表格文档 - 并行处理:使用
errgroup并发处理多个表格文档的NL2SQL查询 - 自然语言转SQL:调用NL2SQL模型将用户查询转换为SQL语句
- SQL执行:在关系数据库中执行生成的SQL查询
- 结果转换:将SQL查询结果转换为统一的文档格式
// backend/domain/knowledge/service/retrieve.go
// NL2SQL检索节点:处理表格文档的自然语言转SQL查询
// 参数:
// - ctx: 上下文,用于超时控制和日志记录
// - req: 检索上下文,包含用户查询、文档列表、策略配置等
// 返回:
// - retrieveResult: 检索结果文档列表
// - err: 错误信息
func (k *knowledgeSVC) nl2SqlRetrieveNode(ctx context.Context, req *RetrieveContext) (retrieveResult []*schema.Document, err error) {
hasTable := false // 标记是否存在表格文档
var tableDocs []*model.KnowledgeDocument // 存储所有表格文档
// 遍历所有文档,识别类型为DocumentTypeTable的表格文档
for _, doc := range req.Documents {
if doc.DocumentType == int32(knowledgeModel.DocumentTypeTable) {
hasTable = true
tableDocs = append(tableDocs, doc)
}
}
// 构建NL2SQL选项,如果指定了聊天模型则使用指定的模型
var opts []nl2sql.Option
if req.ChatModel != nil {
opts = append(opts, nl2sql.WithChatModel(req.ChatModel))
}
// 如果存在表格文档且启用了NL2SQL功能,则执行NL2SQL查询
if hasTable && req.Strategy.EnableNL2SQL {
mu := sync.Mutex{} // 互斥锁,用于保护并发写入结果
// 使用errgroup并发处理多个表格文档的NL2SQL查询
eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(len(tableDocs)) // 设置并发限制为表格文档数量
res := make([]*schema.Document, 0) // 存储所有查询结果
// 为每个表格文档启动一个goroutine进行并发查询
for i := range tableDocs {
t := i
eg.Go(func() error {
doc := tableDocs[t]
// 对单个表格文档执行NL2SQL查询
docs, execErr := k.nl2SqlExec(ctx, doc, req, opts)
if execErr != nil {
logs.CtxErrorf(ctx, "nl2sql exec failed: %v", execErr)
return errorx.New(errno.ErrKnowledgeNL2SqlExecFailCode, errorx.KV("msg", execErr.Error()))
}
// 线程安全地添加查询结果
mu.Lock()
res = append(res, docs...)
mu.Unlock()
return nil
})
}
// 等待所有goroutine完成
err = eg.Wait()
if err != nil {
logs.CtxErrorf(ctx, "nl2sql exec failed: %v", err)
return nil, nil
}
return res, nil
} else {
// 如果没有表格文档或未启用NL2SQL,返回空结果
return nil, nil
}
}
3.3.3 NL2SQL核心实现
// backend/domain/knowledge/service/retrieve.go
// NL2SQL执行函数:将自然语言查询转换为SQL并执行
// 参数:
// - ctx: 上下文
// - doc: 表格文档,包含表结构信息
// - retrieveCtx: 检索上下文,包含用户查询历史
// - opts: NL2SQL选项,如聊天模型配置
// 返回:
// - retrieveResult: 查询结果文档列表
// - err: 错误信息
func (k *knowledgeSVC) nl2SqlExec(ctx context.Context, doc *model.KnowledgeDocument, retrieveCtx *RetrieveContext, opts []nl2sql.Option) (
retrieveResult []*schema.Document, err error) {
// 步骤1: 调用NL2SQL模型生成SQL
// 将表格文档打包成TableSchema格式,传入聊天历史和选项
sql, err := k.nl2Sql.NL2SQL(ctx, retrieveCtx.ChatHistory, []*document.TableSchema{packNL2SqlRequest(doc)}, opts...)
if err != nil {
logs.CtxErrorf(ctx, "nl2sql failed: %v", err)
return nil, err
}
// 步骤2: 添加切片ID列到SQL中,用于标识数据来源
sql = addSliceIdColumn(sql)
// 步骤3: 构建表名和字段映射关系
// 将逻辑表名映射到物理表名,字段名映射到数据库字段名
replaceMap := map[string]sqlparsercontract.TableColumn{}
replaceMap[doc.Name] = sqlparsercontract.TableColumn{
NewTableName: ptr.Of(doc.TableInfo.PhysicalTableName), // 物理表名
ColumnMap: map[string]string{
pkID: consts.RDBFieldID, // 主键字段映射
},
}
// 遍历表格的所有列,建立字段名映射关系
for i := range doc.TableInfo.Columns {
if doc.TableInfo.Columns[i] == nil {
continue
}
// 跳过RDB字段ID,避免重复映射
if doc.TableInfo.Columns[i].Name == consts.RDBFieldID {
continue
}
// 将逻辑字段名映射到物理字段名
replaceMap[doc.Name].ColumnMap[doc.TableInfo.Columns[i].Name] = convert.ColumnIDToRDBField(doc.TableInfo.Columns[i].ID)
}
// 步骤4: 解析并修改SQL(表名和字段名映射)
// 使用SQL解析器将逻辑表名和字段名替换为物理名称
parsedSQL, err := sqlparser.NewSQLParser().ParseAndModifySQL(sql, replaceMap)
if err != nil {
logs.CtxErrorf(ctx, "parse sql failed: %v", err)
return nil, err
}
// 步骤5: 执行SQL查询
// 在关系数据库中执行修改后的SQL语句
resp, err := k.rdb.ExecuteSQL(ctx, &rdb.ExecuteSQLRequest{
SQL: parsedSQL,
})
if err != nil {
logs.CtxErrorf(ctx, "execute sql failed: %v", err)
return nil, err
}
// 步骤6: 转换结果为文档格式
// 将数据库查询结果转换为统一的文档格式,便于后续处理
for i := range resp.ResultSet.Rows {
// 从结果行中提取ID字段
id, ok := resp.ResultSet.Rows[i][consts.RDBFieldID].(int64)
if !ok {
logs.CtxWarnf(ctx, "convert id failed, row: %v", resp.ResultSet.Rows[i])
return nil, errors.New("convert id failed")
}
// 创建文档对象,设置ID和元数据
d := &schema.Document{
ID: strconv.FormatInt(id, 10), // 将ID转换为字符串
Content: "", // NL2SQL结果内容为空,因为结构化数据不需要文本内容
MetaData: map[string]any{}, // 元数据为空
}
d.WithScore(1) // NL2SQL结果默认分数为1,表示精确匹配
retrieveResult = append(retrieveResult, d)
}
return retrieveResult, nil
}
3.3.4 NL2SQL模型详细实现
// backend/infra/impl/document/nl2sql/builtin/nl2sql.go
// NL2SQL模型核心实现:将自然语言转换为SQL语句
// 参数:
// - ctx: 上下文
// - messages: 聊天历史消息列表
// - tables: 表格结构信息列表
// - opts: 可选的NL2SQL配置选项
// 返回:
// - sql: 生成的SQL语句
// - err: 错误信息
func (n *n2s) NL2SQL(ctx context.Context, messages []*schema.Message, tables []*document.TableSchema, opts ...nl2sql.Option) (sql string, err error) {
// 初始化NL2SQL选项,默认使用配置的聊天模型
o := &nl2sql.Options{ChatModel: n.cm}
for _, opt := range opts {
opt(o) // 应用自定义选项
}
// 检查聊天模型是否已配置
if o.ChatModel == nil {
return "", fmt.Errorf("[NL2SQL] chat model not configured")
}
// 构建处理链:输入处理 -> 模板渲染 -> 模型调用 -> 结果解析
c := compose.NewChain[*nl2sqlInput, string]().
// 第一步:输入处理Lambda,构建表格描述信息
AppendLambda(compose.InvokableLambda(func(ctx context.Context, input *nl2sqlInput) (output map[string]any, err error) {
// 检查表格信息是否为空
if len(input.tables) == 0 {
return nil, errors.New("table meta is empty")
}
// 构建表格描述信息字符串
tableDesc := strings.Builder{}
for _, table := range input.tables {
// 添加表名和表描述
tableDesc.WriteString(fmt.Sprintf(defaultTableFmt, table.Name, table.Comment))
// 添加每个字段的信息(字段名、描述、类型、是否必填)
for _, column := range table.Columns {
tableDesc.WriteString(fmt.Sprintf(defaultColumnFmt, column.Name, column.Description, column.Type.String(), !column.Nullable))
}
}
// 返回模板渲染所需的数据
return map[string]interface{}{
"messages": input.messages, // 聊天历史
"table_schema": tableDesc.String(), // 表格结构描述
}, nil
})).
// 第二步:使用Jinja2模板渲染提示
AppendChatTemplate(n.tpl).
// 第三步:调用大语言模型生成SQL
AppendChatModel(o.ChatModel).
// 第四步:结果解析Lambda,解析模型返回的JSON格式结果
AppendLambda(compose.InvokableLambda(func(ctx context.Context, msg *schema.Message) (sql string, err error) {
// 解析模型返回的JSON格式结果
var promptResp *promptResponse
if err := json.Unmarshal([]byte(msg.Content), &promptResp); err != nil {
logs.CtxWarnf(ctx, "unmarshal failed: %v", err)
return "", err
}
// 检查是否成功生成SQL
if promptResp.SQL == "" {
logs.CtxInfof(ctx, "no sql generated, err_code: %v, err_msg: %v", promptResp.ErrCode, promptResp.ErrMsg)
return "", errors.New(promptResp.ErrMsg)
}
return promptResp.SQL, nil
}))
// 编译处理链
r, err := c.Compile(ctx)
if err != nil {
return "", err
}
// 准备输入数据
input := &nl2sqlInput{
messages: messages, // 聊天历史
tables: tables, // 表格结构
}
// 执行处理链并返回生成的SQL
return r.Invoke(ctx, input)
}
3.3.5 NL2SQL提示模板
// backend/conf/prompt/nl2sql_template_jinja2.json
[
{
"role": "system",
"content": "# Role: NL2SQL Consultant\n\n## Goals\nTranslate natural language statements into SQL queries in MySQL standard. Follow the Constraints and return only a JSON always.\n\n## Format\n- JSON format only. JSON contains field \"sql\" for generated SQL, filed \"err_code\" for reason type, field \"err_msg\" for detail reason (prefer more than 10 words)\n- Don't use \"```json\" markdown format\n\n## Skills\n- Good at Translate natural language statements into SQL queries in MySQL standard.\n\n## Define\n\"err_code\" Reason Type Define:\n- 0 means you generated a SQL\n- 3002 means you cannot generate a SQL because of timeout\n- 3003 means you cannot generate a SQL because of table schema missing\n- 3005 means you cannot generate a SQL because of some term is ambiguous\n\n## Example\nQ: Help me implement NL2SQL.\n.table schema description: CREATE TABLE `sales_records` (\\n `sales_id` bigint(20) unsigned NOT NULL COMMENT 'id of sales person',\\n `product_id` bigint(64) COMMENT 'id of product',\\n `sale_date` datetime(3) COMMENT 'sold date and time',\\n `quantity_sold` int(11) COMMENT 'sold amount',\\n PRIMARY KEY (`sales_id`)\\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='销售记录表';\n.natural language description of the SQL requirement: 查询上月的销量总额第一名的销售员和他的销售总额\nA: {\n \"sql\":\"SELECT sales_id, SUM(quantity_sold) AS total_sales FROM sales_records WHERE MONTH(sale_date) = MONTH(CURRENT_DATE - INTERVAL 1 MONTH) AND YEAR(sale_date) = YEAR(CURRENT_DATE - INTERVAL 1 MONTH) GROUP BY sales_id ORDER BY total_sales DESC LIMIT 1\",\n \"err_code\":0,\n \"err_msg\":\"SQL query generated successfully\"\n}"
},
{
"role": "user",
"content": "help me implement NL2SQL.\ntable schema description:{{table_schema}}\nnatural language description of the SQL requirement: {{messages}}."
}
]
3.3.6 表格数据打包函数
// backend/domain/knowledge/service/retrieve.go
// 表格数据打包函数:将知识库文档转换为NL2SQL所需的表格结构
// 参数:
// - doc: 知识库文档,包含表格信息
// 返回:
// - *document.TableSchema: 表格结构信息,包含表名、描述、字段列表
func packNL2SqlRequest(doc *model.KnowledgeDocument) *document.TableSchema {
res := &document.TableSchema{} // 创建表格结构对象
if doc.TableInfo == nil {
return res
}
// 设置表格基本信息
res.Name = doc.TableInfo.VirtualTableName // 虚拟表名(逻辑表名)
res.Comment = doc.TableInfo.TableDesc // 表格描述
res.Columns = []*document.Column{} // 初始化字段列表
// 遍历表格的所有列,构建字段信息
for _, column := range doc.TableInfo.Columns {
// 跳过RDB字段ID,因为这是系统内部字段,不需要暴露给NL2SQL
if column.Name == consts.RDBFieldID {
continue
}
// 添加字段信息到表格结构中
res.Columns = append(res.Columns, &document.Column{
Name: column.Name, // 字段名
Type: column.Type, // 字段类型
Description: column.Description, // 字段描述
Nullable: !column.Indexing, // 是否可为空(与索引字段相反)
IsPrimary: false, // 是否为主键(这里统一设为false)
})
}
return res
}
3.3.7 SQL解析器实现
// backend/infra/impl/sqlparser/sql_parser.go
// SQL解析器:解析和修改SQL语句,实现表名和字段名的映射
// 参数:
// - sql: 原始SQL语句
// - tableColumns: 表名和字段的映射关系
// 返回:
// - string: 修改后的SQL语句
// - error: 错误信息
func (p *Impl) ParseAndModifySQL(sql string, tableColumns map[string]sqlparser.TableColumn) (string, error) {
if len(tableColumns) == 0 {
return sql, nil
}
for originalTableName, tableColumn := range tableColumns {
if originalTableName == "" {
return "", fmt.Errorf("original TableName must be non-empty")
}
if tableColumn.ColumnMap != nil {
for key, value := range tableColumn.ColumnMap {
if (key == "") != (value == "") {
return "", fmt.Errorf("ColumnMap key and value must be either both empty or both non-empty")
}
}
}
}
// 使用TiDB解析器解析SQL语句
stmt, err := p.parser.ParseOneStmt(sql, mysql.UTF8MB4Charset, mysql.UTF8MB4GeneralCICollation)
if err != nil {
return "", fmt.Errorf("failed to parse SQL: %v", err)
}
// 第一遍遍历:收集所有表别名
aliasCollector := NewAliasCollector()
stmt.Accept(aliasCollector)
// 检查别名冲突:原始表名不能与别名相同
for originalTableName, _ := range tableColumns {
if _, ok := aliasCollector.tableAliases[originalTableName]; ok {
return "", fmt.Errorf("alisa table name should not equal with origin table name")
}
}
// 第二遍遍历:使用收集的别名信息修改AST(抽象语法树)
modifier := NewSQLModifier(tableColumns, aliasCollector.tableAliases)
stmt.Accept(modifier)
// 将修改后的AST转换回SQL字符串
var sb strings.Builder
// 设置格式化标志:使用单引号、移除字符集前缀
flags := format.RestoreStringSingleQuotes | format.RestoreStringWithoutCharset
restoreCtx := format.NewRestoreCtx(flags, &sb)
err = stmt.Restore(restoreCtx)
if err != nil {
return "", fmt.Errorf("failed to restore SQL: %v", err)
}
return sb.String(), nil
}
3.3.8 RDB执行SQL实现
// backend/infra/impl/rdb/mysql.go
// RDB执行SQL:在关系数据库中执行SQL查询
// 参数:
// - ctx: 上下文
// - req: SQL执行请求,包含SQL语句和参数
// 返回:
// - *rdb.ExecuteSQLResponse: 执行结果
// - error: 错误信息
func (m *mysqlService) ExecuteSQL(ctx context.Context, req *rdb.ExecuteSQLRequest) (*rdb.ExecuteSQLResponse, error) {
if req == nil {
return nil, fmt.Errorf("invalid request")
}
logs.CtxInfof(ctx, "[ExecuteSQL] req is %v", req)
var processedSQL string // 处理后的SQL语句
var processedParams []interface{} // 处理后的参数
var err error
// 根据SQL类型处理参数
if req.SQLType == entity2.SQLType_Raw {
// 原始SQL类型:不处理参数,直接使用原始SQL
processedSQL = req.SQL
processedParams = nil
} else {
// 参数化SQL类型:处理切片参数
processedSQL, processedParams, err = m.processSliceParams(req.SQL, req.Params)
if err != nil {
return nil, fmt.Errorf("failed to process parameters: %v", err)
}
}
// 获取SQL操作类型(SELECT/INSERT/UPDATE/DELETE)
operation, err := sqlparser.NewSQLParser().GetSQLOperation(processedSQL)
if err != nil {
return nil, err
}
// 处理非SELECT操作(INSERT/UPDATE/DELETE)
if operation != sqlparsercontract.OperationTypeSelect {
// 执行非查询SQL,返回影响的行数
result := m.db.WithContext(ctx).Exec(processedSQL, processedParams...)
if result.Error != nil {
return nil, fmt.Errorf("failed to execute SQL: %v", result.Error)
}
// 构建结果集(非查询操作没有数据行)
resultSet := &entity2.ResultSet{
Columns: []string{}, // 空列名
Rows: []map[string]interface{}{}, // 空数据行
AffectedRows: result.RowsAffected, // 影响的行数
}
return &rdb.ExecuteSQLResponse{
ResultSet: resultSet,
}, nil
}
// 处理SELECT查询操作
// 执行查询并获取结果集
rows, err := m.db.WithContext(ctx).Raw(processedSQL, processedParams...).Rows()
if err != nil {
return nil, fmt.Errorf("failed to execute SQL: %v", err)
}
defer rows.Close() // 确保结果集被关闭
// 获取列名信息
columns, err := rows.Columns()
if err != nil {
return nil, fmt.Errorf("failed to get columns: %v", err)
}
// 初始化结果集
resultSet := &entity2.ResultSet{
Columns: columns, // 设置列名
Rows: make([]map[string]interface{}, 0), // 初始化数据行列表
}
// 遍历查询结果的每一行
for rows.Next() {
// 为当前行准备值容器
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i] // 创建指向值的指针
}
// 扫描当前行数据到值容器
if err := rows.Scan(valuePtrs...); err != nil {
return nil, fmt.Errorf("failed to scan row: %v", err)
}
// 将当前行数据转换为map格式
rowData := make(map[string]interface{})
for i, col := range columns {
rowData[col] = values[i] // 列名 -> 值的映射
}
resultSet.Rows = append(resultSet.Rows, rowData)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error while reading rows: %v", err)
}
// 返回查询结果
return &rdb.ExecuteSQLResponse{
ResultSet: resultSet,
}, nil
}
四、多知识库并发检索机制
多知识库并发检索机制深度解析:
retrieveChannels函数是ES检索和向量检索的核心实现,负责统一处理多个知识库的并发检索。这个函数在整个RAG系统中扮演着"检索调度器"的角色,它的主要作用是:
核心功能定位:
- 统一入口:为ES检索和向量检索提供统一的处理入口
- 并发调度:管理多个知识库的并发检索任务
- 资源控制:通过限制并发数避免系统资源过载
- 结果聚合:安全地合并来自不同知识库的检索结果
在RAG流程中的位置:
用户查询 → 2. 查询重写 → 3. 多知识库并发检索 → 4. 结果重排序 → 5. 结果打包
设计优势:
- 性能优化:通过并发检索显著提升响应速度
- 容错机制:单个知识库失败不影响整体检索
- 资源管理:限制并发数避免系统过载
- 精确检索:通过分区和DSL过滤确保检索精度
执行流程图:
详细执行步骤:
- 知识库遍历:遍历所有相关的知识库,为每个知识库创建独立的检索任务
- 分区过滤:根据文档ID创建分区列表,限制检索范围,提高检索效率
- DSL条件构建:构建文档ID的过滤条件,确保只检索指定文档
- 多字段匹配:对于表格文档,支持在多个索引字段中进行匹配
- 并发检索:使用
errgroup并发执行多个知识库的检索任务 - 结果合并:使用互斥锁安全地合并所有检索结果
核心优化策略:
- 分区检索:通过文档ID分区,减少检索范围,提高效率
- 并发处理:支持多个知识库的并发检索,最多同时处理2个知识库
- 字段索引:只对设置了索引的字段进行匹配,避免无效检索
- 错误隔离:单个知识库检索失败不影响其他知识库的检索
- 结果去重:通过文档ID确保结果唯一性
// backend/domain/knowledge/service/retrieve.go
// 多知识库并发检索函数 - 负责多知识库并发检索的核心实现
func (k *knowledgeSVC) retrieveChannels(ctx context.Context, req *RetrieveContext, manager searchstore.Manager) (result []*schema.Document, err error) {
// 1. 查询预处理:根据策略选择使用原始查询还是重写后的查询
query := req.OriginQuery
if req.Strategy.EnableQueryRewrite && req.RewrittenQuery != nil {
query = *req.RewrittenQuery
}
// 2. 并发控制:使用互斥锁保护结果合并,限制最大并发数为2
mu := sync.Mutex{}
eg, ctx := errgroup.WithContext(ctx)
eg.SetLimit(2) // 最多同时处理2个知识库,避免资源竞争
// 3. 遍历所有知识库,为每个知识库创建独立的检索任务
for knowledgeID, knowledgeInfo := range req.KnowledgeInfoMap {
kid := knowledgeID
info := knowledgeInfo
collectionName := getCollectionName(kid) // 获取知识库对应的集合名称
// 4. DSL条件构建:创建文档ID的过滤条件,确保只检索指定文档
dsl := &searchstore.DSL{
Op: searchstore.OpIn, // 使用IN操作符
Field: "document_id", // 过滤字段为document_id
Value: knowledgeInfo.DocumentIDs, // 过滤值为指定的文档ID列表
}
// 5. 分区过滤:根据文档ID创建分区列表,限制检索范围
partitions := make([]string, 0, len(req.Documents))
for _, doc := range req.Documents {
if doc.KnowledgeID == kid {
partitions = append(partitions, strconv.FormatInt(doc.ID, 10))
}
}
if len(partitions) == 0 {
continue // 如果没有相关文档,跳过当前知识库
}
// 6. 构建检索选项:设置分区键、分区列表和DSL过滤条件
opts := []retriever.Option{
searchstore.WithRetrieverPartitionKey(fieldNameDocumentID), // 设置分区键
searchstore.WithPartitions(partitions), // 设置分区列表
retriever.WithDSLInfo(dsl.DSL()), // 设置DSL过滤条件
}
// 7. 多字段匹配:对于表格文档,支持在多个索引字段中进行匹配
if info.DocumentType == knowledgeModel.DocumentTypeTable && !k.enableCompactTable {
var matchCols []string
for _, col := range info.TableColumns {
if col.Indexing { // 只对设置了索引的字段进行匹配
matchCols = append(matchCols, getColName(col.ID))
}
}
opts = append(opts, searchstore.WithMultiMatch(matchCols, query)) // 添加多字段匹配选项
}
// 8. 并发检索:使用goroutine并发执行检索任务
eg.Go(func() error {
// 获取当前知识库的搜索存储实例
ss, err := manager.GetSearchStore(ctx, collectionName)
if err != nil {
return errorx.New(errno.ErrKnowledgeSearchStoreCode, errorx.KV("msg", err.Error()))
}
// 执行检索操作
retrievedDocs, err := ss.Retrieve(ctx, query, opts...)
if err != nil {
return errorx.New(errno.ErrKnowledgeRetrieveExecFailCode, errorx.KV("msg", err.Error()))
}
// 9. 结果合并:使用互斥锁安全地合并检索结果
mu.Lock()
result = append(result, retrievedDocs...)
mu.Unlock()
return nil
})
}
// 10. 等待所有并发任务完成,检查是否有错误
if err = eg.Wait(); err != nil {
return nil, err
}
return
}
五、重排序与结果融合
5.1 重排序机制
重排序机制深度剖析:
- 结果收集:从并行执行的结果映射中提取各种检索方式的结果
- 数据转换:将文档转换为重排序器需要的格式,包含文档和分数信息
- 策略路由:根据检索策略选择要重排序的结果集:
- 语义搜索:只重排序向量检索结果
- 全文搜索:只重排序ES检索结果
- 混合搜索:重排序向量和ES检索结果
- NL2SQL:如果启用,则包含NL2SQL结果
- 查询选择:选择使用原始查询或重写后的查询
- 重排序执行:调用重排序器对所有结果进行重新排序
- 分数过滤:过滤掉分数低于最小阈值的结果
重排序算法原理:
- 多模态融合:将不同检索方式的结果进行统一排序
- 语义相关性:基于查询与文档的语义相似度进行重排序
- 分数归一化:将不同检索方式的分数进行标准化处理
- 阈值过滤:通过最小分数阈值过滤低质量结果
- 结果去重:去除重复的文档,确保结果唯一性
// backend/domain/knowledge/service/retrieve.go
func (k *knowledgeSVC) reRankNode(ctx context.Context, resultMap map[string]any) (retrieveResult []*schema.Document, err error) {
// 首先获取下retrieve上下文
retrieveCtx, ok := resultMap["passRequestContext"].(*RetrieveContext)
if !ok {
logs.CtxErrorf(ctx, "retrieve context is not found")
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "retrieve context is not found"))
}
// 获取下向量化召回的接口
vectorRetrieveResult, ok := resultMap["vectorRetrieveNode"].([]*schema.Document)
if !ok {
logs.CtxErrorf(ctx, "vector retrieve result is not found")
vectorRetrieveResult = []*schema.Document{}
}
// 获取下es召回的接口
esRetrieveResult, ok := resultMap["esRetrieveNode"].([]*schema.Document)
if !ok {
logs.CtxErrorf(ctx, "es retrieve result is not found")
esRetrieveResult = []*schema.Document{}
}
// 获取下nl2sql召回的接口
nl2SqlRetrieveResult, ok := resultMap["nl2SqlRetrieveNode"].([]*schema.Document)
if !ok {
logs.CtxErrorf(ctx, "nl2sql retrieve result is not found")
nl2SqlRetrieveResult = []*schema.Document{}
}
docs2RerankData := func(docs []*schema.Document) []*rerank.Data {
data := make([]*rerank.Data, 0, len(docs))
for i := range docs {
doc := docs[i]
data = append(data, &rerank.Data{Document: doc, Score: doc.Score()})
}
return data
}
// 根据召回策略从不同渠道获取召回结果
var retrieveResultArr [][]*rerank.Data
if retrieveCtx.Strategy.EnableNL2SQL {
// nl2sql结果
retrieveResultArr = append(retrieveResultArr, docs2RerankData(nl2SqlRetrieveResult))
}
switch retrieveCtx.Strategy.SearchType {
case knowledgeModel.SearchTypeSemantic:
retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
case knowledgeModel.SearchTypeFullText:
retrieveResultArr = append(retrieveResultArr, docs2RerankData(esRetrieveResult))
case knowledgeModel.SearchTypeHybrid:
retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
retrieveResultArr = append(retrieveResultArr, docs2RerankData(esRetrieveResult))
default:
retrieveResultArr = append(retrieveResultArr, docs2RerankData(vectorRetrieveResult))
}
query := retrieveCtx.OriginQuery
if retrieveCtx.Strategy.EnableQueryRewrite && retrieveCtx.RewrittenQuery != nil {
query = ptr.From(retrieveCtx.RewrittenQuery)
}
resp, err := k.reranker.Rerank(ctx, &rerank.Request{
Query: query,
Data: retrieveResultArr,
TopN: retrieveCtx.Strategy.TopK,
})
if err != nil {
logs.CtxErrorf(ctx, "rerank failed: %v", err)
return nil, err
}
retrieveResult = make([]*schema.Document, 0, len(resp.SortedData))
for _, item := range resp.SortedData {
if item.Score < ptr.From(retrieveCtx.Strategy.MinScore) {
continue
}
doc := item.Document
doc.WithScore(item.Score)
retrieveResult = append(retrieveResult, doc)
}
return retrieveResult, nil
}
5.2 RRF重排序实现
// backend/infra/impl/document/rerank/rrf/rrf.go
// RRF重排序器实现
func NewRRFReranker(k int64) rerank.Reranker {
if k == 0 {
k = 60
}
return &rrfReranker{k}
}
type rrfReranker struct {
k int64
}
func (r *rrfReranker) Rerank(ctx context.Context, req *rerank.Request) (*rerank.Response, error) {
if req == nil || req.Data == nil || len(req.Data) == 0 {
return nil, fmt.Errorf("invalid request: no data provided")
}
// 计算每个文档的RRF分数
id2Score := make(map[string]float64)
id2Data := make(map[string]*rerank.Data)
for _, resultList := range req.Data {
for rank := range resultList {
result := resultList[rank]
if result != nil && result.Document != nil {
// RRF公式:1 / (k + rank)
score := 1.0 / (float64(rank) + float64(r.k))
if score > id2Score[result.Document.ID] {
id2Score[result.Document.ID] = score
id2Data[result.Document.ID] = result
}
}
}
}
// 按分数排序
var sorted []*rerank.Data
for _, data := range id2Data {
sorted = append(sorted, data)
}
sort.Slice(sorted, func(i, j int) bool {
return id2Score[sorted[i].Document.ID] > id2Score[sorted[j].Document.ID]
})
// 限制返回数量
topN := int64(len(sorted))
if req.TopN != nil && ptr.From(req.TopN) != 0 && ptr.From(req.TopN) < topN {
topN = ptr.From(req.TopN)
}
return &rerank.Response{SortedData: sorted[:topN]}, nil
}
RRF算法说明:
- k参数:控制不同来源的权重,k越大,排名差异越小(默认60)
- RRF公式:
1 / (k + rank),排名越靠前分数越高 - 分数聚合:同一文档在不同来源中取RRF分数最高的作为最终分数(Coze Studio定制实现)
- 排序输出:按RRF分数降序排列,返回TopN结果
RRF算法示例:
假设有三个检索来源(向量检索、ES检索、NL2SQL检索),每个来源返回以下结果:
| 来源 | 排名 | 文档ID | 原始分数 |
|---|---|---|---|
| 向量检索 | 1 | doc_001 | 0.95 |
| 向量检索 | 2 | doc_002 | 0.87 |
| 向量检索 | 3 | doc_003 | 0.76 |
| ES检索 | 1 | doc_002 | 0.92 |
| ES检索 | 2 | doc_001 | 0.88 |
| ES检索 | 3 | doc_004 | 0.81 |
| NL2SQL检索 | 1 | doc_003 | 0.89 |
| NL2SQL检索 | 2 | doc_005 | 0.85 |
RRF分数计算(k=60):
| 文档ID | 向量检索RRF | ES检索RRF | NL2SQL检索RRF | 最终RRF分数(取最高) |
|---|---|---|---|---|
| doc_001 | 1/(60+1) = 0.0164 | 1/(60+2) = 0.0161 | - | 0.0164 |
| doc_002 | 1/(60+2) = 0.0161 | 1/(60+1) = 0.0164 | - | 0.0164 |
| doc_003 | 1/(60+3) = 0.0159 | - | 1/(60+1) = 0.0164 | 0.0164 |
| doc_004 | - | 1/(60+3) = 0.0159 | - | 0.0159 |
| doc_005 | - | - | 1/(60+2) = 0.0161 | 0.0161 |
最终排序结果:
- doc_001 (RRF: 0.0164)
- doc_002 (RRF: 0.0164)
- doc_003 (RRF: 0.0164)
- doc_005 (RRF: 0.0161)
- doc_004 (RRF: 0.0159)
算法特点:
- 取最高分数:同一文档在不同来源中取RRF分数最高的作为最终分数
- 排名敏感:排名越靠前,RRF分数越高
- 去重处理:避免同一文档重复出现
- 可调节性:通过k参数控制排名差异的影响程度
5.3 结果打包与格式化
结果打包的作用:
- 去重处理:去除重复的检索结果
- 元数据丰富:添加文档信息、来源说明等
- 数量控制:限制返回结果数量
- 格式统一:统一结果格式,便于后续处理
// backend/domain/knowledge/service/retrieve.go
// 结果打包节点
func (k *knowledgeSVC) packResults(ctx context.Context, retrieveResult []*schema.Document) (results []*knowledgeModel.RetrieveSlice, err error) {
if len(retrieveResult) == 0 {
return nil, nil
}
// 1. 收集所有相关的ID
sliceIDs := make(sets.Set[int64])
docIDs := make(sets.Set[int64])
knowledgeIDs := make(sets.Set[int64])
documentMap := map[int64]*model.KnowledgeDocument{}
knowledgeMap := map[int64]*model.Knowledge{}
sliceScoreMap := map[int64]float64{}
for _, doc := range retrieveResult {
id, err := strconv.ParseInt(doc.ID, 10, 64)
if err != nil {
logs.CtxErrorf(ctx, "convert id failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeSystemCode, errorx.KV("msg", "convert id failed"))
}
sliceIDs[id] = struct{}{}
sliceScoreMap[id] = doc.Score()
}
// 2. 批量查询相关数据
slices, err := k.sliceRepo.MGetSlices(ctx, sliceIDs.ToSlice())
if err != nil {
logs.CtxErrorf(ctx, "mget slices failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
}
for _, slice := range slices {
docIDs[slice.DocumentID] = struct{}{}
knowledgeIDs[slice.KnowledgeID] = struct{}{}
}
// 3. 查询知识库和文档信息
knowledgeModels, err := k.knowledgeRepo.FilterEnableKnowledge(ctx, knowledgeIDs.ToSlice())
if err != nil {
logs.CtxErrorf(ctx, "filter enable knowledge failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
}
for _, kn := range knowledgeModels {
knowledgeMap[kn.ID] = kn
}
documents, err := k.documentRepo.MGetByID(ctx, docIDs.ToSlice())
if err != nil {
logs.CtxErrorf(ctx, "mget documents failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeDBCode, errorx.KV("msg", err.Error()))
}
for _, doc := range documents {
documentMap[doc.ID] = doc
}
// 4. 构建结果
results = []*knowledgeModel.RetrieveSlice{}
for i := range slices {
doc := documentMap[slices[i].DocumentID]
kn := knowledgeMap[slices[i].KnowledgeID]
sliceEntity := entity.Slice{
Info: knowledgeModel.Info{
ID: slices[i].ID,
CreatorID: slices[i].CreatorID,
SpaceID: doc.SpaceID,
AppID: kn.AppID,
CreatedAtMs: slices[i].CreatedAt,
UpdatedAtMs: slices[i].UpdatedAt,
},
KnowledgeID: slices[i].KnowledgeID,
DocumentID: slices[i].DocumentID,
DocumentName: doc.Name,
Sequence: int64(slices[i].Sequence),
ByteCount: int64(len(slices[i].Content)),
SliceStatus: knowledgeModel.SliceStatus(slices[i].Status),
CharCount: int64(utf8.RuneCountInString(slices[i].Content)),
}
// 5. 添加额外信息
docUri := documentMap[slices[i].DocumentID].URI
var docURL string
if len(docUri) != 0 {
docURL, err = k.storage.GetObjectUrl(ctx, docUri)
if err != nil {
logs.CtxErrorf(ctx, "get object url failed: %v", err)
return nil, errorx.New(errno.ErrKnowledgeGetObjectURLFailCode, errorx.KV("msg", err.Error()))
}
}
sliceEntity.Extra = map[string]string{
consts.KnowledgeName: kn.Name,
consts.DocumentURL: docURL,
}
// 6. 根据文档类型处理内容
switch knowledgeModel.DocumentType(doc.DocumentType) {
case knowledgeModel.DocumentTypeText:
sliceEntity.RawContent = []*knowledgeModel.SliceContent{
{Type: knowledgeModel.SliceContentTypeText, Text: ptr.Of(k.formatSliceContent(ctx, slices[i].Content))},
}
case knowledgeModel.DocumentTypeTable:
// 表格数据处理逻辑
case knowledgeModel.DocumentTypeImage:
img := fmt.Sprintf(`<img src="" data-tos-key="%s">`, documentMap[slices[i].DocumentID].URI)
sliceEntity.RawContent = []*knowledgeModel.SliceContent{
{Type: knowledgeModel.SliceContentTypeText, Text: ptr.Of(k.formatSliceContent(ctx, img+slices[i].Content))},
}
default:
}
results = append(results, &knowledgeModel.RetrieveSlice{
Slice: &sliceEntity,
Score: sliceScoreMap[slices[i].ID],
})
}
// 7. 更新命中计数
err = k.sliceRepo.IncrementHitCount(ctx, sliceIDs.ToSlice())
if err != nil {
logs.CtxWarnf(ctx, "increment hit count failed: %v", err)
}
return results, nil
}
更多推荐
所有评论(0)