Spring AI(八) Rag查询优化
上文我们讲了结合阿里的tair实现数据向量化和查询今天我们讲讲,Rag相关查询的优化。先不讲技术,先聊聊我对查询优化的理解。
·
Spring AI 的RAG实现集合火山向量模型+阿里云Tair(企业版)上文我们讲了结合阿里的tair实现数据向量化和查询
今天我们讲讲,Rag相关查询的优化。先不讲技术,先聊聊我对查询优化的理解。
- 就是我们说的要结合提问人的整体聊天记录来进行思考。在技术层面的表达:就是上下文对话,来确认当前提问的整体含义。
- 我们在聊天的时候是深入思考对方提问的真正含义,去掉毫无意义的表达,也就是对方真正想问什么。在技术层面的表达是压缩(去掉无意义的),改写(真正想问的)。
- 然后我们在对方的提问的时候,尽量把答案给的完善。再技术层面的表达是:扩展查询。
查询改造
先构造查询实体
@Data
public class QueryRequest {
private String query; // 用户查询文本
private List<HistoryItem> history = new ArrayList<>(); // 历史对话(多轮场景)
@Data
public static class HistoryItem {
private String role; // "user" 或 "assistant"
private String content; // 对话内容
}
}
第一步实现:结合历史对话,转化问询
private Query buildOriginalQuery(QueryRequest request) {
List<Message> historyMessages = new ArrayList<>();
request.getHistory().forEach(historyItem -> {
Message msg = null;
if("user".equals(historyItem.getRole())){
msg = new UserMessage(historyItem.getContent());
}else{
msg = new AssistantMessage(historyItem.getContent());
}
historyMessages.add(msg);
});
return Query.builder()
.text(request.getQuery())
.history(historyMessages)
.build();
}
这样就把包含历史对话信息,和当前问题结合起来,转化成一个独立的新的查询。
// 1. 构建原始查询(含历史对话)
Query originalQuery = buildOriginalQuery(request);
第二步:压缩
@Bean
public QueryTransformer compressionQueryTransformer() {
return CompressionQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder)
.build();
}
对原始查询进行压缩
Query compressedQuery = compressionQueryTransformer.transform(originalQuery);
第三步:改写
@Bean
public QueryTransformer rewriteQueryTransformer() {
return RewriteQueryTransformer.builder()
.chatClientBuilder(chatClientBuilder.build().mutate())
.build();
}
对压缩之后的查询进行改写
Query rewrittenQuery = rewriteQueryTransformer.transform(compressedQuery);
第四步:查询扩展
@Bean
public QueryExpander multiQueryExpander() {
return MultiQueryExpander.builder()
.chatClientBuilder(chatClientBuilder)
.numberOfQueries(3) // 生成 3 个变体查询
.includeOriginal(false) // 不包含原始查询
.build();
}
对改写的查询之后进行扩展
List<Query> expandedQueries = multiQueryExpander.expand(rewrittenQuery);
向量查询
到这里对于查询提问的处理就已经结束了,那么如果要使用文档查询,也就是向量查询。我们要构建文档查询器。
@Bean
public DocumentRetriever vectorStoreDocumentRetriever() {
return TairVectorStoreDocumentRetriever
.builder()
.vectorStore(tairVectorStore)
.topK(5) // 召回 top5 文档
.similarityThreshold(0.7) // 相似度阈值(平衡场景默认值)
.build();
}
因为这里我们用的是TairVectorStore做的向量库,所以重写下DocumentRetriever的实现
import com.alibaba.cloud.ai.vectorstore.tair.TairVectorStore;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import java.util.List;
import java.util.function.Supplier;
public final class TairVectorStoreDocumentRetriever implements DocumentRetriever {
public static final String FILTER_EXPRESSION = "vector_store_filter_expression";
private final TairVectorStore tairVectorStore;
private final Double similarityThreshold;
private final Integer topK;
private final Supplier<Filter.Expression> filterExpression;
public TairVectorStoreDocumentRetriever(TairVectorStore tairVectorStore, @Nullable Double similarityThreshold,
@Nullable Integer topK, @Nullable Supplier<Filter.Expression> filterExpression) {
Assert.notNull(tairVectorStore, "vectorStore cannot be null");
Assert.isTrue(similarityThreshold == null || similarityThreshold >= 0.0,
"similarityThreshold must be equal to or greater than 0.0");
Assert.isTrue(topK == null || topK > 0, "topK must be greater than 0");
this.tairVectorStore = tairVectorStore;
this.similarityThreshold = similarityThreshold != null ? similarityThreshold
: SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL;
this.topK = topK != null ? topK : SearchRequest.DEFAULT_TOP_K;
this.filterExpression = filterExpression != null ? filterExpression : () -> null;
}
@Override
public List<Document> retrieve(Query query) {
Assert.notNull(query, "query cannot be null");
var requestFilterExpression = computeRequestFilterExpression(query);
var searchRequest = SearchRequest.builder()
.query(query.text())
.filterExpression(requestFilterExpression)
.similarityThreshold(this.similarityThreshold)
.topK(this.topK)
.build();
return this.tairVectorStore.doSimilaritySearch(searchRequest);
}
private Filter.Expression computeRequestFilterExpression(Query query) {
var contextFilterExpression = query.context().get(FILTER_EXPRESSION);
if (contextFilterExpression != null) {
if (contextFilterExpression instanceof Filter.Expression) {
return (Filter.Expression) contextFilterExpression;
}
else if (StringUtils.hasText(contextFilterExpression.toString())) {
return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
}
}
return this.filterExpression.get();
}
public static Builder builder() {
return new Builder();
}
public static final class Builder {
private TairVectorStore tairVectorStore;
private Double similarityThreshold;
private Integer topK;
private Supplier<Filter.Expression> filterExpression;
private Builder() {
}
public Builder vectorStore(TairVectorStore tairVectorStore) {
this.tairVectorStore = tairVectorStore;
return this;
}
public Builder similarityThreshold(Double similarityThreshold) {
this.similarityThreshold = similarityThreshold;
return this;
}
public Builder topK(Integer topK) {
this.topK = topK;
return this;
}
public Builder filterExpression(Filter.Expression filterExpression) {
this.filterExpression = () -> filterExpression;
return this;
}
public Builder filterExpression(Supplier<Filter.Expression> filterExpression) {
this.filterExpression = filterExpression;
return this;
}
public TairVectorStoreDocumentRetriever build() {
return new TairVectorStoreDocumentRetriever(this.tairVectorStore, this.similarityThreshold, this.topK,
this.filterExpression);
}
}
}
其实核心改动就是改下了查询,把默认similaritySearch改成doSimilaritySearch
然后把VectorStore改成tairVectorStore
@Override
public List<Document> retrieve(Query query) {
Assert.notNull(query, "query cannot be null");
var requestFilterExpression = computeRequestFilterExpression(query);
var searchRequest = SearchRequest.builder()
.query(query.text())
.filterExpression(requestFilterExpression)
.similarityThreshold(this.similarityThreshold)
.topK(this.topK)
.build();
return this.tairVectorStore.doSimilaritySearch(searchRequest);
}
因为我们对Query进行扩展了,所以这里直接进行多查询并行检索
List<List<Document>> allDocuments = expandedQueries
.stream()
.map(documentRetriever::retrieve)
.toList();
然后对查询结果进行合并
@Bean
public DocumentJoiner concatenationDocumentJoiner() {
return new ConcatenationDocumentJoiner();
}
Map<Query, List<List<Document>>> documentsForQuery = new HashMap<>();
documentsForQuery.put(rewrittenQuery,allDocuments);
List<Document> joinedDocuments = documentJoiner.join(documentsForQuery);
到这里对向量查询的处理结束。
LLM处理
最后,调用LLM进行查询和结果处理。
构造提示词
/**
* 构造提示词
* @param query
* @param documents
* @return
*/
private String buildGenerationPrompt(String query, List<Document> documents) {
StringBuilder context = new StringBuilder();
for (Document doc : documents) {
context.append("文档 ID:").append(doc.getId()).append("\\\\n")
.append("内容:").append(doc.getText()).append("\\\\n\\\\n");
}
return """
请基于以下参考文档回答用户问题,确保答案准确且有依据。
若文档无法回答问题,请直接说明,不要编造内容。
参考文档:
%s
用户问题:%s
""".formatted(context.toString(), query);
}
// 4. 生成阶段:拼装 Prompt 并调用 LLM
String prompt = buildGenerationPrompt(rewrittenQuery.text(), joinedDocuments);
String answer = chatClient.prompt()
.user(prompt)
.call()
.content();
到这里就基本完事了
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)