spring ai alibaba agent之ReflectAgent源码解读

如果不熟悉的可以看下官方写的这篇文章[关于spring ai alibaba graph 使用指南于源码解读](Spring AI Alibaba Graph 使用指南与源码解读),这边文章会让你对底层的设计和实现有所了解。本文讲的更像一个ReflectAgent使用指南,目前官网没有关于正式的使用文档,只是给出一些使用demo,这篇文章会分析整个ReflectAgent流程和使用流程注意。这个agent适合场景写作、nl2sql生成等。

官网使用示例

流程图如下:

在这里插入图片描述

官网给出的使用demo,主要代码:

@Configuration
public class RelectionAutoconfiguration {

	public static class AssistantGraphNode implements NodeAction {

		private final LlmNode llmNode;

		private SystemPromptTemplate systemPromptTemplate;

		private final String NODE_ID = "call_model";

		private static final String CLASSIFIER_PROMPT_TEMPLATE = """
					You are an essay assistant tasked with writing excellent 5-paragraph essays.
				    Generate the best essay possible for the user's request.
				    If the user provides critique, respond with a revised version of your previous attempts.
				    Only return the main content I need, without adding any other interactive language.
				    Please answer in Chinese:
				""";

		public AssistantGraphNode(ChatClient chatClient) {
			this.systemPromptTemplate = new SystemPromptTemplate(CLASSIFIER_PROMPT_TEMPLATE);
			this.llmNode = LlmNode.builder()
				.systemPromptTemplate(systemPromptTemplate.render())
				.chatClient(chatClient)
				.messagesKey("messages")
				.build();
		}

		public static Builder builder() {
			return new Builder();
		}

		public static class Builder {

			private ChatClient chatClient;

			public Builder chatClient(ChatClient chatClient) {
				this.chatClient = chatClient;
				return this;
			}

			public AssistantGraphNode build() {
				if (chatClient == null) {
					throw new IllegalArgumentException("ChatClient must be provided");
				}
				return new AssistantGraphNode(chatClient);
			}

		}

		@Override
		public Map<String, Object> apply(OverAllState overAllState) throws Exception {

			List<Message> messages = (List<Message>) overAllState.value(MESSAGES).get();

			StateGraph stateGraph = new StateGraph(() -> {
				Map<String, KeyStrategy> strategies = new HashMap<>();
				strategies.put(MESSAGES, new AppendStrategy());
				return strategies;
			}).addNode(this.NODE_ID, node_async(llmNode)).addEdge(START, this.NODE_ID).addEdge(this.NODE_ID, END);

			OverAllState invokeState = stateGraph.compile().invoke(Map.of(MESSAGES, messages)).get();
			List<Message> reactMessages = (List<Message>) invokeState.value(MESSAGES).orElseThrow();

			return Map.of(MESSAGES, reactMessages);

		}

	}

	public static class JudgeGraphNode implements NodeAction {

		private final LlmNode llmNode;

		private final String NODE_ID = "judge_response";

		private SystemPromptTemplate systemPromptTemplate;

		private static final String CLASSIFIER_PROMPT_TEMPLATE = """
					You are a teacher grading a student's essay submission. Provide detailed feedback and revision suggestions for the essay.

					Your feedback should cover the following aspects:

					- Length : Is the essay sufficiently developed? Does it meet the required length or need expansion/shortening?
					- Depth : Are the ideas well-developed? Is there sufficient analysis, evidence, or explanation?
					- Structure : Is the organization logical and clear? Are the introduction, transitions, and conclusion effective?
					- Style and Tone : Is the writing style appropriate for the purpose and audience? Is the tone consistent and professional?
					- Language Use : Are vocabulary, grammar, and sentence structure accurate and varied?
					- Focus only on providing actionable suggestions for improvement. Do not include grades, scores, or overall summary evaluations.

					Please respond in Chinese .
				""";

		public JudgeGraphNode(ChatClient chatClient) {
			this.systemPromptTemplate = new SystemPromptTemplate(CLASSIFIER_PROMPT_TEMPLATE);
			this.llmNode = LlmNode.builder()
				.chatClient(chatClient)
				.systemPromptTemplate(systemPromptTemplate.render())
				.messagesKey(MESSAGES)
				.build();

		}

		public static Builder builder() {
			return new Builder();
		}

		public static class Builder {

			private ChatClient chatClient;

			public JudgeGraphNode.Builder chatClient(ChatClient chatClient) {
				this.chatClient = chatClient;
				return this;
			}

			public JudgeGraphNode build() {
				if (chatClient == null) {
					throw new IllegalArgumentException("ChatClient must be provided");
				}
				return new JudgeGraphNode(chatClient);
			}

		}

		@Override
		public Map<String, Object> apply(OverAllState allState) throws Exception {
			List<Message> messages = (List<Message>) allState.value(MESSAGES).get();

			StateGraph stateGraph = new StateGraph(() -> {
				Map<String, KeyStrategy> strategies = new HashMap<>();
				strategies.put(MESSAGES, new AppendStrategy());
				return strategies;
			}).addNode(this.NODE_ID, node_async(llmNode)).addEdge(START, this.NODE_ID).addEdge(this.NODE_ID, END);

			CompiledGraph compile = stateGraph.compile();

			OverAllState invokeState = compile.invoke(Map.of(MESSAGES, messages)).get();

			UnaryOperator<List<Message>> convertLastToUserMessage = messageList -> {
				int size = messageList.size();
				if (size == 0)
					return messageList;
				Message last = messageList.get(size - 1);
				messageList.set(size - 1, new UserMessage(last.getText()));
				return messageList;
			};

			List<Message> reactMessages = (List<Message>) invokeState.value(MESSAGES).orElseThrow();
			convertLastToUserMessage.apply(reactMessages);

			return Map.of(MESSAGES, reactMessages);

		}

	}

	@Bean
	public CompiledGraph reflectGraph(ChatModel chatModel) throws GraphStateException {

		ChatClient chatClient = ChatClient.builder(chatModel)
			.defaultAdvisors(new SimpleLoggerAdvisor())
			.defaultOptions(OpenAiChatOptions.builder().internalToolExecutionEnabled(false).build())
			.build();

		AssistantGraphNode assistantGraphNode = AssistantGraphNode.builder().chatClient(chatClient).build();
		JudgeGraphNode judgeGraphNode = JudgeGraphNode.builder().chatClient(chatClient).build();

		ReflectAgent reflectAgent = ReflectAgent.builder()
			.graph(assistantGraphNode)
			.reflection(judgeGraphNode)
			.maxIterations(2)
			.build();

		return reflectAgent.getAndCompileGraph();
	}

}

测试入口:

@RestController
@RequestMapping("/reflection")
public class ReflectionController {

	private static final Logger logger = LoggerFactory.getLogger(ReflectionController.class);

	private CompiledGraph compiledGraph;

	public ReflectionController(@Qualifier("reflectGraph") CompiledGraph compiledGraph) {
		this.compiledGraph = compiledGraph;
	}

	@GetMapping("/chat")
	public String simpleChat(String query) throws GraphRunnerException {
		return compiledGraph.invoke(Map.of(ReflectAgent.MESSAGES, List.of(new UserMessage(query))))
			.get()
			.<List<Message>>value(ReflectAgent.MESSAGES)
			.orElseThrow()
			.stream()
			.filter(message -> message.getMessageType() == MessageType.ASSISTANT)
			.reduce((first, second) -> second)
			.map(Message::getText)
			.orElseThrow();
	}

}

这个demo主要是用于写作场景,

  • AssistantGraphNode根据用户的写作要求或者建议,进行内容写作的。
  • JudgeGraphNode 是根据AssistantGraphNode生成的内容,对内容进行给出修改建议。

使用特别说明

  • AssistantGraphNode、JudgeGraphNode最后结果都是使用LLM能力,也是使用spring ai alibaba中标准的LlmNode节点实现,OverAllState状态流转中,默认使用了key名为messages,在LlmNode、AssistantGraphNode、JudgeGraphNode中流转,这个默认方式不限于这个示例,很多实现都是基于这个事实,spring ai alibaba graph 中内置的ReflectAgent、ReactAgent等都是基于此。

  • AssistantGraphNode、JudgeGraphNode内部基于graph实现的,他们的StateGraph 对于key名为messages采用的是追加策略,这个很重要,通过模型生成的写作内容和反馈内容追加到了messages上(这里可以理解为是大模型chat历史上下文的叠,作为下一阶段模型的上下文输入。),:

    StateGraph stateGraph = new StateGraph(() -> {
    				Map<String, KeyStrategy> strategies = new HashMap<>();
    				strategies.put(MESSAGES, new AppendStrategy());//都是采用追加的策略
    

基于上,我们在看ReflectAgent源码中几个关键方法:

	public StateGraph createReflectionGraph(NodeAction graph, NodeAction reflection, int maxIterations)
			throws GraphStateException {
		this.maxIterations = maxIterations;
		logger.debug("Creating reflection graph with max iterations: {}", maxIterations);

		StateGraph stateGraph = new StateGraph(() -> {
			HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
			keyStrategyHashMap.put(MESSAGES, new ReplaceStrategy());
			keyStrategyHashMap.put(ITERATION_NUM, new ReplaceStrategy());
			return keyStrategyHashMap;
		}).addNode(GRAPH_NODE_ID, node_async(graph))
			.addNode(REFLECTION_NODE_ID, node_async(reflection))
			.addEdge(START, GRAPH_NODE_ID)
			.addConditionalEdges(GRAPH_NODE_ID, edge_async(this::graphCount),
					Map.of(REFLECTION_NODE_ID, REFLECTION_NODE_ID, END, END))
			.addConditionalEdges(REFLECTION_NODE_ID, edge_async(this::apply),
					Map.of(GRAPH_NODE_ID, GRAPH_NODE_ID, END, END));

		logger.info("Reflection graph created successfully with {} nodes", 2);
		return stateGraph;
	}

	public String graphCount(OverAllState state) throws Exception {
		Optional<Object> iterationNumOptional = state.value(ITERATION_NUM);

		if (!iterationNumOptional.isPresent()) {
			logger.debug("Initializing iteration counter to 1");
			state.updateState(Map.of(ITERATION_NUM, 1));
		}
		else {
			Integer iterationNum = (Integer) iterationNumOptional.get();
			logger.info("Current iteration: {} | Max iterations: {}", iterationNum, maxIterations);

			if (iterationNum >= maxIterations) {
				logger.info("Iteration limit reached, stopping reflection cycle");
				state.updateState(Map.of(ITERATION_NUM, 0));
				this.printMessage(state);
				return END;
			}

			int nextIterationNum = iterationNum + 1;
			logger.debug("Incrementing iteration counter from {} to {}", iterationNum, nextIterationNum);
			state.updateState(Map.of(ITERATION_NUM, nextIterationNum));
		}

		Integer updatedCount = (Integer) state.value(ITERATION_NUM).orElseThrow();
		logger.debug("Updated iteration count: {}", updatedCount);

		return REFLECTION_NODE_ID;
	}

	public String apply(OverAllState state) throws Exception {
		List<Message> messages = (List<Message>) state.value(MESSAGES).get();
		int messageSize = messages.size();

		logger.debug("Processing messages, found {} messages in state", messageSize);

		if (messageSize == 0) {
			logger.info("No messages to process, ending reflection cycle");
			return END;
		}
		if (messages.get(messages.size() - 1).getMessageType().equals(MessageType.ASSISTANT)) {
			logger.info("Last message is from assistant: {}", messages.get(messages.size() - 1).getText());
			return END;
		}
		logger.debug("Last message is from user, continuing to graph node");
		return "graph";
	}
  • createReflectionGraph 方法创建graph,在创建从写作节点到反馈节点时,加了一个条件,这个条件就是迭代的次数,从写作到反馈节点跑完记录为一次,也就是说如果你设置最大迭代次数,都会跑满。

  • apply 方法是反馈节点再次回到写作节点判断的关键方法,可以看到最后从反馈节点到写作节点关键的判断方法:

    if (messages.get(messages.size() - 1).getMessageType().equals(MessageType.ASSISTANT)) {
    			logger.info("Last message is from assistant: {}", messages.get(messages.size() - 1).getText());
    			return END;
    		}
    

    反馈JudgeGraphNode节点执行后能够再次到达写作AssistantGraphNode节点,必须满足条件:状态state中messages中最后一条信息必须是非ASSISTANT类型的。这就要求自定义JudgeGraphNode执行完以后,必须更新最后一条消息类型设置为非ASSISTANT类型,这是使用ReflectAgent前提。这里为什么这么判断,因为这个也是我们常用的chat的多轮对话方。把JudgeGraphNode给出的反馈内容,作为用户角色内容发起,模型在AssistantGraphNode节点时拿到上下文,再次写作,达到想要的效果。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐