From 749e8d45eaaacb5734cfec0ec79d1578f39f91af Mon Sep 17 00:00:00 2001 From: Sun Yuhan <1085481446@qq.com> Date: Tue, 11 Mar 2025 10:40:58 +0800 Subject: [PATCH] feature: Add the ability to pass custom metadata for MessageChatMemoryAdvisor and PromptChatMemoryAdvisor Signed-off-by: Sun Yuhan <1085481446@qq.com> --- .../advisor/MessageChatMemoryAdvisor.java | 25 +++++++----- .../advisor/PromptChatMemoryAdvisor.java | 40 ++++++++++++------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java index cd1c53cb301..c7355bc128f 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/MessageChatMemoryAdvisor.java @@ -16,9 +16,9 @@ package org.springframework.ai.chat.client.advisor; -import java.util.ArrayList; -import java.util.List; +import java.util.*; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; @@ -48,7 +48,7 @@ public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversatio } public MessageChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize, - int order) { + int order) { super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order); } @@ -92,9 +92,9 @@ private AdvisedRequest before(AdvisedRequest request) { // 3. Create a new request with the advised messages. AdvisedRequest advisedRequest = AdvisedRequest.from(request).messages(advisedMessages).build(); - // 4. Add the new user input to the conversation memory. - UserMessage userMessage = new UserMessage(request.userText(), request.media()); + Map metadata = new HashMap<>(request.adviseContext()); + UserMessage userMessage = new UserMessage(request.userText(), request.media(), metadata); this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; @@ -103,10 +103,17 @@ private AdvisedRequest before(AdvisedRequest request) { private void observeAfter(AdvisedResponse advisedResponse) { List assistantMessages = advisedResponse.response() - .getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); + .getResults() + .stream() + .map(g -> { + AssistantMessage output = g.getOutput(); + Map metadata = Optional.ofNullable(output.getMetadata()).orElse(new HashMap<>()); + metadata.putAll(advisedResponse.adviseContext()); + AssistantMessage assistantMessage = new AssistantMessage(output.getText(), + metadata, output.getToolCalls(), output.getMedia()); + return (Message) assistantMessage; + }) + .toList(); this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java index aa709878839..ad41844f195 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/PromptChatMemoryAdvisor.java @@ -19,8 +19,10 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; +import org.springframework.ai.chat.messages.AssistantMessage; import reactor.core.publisher.Flux; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; @@ -66,13 +68,13 @@ public PromptChatMemoryAdvisor(ChatMemory chatMemory, String systemTextAdvise) { } public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize, - String systemTextAdvise) { + String systemTextAdvise) { this(chatMemory, defaultConversationId, chatHistoryWindowSize, systemTextAdvise, Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); } public PromptChatMemoryAdvisor(ChatMemory chatMemory, String defaultConversationId, int chatHistoryWindowSize, - String systemTextAdvise, int order) { + String systemTextAdvise, int order) { super(chatMemory, defaultConversationId, chatHistoryWindowSize, true, order); this.systemTextAdvise = systemTextAdvise; } @@ -106,13 +108,13 @@ private AdvisedRequest before(AdvisedRequest request) { // 1. Advise system parameters. List memoryMessages = this.getChatMemoryStore() - .get(this.doGetConversationId(request.adviseContext()), - this.doGetChatMemoryRetrieveSize(request.adviseContext())); + .get(this.doGetConversationId(request.adviseContext()), + this.doGetChatMemoryRetrieveSize(request.adviseContext())); String memory = (memoryMessages != null) ? memoryMessages.stream() - .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) - .map(m -> m.getMessageType() + ":" + ((Content) m).getText()) - .collect(Collectors.joining(System.lineSeparator())) : ""; + .filter(m -> m.getMessageType() == MessageType.USER || m.getMessageType() == MessageType.ASSISTANT) + .map(m -> m.getMessageType() + ":" + ((Content) m).getText()) + .collect(Collectors.joining(System.lineSeparator())) : ""; Map advisedSystemParams = new HashMap<>(request.systemParams()); advisedSystemParams.put("memory", memory); @@ -122,12 +124,13 @@ private AdvisedRequest before(AdvisedRequest request) { // 3. Create a new request with the advised system text and parameters. AdvisedRequest advisedRequest = AdvisedRequest.from(request) - .systemText(advisedSystemText) - .systemParams(advisedSystemParams) - .build(); + .systemText(advisedSystemText) + .systemParams(advisedSystemParams) + .build(); // 4. Add the new user input to the conversation memory. - UserMessage userMessage = new UserMessage(request.userText(), request.media()); + Map metadata = new HashMap<>(request.adviseContext()); + UserMessage userMessage = new UserMessage(request.userText(), request.media(), metadata); this.getChatMemoryStore().add(this.doGetConversationId(request.adviseContext()), userMessage); return advisedRequest; @@ -136,10 +139,17 @@ private AdvisedRequest before(AdvisedRequest request) { private void observeAfter(AdvisedResponse advisedResponse) { List assistantMessages = advisedResponse.response() - .getResults() - .stream() - .map(g -> (Message) g.getOutput()) - .toList(); + .getResults() + .stream() + .map(g -> { + AssistantMessage output = g.getOutput(); + Map metadata = Optional.ofNullable(output.getMetadata()).orElse(new HashMap<>()); + metadata.putAll(advisedResponse.adviseContext()); + AssistantMessage assistantMessage = new AssistantMessage(output.getText(), + metadata, output.getToolCalls(), output.getMedia()); + return (Message) assistantMessage; + }) + .toList(); this.getChatMemoryStore().add(this.doGetConversationId(advisedResponse.adviseContext()), assistantMessages); }