diff --git a/dspy/predict/react.py b/dspy/predict/react.py index 5f87879f80..2e7b973595 100644 --- a/dspy/predict/react.py +++ b/dspy/predict/react.py @@ -4,7 +4,9 @@ from litellm import ContextWindowExceededError import dspy +from dspy.adapters.types.history import History from dspy.adapters.types.tool import Tool +from dspy.primitives.example import Example from dspy.primitives.module import Module from dspy.signatures.signature import ensure_signature @@ -93,6 +95,34 @@ def _format_trajectory(self, trajectory: dict[str, Any]): trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") return adapter.format_user_message_content(trajectory_signature, trajectory) + def _format_history_with_trajectory(self, history: History | None) -> History | None: + if history is None or not history.messages: + return history + + formatted_messages = [] + changed = False + + for message in history.messages: + if isinstance(message, Example): + message_data = dict(message.items()) + elif isinstance(message, dict): + message_data = dict(message) + else: + formatted_messages.append(message) + continue + + trajectory_value = message_data.get("trajectory") + if trajectory_value and not isinstance(trajectory_value, str): + message_data["trajectory"] = self._format_trajectory(trajectory_value) + changed = True + + formatted_messages.append(message_data) + + if not changed: + return history + + return History(messages=formatted_messages) + def forward(self, **input_args): trajectory = {} max_iters = input_args.pop("max_iters", self.max_iters) @@ -146,10 +176,14 @@ async def aforward(self, **input_args): def _call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): for _ in range(3): try: - return module( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) + call_kwargs = dict(input_args) + call_kwargs["trajectory"] = self._format_trajectory(trajectory) + + history_value = call_kwargs.get("history") + if isinstance(history_value, History): + call_kwargs["history"] = self._format_history_with_trajectory(history_value) + + return module(**call_kwargs) except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") trajectory = self.truncate_trajectory(trajectory) @@ -157,10 +191,14 @@ def _call_with_potential_trajectory_truncation(self, module, trajectory, **input async def _async_call_with_potential_trajectory_truncation(self, module, trajectory, **input_args): for _ in range(3): try: - return await module.acall( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) + call_kwargs = dict(input_args) + call_kwargs["trajectory"] = self._format_trajectory(trajectory) + + history_value = call_kwargs.get("history") + if isinstance(history_value, History): + call_kwargs["history"] = self._format_history_with_trajectory(history_value) + + return await module.acall(**call_kwargs) except ContextWindowExceededError: logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") trajectory = self.truncate_trajectory(trajectory) diff --git a/tests/predict/test_react.py b/tests/predict/test_react.py index 55ff596072..96adc888cf 100644 --- a/tests/predict/test_react.py +++ b/tests/predict/test_react.py @@ -50,6 +50,73 @@ def make_images(): assert sum(1 for part in observation_content if isinstance(part, dict) and part.get("type") == "image_url") == 2 +def test_history_trajectory_uses_chat_format(): + def math_tool(expression: str) -> str: + return str(eval(expression)) + + adapter = dspy.ChatAdapter() + lm = DummyLM( + [ + { + "next_thought": "I should use the math tool.", + "next_tool_name": "math_tool", + "next_tool_args": {"expression": "2+2"}, + }, + { + "next_thought": "That answers the question; time to finish.", + "next_tool_name": "finish", + "next_tool_args": {}, + }, + { + "reasoning": "Computed 2+2 with the provided tool.", + "answer": "4", + }, + { + "next_thought": "Reusing the math tool for another calculation.", + "next_tool_name": "math_tool", + "next_tool_args": {"expression": "3*4"}, + }, + { + "next_thought": "I have the second result and can finish now.", + "next_tool_name": "finish", + "next_tool_args": {}, + }, + { + "reasoning": "Computed 3*4 with the provided tool.", + "answer": "12", + }, + ], + adapter=adapter, + ) + + dspy.settings.configure(lm=lm, adapter=adapter) + + class HistorySignature(dspy.Signature): + history: dspy.History = dspy.InputField() + question: str = dspy.InputField() + answer: str = dspy.OutputField() + + react = dspy.ReAct(HistorySignature, tools=[math_tool]) + + history = dspy.History(messages=[]) + + q1 = "What is 2+2?" + first_outputs = react(history=history, question=q1) + history.messages.append({"question": q1, **first_outputs}) + + q2 = "Now compute 3*4." + react(history=history, question=q2) + + messages = lm.history[-1]["messages"] + user_messages = [message for message in messages if message.get("role") == "user"] + assert len(user_messages) >= 2 + + first_user_content = user_messages[0]["content"] + assert "[[ ## trajectory ## ]]" in first_user_content + assert "[[ ## thought_0 ## ]]" in first_user_content + assert '"thought_0"' not in first_user_content + + def test_tool_calling_with_pydantic_args(): class CalendarEvent(BaseModel): name: str