spring ai alibaba agent之ReflectAgent源码解读
【摘要】本文深入解读了Spring AI Alibaba中的ReflectAgent源码,重点分析了其设计原理和使用场景。文章首先推荐读者参考官方文档了解底层设计,然后通过流程图展示了ReflectAgent的工作流程。作者提供了官方示例代码,包括AssistantGraphNode和JudgeGraphNode两个核心节点的实现,演示了如何构建LLM节点、设置系统提示模板以及处理状态流转。这种A
spring ai alibaba agent之ReflectAgent源码解读
如果不熟悉的可以看下官方写的这篇文章[关于spring ai alibaba graph 使用指南于源码解读](Spring AI Alibaba Graph 使用指南与源码解读),这边文章会让你对底层的设计和实现有所了解。本文讲的更像一个ReflectAgent使用指南,目前官网没有关于正式的使用文档,只是给出一些使用demo,这篇文章会分析整个ReflectAgent流程和使用流程注意。这个agent适合场景写作、nl2sql生成等。
官网使用示例
流程图如下:

官网给出的使用demo,主要代码:
@Configuration
public class RelectionAutoconfiguration {
public static class AssistantGraphNode implements NodeAction {
private final LlmNode llmNode;
private SystemPromptTemplate systemPromptTemplate;
private final String NODE_ID = "call_model";
private static final String CLASSIFIER_PROMPT_TEMPLATE = """
You are an essay assistant tasked with writing excellent 5-paragraph essays.
Generate the best essay possible for the user's request.
If the user provides critique, respond with a revised version of your previous attempts.
Only return the main content I need, without adding any other interactive language.
Please answer in Chinese:
""";
public AssistantGraphNode(ChatClient chatClient) {
this.systemPromptTemplate = new SystemPromptTemplate(CLASSIFIER_PROMPT_TEMPLATE);
this.llmNode = LlmNode.builder()
.systemPromptTemplate(systemPromptTemplate.render())
.chatClient(chatClient)
.messagesKey("messages")
.build();
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private ChatClient chatClient;
public Builder chatClient(ChatClient chatClient) {
this.chatClient = chatClient;
return this;
}
public AssistantGraphNode build() {
if (chatClient == null) {
throw new IllegalArgumentException("ChatClient must be provided");
}
return new AssistantGraphNode(chatClient);
}
}
@Override
public Map<String, Object> apply(OverAllState overAllState) throws Exception {
List<Message> messages = (List<Message>) overAllState.value(MESSAGES).get();
StateGraph stateGraph = new StateGraph(() -> {
Map<String, KeyStrategy> strategies = new HashMap<>();
strategies.put(MESSAGES, new AppendStrategy());
return strategies;
}).addNode(this.NODE_ID, node_async(llmNode)).addEdge(START, this.NODE_ID).addEdge(this.NODE_ID, END);
OverAllState invokeState = stateGraph.compile().invoke(Map.of(MESSAGES, messages)).get();
List<Message> reactMessages = (List<Message>) invokeState.value(MESSAGES).orElseThrow();
return Map.of(MESSAGES, reactMessages);
}
}
public static class JudgeGraphNode implements NodeAction {
private final LlmNode llmNode;
private final String NODE_ID = "judge_response";
private SystemPromptTemplate systemPromptTemplate;
private static final String CLASSIFIER_PROMPT_TEMPLATE = """
You are a teacher grading a student's essay submission. Provide detailed feedback and revision suggestions for the essay.
Your feedback should cover the following aspects:
- Length : Is the essay sufficiently developed? Does it meet the required length or need expansion/shortening?
- Depth : Are the ideas well-developed? Is there sufficient analysis, evidence, or explanation?
- Structure : Is the organization logical and clear? Are the introduction, transitions, and conclusion effective?
- Style and Tone : Is the writing style appropriate for the purpose and audience? Is the tone consistent and professional?
- Language Use : Are vocabulary, grammar, and sentence structure accurate and varied?
- Focus only on providing actionable suggestions for improvement. Do not include grades, scores, or overall summary evaluations.
Please respond in Chinese .
""";
public JudgeGraphNode(ChatClient chatClient) {
this.systemPromptTemplate = new SystemPromptTemplate(CLASSIFIER_PROMPT_TEMPLATE);
this.llmNode = LlmNode.builder()
.chatClient(chatClient)
.systemPromptTemplate(systemPromptTemplate.render())
.messagesKey(MESSAGES)
.build();
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
private ChatClient chatClient;
public JudgeGraphNode.Builder chatClient(ChatClient chatClient) {
this.chatClient = chatClient;
return this;
}
public JudgeGraphNode build() {
if (chatClient == null) {
throw new IllegalArgumentException("ChatClient must be provided");
}
return new JudgeGraphNode(chatClient);
}
}
@Override
public Map<String, Object> apply(OverAllState allState) throws Exception {
List<Message> messages = (List<Message>) allState.value(MESSAGES).get();
StateGraph stateGraph = new StateGraph(() -> {
Map<String, KeyStrategy> strategies = new HashMap<>();
strategies.put(MESSAGES, new AppendStrategy());
return strategies;
}).addNode(this.NODE_ID, node_async(llmNode)).addEdge(START, this.NODE_ID).addEdge(this.NODE_ID, END);
CompiledGraph compile = stateGraph.compile();
OverAllState invokeState = compile.invoke(Map.of(MESSAGES, messages)).get();
UnaryOperator<List<Message>> convertLastToUserMessage = messageList -> {
int size = messageList.size();
if (size == 0)
return messageList;
Message last = messageList.get(size - 1);
messageList.set(size - 1, new UserMessage(last.getText()));
return messageList;
};
List<Message> reactMessages = (List<Message>) invokeState.value(MESSAGES).orElseThrow();
convertLastToUserMessage.apply(reactMessages);
return Map.of(MESSAGES, reactMessages);
}
}
@Bean
public CompiledGraph reflectGraph(ChatModel chatModel) throws GraphStateException {
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(new SimpleLoggerAdvisor())
.defaultOptions(OpenAiChatOptions.builder().internalToolExecutionEnabled(false).build())
.build();
AssistantGraphNode assistantGraphNode = AssistantGraphNode.builder().chatClient(chatClient).build();
JudgeGraphNode judgeGraphNode = JudgeGraphNode.builder().chatClient(chatClient).build();
ReflectAgent reflectAgent = ReflectAgent.builder()
.graph(assistantGraphNode)
.reflection(judgeGraphNode)
.maxIterations(2)
.build();
return reflectAgent.getAndCompileGraph();
}
}
测试入口:
@RestController
@RequestMapping("/reflection")
public class ReflectionController {
private static final Logger logger = LoggerFactory.getLogger(ReflectionController.class);
private CompiledGraph compiledGraph;
public ReflectionController(@Qualifier("reflectGraph") CompiledGraph compiledGraph) {
this.compiledGraph = compiledGraph;
}
@GetMapping("/chat")
public String simpleChat(String query) throws GraphRunnerException {
return compiledGraph.invoke(Map.of(ReflectAgent.MESSAGES, List.of(new UserMessage(query))))
.get()
.<List<Message>>value(ReflectAgent.MESSAGES)
.orElseThrow()
.stream()
.filter(message -> message.getMessageType() == MessageType.ASSISTANT)
.reduce((first, second) -> second)
.map(Message::getText)
.orElseThrow();
}
}
这个demo主要是用于写作场景,
- AssistantGraphNode根据用户的写作要求或者建议,进行内容写作的。
- JudgeGraphNode 是根据AssistantGraphNode生成的内容,对内容进行给出修改建议。
使用特别说明:
-
AssistantGraphNode、JudgeGraphNode最后结果都是使用LLM能力,也是使用spring ai alibaba中标准的LlmNode节点实现,OverAllState状态流转中,默认使用了key名为messages,在LlmNode、AssistantGraphNode、JudgeGraphNode中流转,这个默认方式不限于这个示例,很多实现都是基于这个事实,spring ai alibaba graph 中内置的ReflectAgent、ReactAgent等都是基于此。
-
AssistantGraphNode、JudgeGraphNode内部基于graph实现的,他们的StateGraph 对于key名为messages采用的是追加策略,这个很重要,通过模型生成的写作内容和反馈内容追加到了messages上(这里可以理解为是大模型chat历史上下文的叠,作为下一阶段模型的上下文输入。),:
StateGraph stateGraph = new StateGraph(() -> { Map<String, KeyStrategy> strategies = new HashMap<>(); strategies.put(MESSAGES, new AppendStrategy());//都是采用追加的策略
基于上,我们在看ReflectAgent源码中几个关键方法:
public StateGraph createReflectionGraph(NodeAction graph, NodeAction reflection, int maxIterations)
throws GraphStateException {
this.maxIterations = maxIterations;
logger.debug("Creating reflection graph with max iterations: {}", maxIterations);
StateGraph stateGraph = new StateGraph(() -> {
HashMap<String, KeyStrategy> keyStrategyHashMap = new HashMap<>();
keyStrategyHashMap.put(MESSAGES, new ReplaceStrategy());
keyStrategyHashMap.put(ITERATION_NUM, new ReplaceStrategy());
return keyStrategyHashMap;
}).addNode(GRAPH_NODE_ID, node_async(graph))
.addNode(REFLECTION_NODE_ID, node_async(reflection))
.addEdge(START, GRAPH_NODE_ID)
.addConditionalEdges(GRAPH_NODE_ID, edge_async(this::graphCount),
Map.of(REFLECTION_NODE_ID, REFLECTION_NODE_ID, END, END))
.addConditionalEdges(REFLECTION_NODE_ID, edge_async(this::apply),
Map.of(GRAPH_NODE_ID, GRAPH_NODE_ID, END, END));
logger.info("Reflection graph created successfully with {} nodes", 2);
return stateGraph;
}
public String graphCount(OverAllState state) throws Exception {
Optional<Object> iterationNumOptional = state.value(ITERATION_NUM);
if (!iterationNumOptional.isPresent()) {
logger.debug("Initializing iteration counter to 1");
state.updateState(Map.of(ITERATION_NUM, 1));
}
else {
Integer iterationNum = (Integer) iterationNumOptional.get();
logger.info("Current iteration: {} | Max iterations: {}", iterationNum, maxIterations);
if (iterationNum >= maxIterations) {
logger.info("Iteration limit reached, stopping reflection cycle");
state.updateState(Map.of(ITERATION_NUM, 0));
this.printMessage(state);
return END;
}
int nextIterationNum = iterationNum + 1;
logger.debug("Incrementing iteration counter from {} to {}", iterationNum, nextIterationNum);
state.updateState(Map.of(ITERATION_NUM, nextIterationNum));
}
Integer updatedCount = (Integer) state.value(ITERATION_NUM).orElseThrow();
logger.debug("Updated iteration count: {}", updatedCount);
return REFLECTION_NODE_ID;
}
public String apply(OverAllState state) throws Exception {
List<Message> messages = (List<Message>) state.value(MESSAGES).get();
int messageSize = messages.size();
logger.debug("Processing messages, found {} messages in state", messageSize);
if (messageSize == 0) {
logger.info("No messages to process, ending reflection cycle");
return END;
}
if (messages.get(messages.size() - 1).getMessageType().equals(MessageType.ASSISTANT)) {
logger.info("Last message is from assistant: {}", messages.get(messages.size() - 1).getText());
return END;
}
logger.debug("Last message is from user, continuing to graph node");
return "graph";
}
-
createReflectionGraph 方法创建graph,在创建从写作节点到反馈节点时,加了一个条件,这个条件就是迭代的次数,从写作到反馈节点跑完记录为一次,也就是说如果你设置最大迭代次数,都会跑满。
-
apply 方法是反馈节点再次回到写作节点判断的关键方法,可以看到最后从反馈节点到写作节点关键的判断方法:
if (messages.get(messages.size() - 1).getMessageType().equals(MessageType.ASSISTANT)) { logger.info("Last message is from assistant: {}", messages.get(messages.size() - 1).getText()); return END; }反馈JudgeGraphNode节点执行后能够再次到达写作AssistantGraphNode节点,必须满足条件:状态state中messages中最后一条信息必须是非ASSISTANT类型的。这就要求自定义JudgeGraphNode执行完以后,必须更新最后一条消息类型设置为非ASSISTANT类型,这是使用ReflectAgent前提。这里为什么这么判断,因为这个也是我们常用的chat的多轮对话方。把JudgeGraphNode给出的反馈内容,作为用户角色内容发起,模型在AssistantGraphNode节点时拿到上下文,再次写作,达到想要的效果。
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)