diff --git a/comps/agent/langchain/src/strategy/react/planner.py b/comps/agent/langchain/src/strategy/react/planner.py index 9771f6220d..773cc199ce 100644 --- a/comps/agent/langchain/src/strategy/react/planner.py +++ b/comps/agent/langchain/src/strategy/react/planner.py @@ -147,6 +147,7 @@ async def non_streaming_run(self, query, config): from ...persistence import AgentPersistence, PersistenceConfig from ...utils import setup_chat_model +from .utils import assemble_history, assemble_memory, convert_json_to_tool_call class AgentState(TypedDict): @@ -174,16 +175,20 @@ def __init__(self, tools, args): llm = setup_chat_model(args) self.tools = tools self.chain = prompt | llm | output_parser + self.with_memory = args.with_memory def __call__(self, state): - from .utils import assemble_history, convert_json_to_tool_call print("---CALL Agent node---") messages = state["messages"] # assemble a prompt from messages - query = messages[0].content - history = assemble_history(messages) + if self.with_memory: + query, history = assemble_memory(messages) + print("@@@ Query: ", history) + else: + query = messages[0].content + history = assemble_history(messages) print("@@@ History: ", history) tools_descriptions = tool_renderer(self.tools) diff --git a/comps/agent/langchain/src/strategy/react/utils.py b/comps/agent/langchain/src/strategy/react/utils.py index f303b424a6..19e51032bd 100644 --- a/comps/agent/langchain/src/strategy/react/utils.py +++ b/comps/agent/langchain/src/strategy/react/utils.py @@ -5,7 +5,7 @@ import uuid from huggingface_hub import ChatCompletionOutputFunctionDefinition, ChatCompletionOutputToolCall -from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from langchain_core.messages.tool import ToolCall from langchain_core.output_parsers import BaseOutputParser @@ -82,3 +82,41 @@ def assemble_history(messages): query_history += f"Assistant Output: {m.content}\n" return query_history + + +def assemble_memory(messages): + """ + messages: Human, AI, TOOL, AI, TOOL, etc. + """ + query = "" + query_id = None + query_history = "" + breaker = "-" * 10 + + # get query + for m in messages[::-1]: + if isinstance(m, HumanMessage): + query = m.content + query_id = m.id + break + + for m in messages: + if isinstance(m, AIMessage): + # if there is tool call + if hasattr(m, "tool_calls") and len(m.tool_calls) > 0: + for tool_call in m.tool_calls: + tool = tool_call["name"] + tc_args = tool_call["args"] + id = tool_call["id"] + tool_output = get_tool_output(messages, id) + query_history += f"Tool Call: {tool} - {tc_args}\nTool Output: {tool_output}\n{breaker}\n" + else: + # did not make tool calls + query_history += f"Assistant Output: {m.content}\n" + + elif isinstance(m, HumanMessage): + if m.id == query_id: + continue + query_history += f"Human Input: {m.content}\n" + + return query, query_history