Spring AI 的RAG实现集合火山向量模型+阿里云Tair(企业版)上文我们讲了结合阿里的tair实现数据向量化和查询

今天我们讲讲,Rag相关查询的优化。先不讲技术,先聊聊我对查询优化的理解。

  1. 就是我们说的要结合提问人的整体聊天记录来进行思考。在技术层面的表达:就是上下文对话,来确认当前提问的整体含义。
  2. 我们在聊天的时候是深入思考对方提问的真正含义,去掉毫无意义的表达,也就是对方真正想问什么。在技术层面的表达是压缩(去掉无意义的),改写(真正想问的)。
  3. 然后我们在对方的提问的时候,尽量把答案给的完善。再技术层面的表达是:扩展查询。

查询改造

先构造查询实体

@Data
public class QueryRequest {

    private String query;  // 用户查询文本
 
    private List<HistoryItem> history = new ArrayList<>();  // 历史对话(多轮场景)
 
    @Data
    public static class HistoryItem {
        private String role;  // "user" 或 "assistant"
        private String content;  // 对话内容
    }
}

第一步实现:结合历史对话,转化问询

 private Query buildOriginalQuery(QueryRequest request) {
        List<Message> historyMessages = new ArrayList<>();
        request.getHistory().forEach(historyItem -> {
            Message msg = null;
            if("user".equals(historyItem.getRole())){
                msg = new UserMessage(historyItem.getContent());
            }else{
                msg = new AssistantMessage(historyItem.getContent());
            }
            historyMessages.add(msg);
        });
        return Query.builder()
                .text(request.getQuery())
                .history(historyMessages)
                .build();
    }

这样就把包含历史对话信息,和当前问题结合起来,转化成一个独立的新的查询。

   // 1. 构建原始查询(含历史对话)
 Query originalQuery = buildOriginalQuery(request);

第二步:压缩

    @Bean
    public QueryTransformer compressionQueryTransformer() {
        return CompressionQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder)
                .build();
    }

对原始查询进行压缩

 Query compressedQuery = compressionQueryTransformer.transform(originalQuery);

第三步:改写

    @Bean
    public QueryTransformer rewriteQueryTransformer() {
        return RewriteQueryTransformer.builder()
                .chatClientBuilder(chatClientBuilder.build().mutate())
                .build();
    }

对压缩之后的查询进行改写

Query rewrittenQuery = rewriteQueryTransformer.transform(compressedQuery);

第四步:查询扩展

    @Bean
    public QueryExpander multiQueryExpander() {
        return MultiQueryExpander.builder()
                .chatClientBuilder(chatClientBuilder)
                .numberOfQueries(3)  // 生成 3 个变体查询
                .includeOriginal(false)  // 不包含原始查询
                .build();
    }

对改写的查询之后进行扩展

List<Query> expandedQueries = multiQueryExpander.expand(rewrittenQuery);

向量查询

到这里对于查询提问的处理就已经结束了,那么如果要使用文档查询,也就是向量查询。我们要构建文档查询器。

    @Bean
    public DocumentRetriever vectorStoreDocumentRetriever() {
        return TairVectorStoreDocumentRetriever
                .builder()
                .vectorStore(tairVectorStore)
                .topK(5)  // 召回 top5 文档
                .similarityThreshold(0.7)  // 相似度阈值(平衡场景默认值)
                .build();
    }

因为这里我们用的是TairVectorStore做的向量库,所以重写下DocumentRetriever的实现

import com.alibaba.cloud.ai.vectorstore.tair.TairVectorStore;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionTextParser;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import java.util.List;
import java.util.function.Supplier;


public final class TairVectorStoreDocumentRetriever implements DocumentRetriever {

	public static final String FILTER_EXPRESSION = "vector_store_filter_expression";

	private final TairVectorStore tairVectorStore;

	private final Double similarityThreshold;

	private final Integer topK;


	private final Supplier<Filter.Expression> filterExpression;

	public TairVectorStoreDocumentRetriever(TairVectorStore tairVectorStore, @Nullable Double similarityThreshold,
											@Nullable Integer topK, @Nullable Supplier<Filter.Expression> filterExpression) {
		Assert.notNull(tairVectorStore, "vectorStore cannot be null");
		Assert.isTrue(similarityThreshold == null || similarityThreshold >= 0.0,
				"similarityThreshold must be equal to or greater than 0.0");
		Assert.isTrue(topK == null || topK > 0, "topK must be greater than 0");
		this.tairVectorStore = tairVectorStore;
		this.similarityThreshold = similarityThreshold != null ? similarityThreshold
				: SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL;
		this.topK = topK != null ? topK : SearchRequest.DEFAULT_TOP_K;
		this.filterExpression = filterExpression != null ? filterExpression : () -> null;
	}

	@Override
	public List<Document> retrieve(Query query) {
		Assert.notNull(query, "query cannot be null");
		var requestFilterExpression = computeRequestFilterExpression(query);
		var searchRequest = SearchRequest.builder()
			.query(query.text())
			.filterExpression(requestFilterExpression)
			.similarityThreshold(this.similarityThreshold)
			.topK(this.topK)
			.build();
		return this.tairVectorStore.doSimilaritySearch(searchRequest);
	}

	private Filter.Expression computeRequestFilterExpression(Query query) {
		var contextFilterExpression = query.context().get(FILTER_EXPRESSION);
		if (contextFilterExpression != null) {
			if (contextFilterExpression instanceof Filter.Expression) {
				return (Filter.Expression) contextFilterExpression;
			}
			else if (StringUtils.hasText(contextFilterExpression.toString())) {
				return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
			}
		}
		return this.filterExpression.get();
	}

	public static Builder builder() {
		return new Builder();
	}

	public static final class Builder {

		private TairVectorStore tairVectorStore;

		private Double similarityThreshold;

		private Integer topK;

		private Supplier<Filter.Expression> filterExpression;

		private Builder() {
		}

		public Builder vectorStore(TairVectorStore tairVectorStore) {
			this.tairVectorStore = tairVectorStore;
			return this;
		}

		public Builder similarityThreshold(Double similarityThreshold) {
			this.similarityThreshold = similarityThreshold;
			return this;
		}

		public Builder topK(Integer topK) {
			this.topK = topK;
			return this;
		}

		public Builder filterExpression(Filter.Expression filterExpression) {
			this.filterExpression = () -> filterExpression;
			return this;
		}

		public Builder filterExpression(Supplier<Filter.Expression> filterExpression) {
			this.filterExpression = filterExpression;
			return this;
		}

		public TairVectorStoreDocumentRetriever build() {
			return new TairVectorStoreDocumentRetriever(this.tairVectorStore, this.similarityThreshold, this.topK,
					this.filterExpression);
		}

	}

}

其实核心改动就是改下了查询,把默认similaritySearch改成doSimilaritySearch

然后把VectorStore改成tairVectorStore

@Override
	public List<Document> retrieve(Query query) {
		Assert.notNull(query, "query cannot be null");
		var requestFilterExpression = computeRequestFilterExpression(query);
		var searchRequest = SearchRequest.builder()
			.query(query.text())
			.filterExpression(requestFilterExpression)
			.similarityThreshold(this.similarityThreshold)
			.topK(this.topK)
			.build();
		return this.tairVectorStore.doSimilaritySearch(searchRequest);
	}

因为我们对Query进行扩展了,所以这里直接进行多查询并行检索

  List<List<Document>> allDocuments = expandedQueries
                    .stream()
                    .map(documentRetriever::retrieve)
                    .toList();

然后对查询结果进行合并

   @Bean
    public DocumentJoiner concatenationDocumentJoiner() {
        return new ConcatenationDocumentJoiner();
    }
Map<Query, List<List<Document>>> documentsForQuery = new HashMap<>();
documentsForQuery.put(rewrittenQuery,allDocuments);
List<Document> joinedDocuments = documentJoiner.join(documentsForQuery);

到这里对向量查询的处理结束。

LLM处理

最后,调用LLM进行查询和结果处理。

构造提示词

/**
     * 构造提示词
     * @param query
     * @param documents
     * @return
     */
    private String buildGenerationPrompt(String query, List<Document> documents) {
        StringBuilder context = new StringBuilder();
        for (Document doc : documents) {
            context.append("文档 ID:").append(doc.getId()).append("\\\\n")
                    .append("内容:").append(doc.getText()).append("\\\\n\\\\n");
        }

        return """
                请基于以下参考文档回答用户问题,确保答案准确且有依据。
                若文档无法回答问题,请直接说明,不要编造内容。
                参考文档:
                %s
                用户问题:%s
                """.formatted(context.toString(), query);
    }
 // 4. 生成阶段:拼装 Prompt 并调用 LLM
            String prompt = buildGenerationPrompt(rewrittenQuery.text(), joinedDocuments);
            String answer = chatClient.prompt()
                    .user(prompt)
                    .call()
                    .content();

到这里就基本完事了

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐