diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 10baf724662c0..7abac7e750ce4 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -12,7 +12,9 @@ RemoveMessage, ToolMessage, ) +from langchain_core.messages.ai import AIMessage from langchain_core.messages.human import HumanMessage +from langchain_core.messages.system import SystemMessage from langchain_core.messages.utils import count_tokens_approximately, trim_messages from langgraph.graph.message import ( REMOVE_ALL_MESSAGES, @@ -491,8 +493,11 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str: if not trimmed_messages: return "Previous conversation was too long to summarize." + # Sanitize messages to remove irrelevant metadata + sanitized_messages = self._sanitize_messages_for_summary(trimmed_messages) + try: - response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages)) + response = self.model.invoke(self.summary_prompt.format(messages=sanitized_messages)) return response.text.strip() except Exception as e: return f"Error generating summary: {e!s}" @@ -506,9 +511,12 @@ async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str if not trimmed_messages: return "Previous conversation was too long to summarize." + # Sanitize messages to remove irrelevant metadata + sanitized_messages = self._sanitize_messages_for_summary(trimmed_messages) + try: response = await self.model.ainvoke( - self.summary_prompt.format(messages=trimmed_messages) + self.summary_prompt.format(messages=sanitized_messages) ) return response.text.strip() except Exception as e: @@ -533,3 +541,73 @@ def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMess ) except Exception: return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:] + + def _sanitize_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]: + """Create lightweight message copies with only essential content for summarization. + + Strips metadata fields that consume tokens but don't contribute to context: + + - usage_metadata (token counts) + - response_metadata (headers, logprobs, model info) + - additional_kwargs (provider-specific data) + - Simplifies tool_calls to just name and args + - Removes artifact from ToolMessages + + Args: + messages: Messages to sanitize. + + Returns: + Sanitized message copies with only essential content. + """ + sanitized: list[AnyMessage] = [] + for msg in messages: + if isinstance(msg, AIMessage): + # Simplify tool_calls to just name and args + simplified_tool_calls = ( + [ + {"name": tc["name"], "args": tc.get("args", {}), "id": tc.get("id", "")} + for tc in msg.tool_calls + ] + if msg.tool_calls + else [] + ) + + sanitized.append( + AIMessage( + content=msg.content, + tool_calls=simplified_tool_calls, + name=msg.name, + id=msg.id, + ) + ) + elif isinstance(msg, ToolMessage): + sanitized.append( + ToolMessage( + content=msg.content, + tool_call_id=msg.tool_call_id, + name=msg.name, + id=msg.id, + status=msg.status, + ) + ) + elif isinstance(msg, HumanMessage): + sanitized.append( + HumanMessage( + content=msg.content, + name=msg.name, + id=msg.id, + ) + ) + elif isinstance(msg, SystemMessage): + sanitized.append( + SystemMessage( + content=msg.content, + name=msg.name, + id=msg.id, + ) + ) + else: + # For other message types, preserve the original message + # since we don't know what fields are required + sanitized.append(msg) + return sanitized diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 2507f6f3e5d61..6f7da8a2c1966 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -3,7 +3,14 @@ import pytest from langchain_core.language_models import ModelProfile from langchain_core.language_models.chat_models import BaseChatModel -from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + AnyMessage, + HumanMessage, + RemoveMessage, + SystemMessage, + ToolMessage, +) from langchain_core.outputs import ChatGeneration, ChatResult from langgraph.graph.message import REMOVE_ALL_MESSAGES @@ -891,3 +898,316 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: # Index 2 is an AIMessage (safe cutoff point), so no adjustment needed cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4) assert cutoff == 2 + + +def test_sanitize_messages_for_summary_strips_metadata() -> None: + """Test that _sanitize_messages_for_summary strips irrelevant metadata.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + # Create messages with metadata + messages: list[AnyMessage] = [ + AIMessage( + content="AI response", + usage_metadata={"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + response_metadata={"model": "gpt-4", "logprobs": [0.1, 0.2]}, + additional_kwargs={"provider_data": "some_value"}, + ), + HumanMessage( + content="Human message", + additional_kwargs={"extra": "data"}, + ), + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + # Check that metadata is stripped + assert len(sanitized) == 2 + assert isinstance(sanitized[0], AIMessage) + assert sanitized[0].content == "AI response" + assert sanitized[0].usage_metadata is None + assert sanitized[0].response_metadata == {} + assert sanitized[0].additional_kwargs == {} + + assert isinstance(sanitized[1], HumanMessage) + assert sanitized[1].content == "Human message" + assert sanitized[1].additional_kwargs == {} + + +def test_sanitize_messages_for_summary_simplifies_tool_calls() -> None: + """Test that tool_calls are simplified to just name, args, and id.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + # Create AIMessage with tool_calls (using valid format) + messages: list[AnyMessage] = [ + AIMessage( + content="Calling tools", + tool_calls=[ + {"name": "search", "args": {"query": "test"}, "id": "call_123"}, + {"name": "calculator", "args": {"expression": "2+2"}, "id": "call_456"}, + ], + ) + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + assert len(sanitized) == 1 + assert isinstance(sanitized[0], AIMessage) + assert len(sanitized[0].tool_calls) == 2 + + # Check first tool call + tool_call_1 = sanitized[0].tool_calls[0] + assert tool_call_1["name"] == "search" + assert tool_call_1["args"] == {"query": "test"} + assert tool_call_1["id"] == "call_123" + + # Check second tool call + tool_call_2 = sanitized[0].tool_calls[1] + assert tool_call_2["name"] == "calculator" + assert tool_call_2["args"] == {"expression": "2+2"} + assert tool_call_2["id"] == "call_456" + + +def test_sanitize_messages_for_summary_handles_empty_tool_calls() -> None: + """Test that AIMessage with no tool_calls is handled correctly.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + messages: list[AnyMessage] = [ + AIMessage(content="No tools", tool_calls=[]), + AIMessage(content="Also no tools"), + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + assert len(sanitized) == 2 + assert sanitized[0].tool_calls == [] + assert sanitized[1].tool_calls == [] + + +def test_sanitize_messages_for_summary_removes_tool_message_artifact() -> None: + """Test that ToolMessage artifact is removed.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + messages: list[AnyMessage] = [ + ToolMessage( + content="Tool result", + tool_call_id="call_123", + artifact={"large_data": "x" * 1000, "binary": b"data"}, + status="success", + ) + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + assert len(sanitized) == 1 + assert isinstance(sanitized[0], ToolMessage) + assert sanitized[0].content == "Tool result" + assert sanitized[0].tool_call_id == "call_123" + assert sanitized[0].status == "success" + assert sanitized[0].artifact is None + + +def test_sanitize_messages_for_summary_preserves_essential_fields() -> None: + """Test that essential fields like content, name, and id are preserved.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + messages: list[AnyMessage] = [ + HumanMessage(content="User message", name="user1", id="msg_1"), + AIMessage(content="AI message", name="assistant", id="msg_2"), + ToolMessage(content="Tool output", tool_call_id="call_1", name="search", id="msg_3"), + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + assert len(sanitized) == 3 + + # Check HumanMessage + assert sanitized[0].content == "User message" + assert sanitized[0].name == "user1" + assert sanitized[0].id == "msg_1" + + # Check AIMessage + assert sanitized[1].content == "AI message" + assert sanitized[1].name == "assistant" + assert sanitized[1].id == "msg_2" + + # Check ToolMessage + assert sanitized[2].content == "Tool output" + assert sanitized[2].name == "search" + assert sanitized[2].id == "msg_3" + + +def test_sanitize_messages_for_summary_handles_system_messages() -> None: + """Test that SystemMessage is handled correctly.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + messages: list[AnyMessage] = [ + SystemMessage( + content="System prompt", + name="system", + id="sys_1", + additional_kwargs={"extra": "data"}, + ) + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + assert len(sanitized) == 1 + assert isinstance(sanitized[0], SystemMessage) + assert sanitized[0].content == "System prompt" + assert sanitized[0].name == "system" + assert sanitized[0].id == "sys_1" + assert sanitized[0].additional_kwargs == {} + + +def test_sanitize_messages_for_summary_handles_mixed_message_types() -> None: + """Test sanitization with a realistic mix of message types.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + messages: list[AnyMessage] = [ + SystemMessage(content="You are a helpful assistant", id="sys_1"), + HumanMessage(content="What's the weather?", id="msg_1"), + AIMessage( + content="Let me check", + tool_calls=[{"name": "weather", "args": {"city": "NYC"}, "id": "call_1"}], + usage_metadata={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}, + id="msg_2", + ), + ToolMessage( + content="Sunny, 72°F", + tool_call_id="call_1", + artifact={"raw_data": {"temp": 72, "condition": "sunny"}}, + id="msg_3", + ), + AIMessage( + content="It's sunny and 72°F in NYC", + response_metadata={"model": "gpt-4", "finish_reason": "stop"}, + id="msg_4", + ), + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + assert len(sanitized) == 5 + + # SystemMessage + assert isinstance(sanitized[0], SystemMessage) + assert sanitized[0].content == "You are a helpful assistant" + + # HumanMessage + assert isinstance(sanitized[1], HumanMessage) + assert sanitized[1].content == "What's the weather?" + + # AIMessage with tool_calls + assert isinstance(sanitized[2], AIMessage) + assert sanitized[2].content == "Let me check" + assert len(sanitized[2].tool_calls) == 1 + assert sanitized[2].tool_calls[0]["name"] == "weather" + assert sanitized[2].usage_metadata is None + + # ToolMessage + assert isinstance(sanitized[3], ToolMessage) + assert sanitized[3].content == "Sunny, 72°F" + assert sanitized[3].artifact is None + + # AIMessage without tool_calls + assert isinstance(sanitized[4], AIMessage) + assert sanitized[4].content == "It's sunny and 72°F in NYC" + assert sanitized[4].response_metadata == {} + + +def test_sanitize_messages_for_summary_handles_empty_list() -> None: + """Test that empty message list is handled correctly.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + sanitized = middleware._sanitize_messages_for_summary([]) + + assert sanitized == [] + + +def test_sanitize_messages_for_summary_handles_tool_call_missing_fields() -> None: + """Test that tool_calls with missing optional fields are handled gracefully.""" + middleware = SummarizationMiddleware(model=MockChatModel(), trigger=("messages", 5)) + + # Create AIMessage with valid tool_calls + messages: list[AnyMessage] = [ + AIMessage( + content="Calling tool", + tool_calls=[ + {"name": "search", "args": {}, "id": "call_1"}, + {"name": "calculator", "args": {"expr": "1+1"}, "id": "call_2"}, + ], + ) + ] + + sanitized = middleware._sanitize_messages_for_summary(messages) + + assert len(sanitized) == 1 + assert len(sanitized[0].tool_calls) == 2 + + # Verify tool calls are preserved correctly + assert sanitized[0].tool_calls[0]["name"] == "search" + assert sanitized[0].tool_calls[0]["args"] == {} + assert sanitized[0].tool_calls[0]["id"] == "call_1" + + assert sanitized[0].tool_calls[1]["name"] == "calculator" + assert sanitized[0].tool_calls[1]["args"] == {"expr": "1+1"} + assert sanitized[0].tool_calls[1]["id"] == "call_2" + + +def test_sanitize_messages_integration_with_create_summary() -> None: + """Test that sanitization is properly integrated into _create_summary.""" + middleware = SummarizationMiddleware( + model=MockChatModel(), trigger=("messages", 5), trim_tokens_to_summarize=None + ) + + # Create messages with lots of metadata + messages: list[AnyMessage] = [ + AIMessage( + content="Response", + usage_metadata={"input_tokens": 1000, "output_tokens": 500, "total_tokens": 1500}, + response_metadata={"model": "gpt-4", "headers": {"x-custom": "value"}}, + additional_kwargs={"provider_specific": "data" * 100}, + ) + ] + + # Call _create_summary which should sanitize before formatting + summary = middleware._create_summary(messages) + + # Should successfully create summary without including metadata + assert summary == "Generated summary" + + +async def test_sanitize_messages_integration_with_acreate_summary() -> None: + """Test that sanitization is properly integrated into _acreate_summary.""" + + class AsyncMockModel(BaseChatModel): + def _generate(self, messages, **kwargs): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Sync"))]) + + async def _agenerate(self, messages, **kwargs): + return ChatResult( + generations=[ChatGeneration(message=AIMessage(content="Async summary"))] + ) + + @property + def _llm_type(self): + return "mock" + + middleware = SummarizationMiddleware( + model=AsyncMockModel(), trigger=("messages", 5), trim_tokens_to_summarize=None + ) + + # Create messages with lots of metadata + messages: list[AnyMessage] = [ + AIMessage( + content="Response", + usage_metadata={"input_tokens": 1000, "output_tokens": 500, "total_tokens": 1500}, + response_metadata={"model": "gpt-4", "headers": {"x-custom": "value"}}, + ) + ] + + # Call _acreate_summary which should sanitize before formatting + summary = await middleware._acreate_summary(messages) + + # Should successfully create summary without including metadata + assert summary == "Async summary" diff --git a/libs/langchain_v1/uv.lock b/libs/langchain_v1/uv.lock index c8765dca4c88b..7eef2bee17e8b 100644 --- a/libs/langchain_v1/uv.lock +++ b/libs/langchain_v1/uv.lock @@ -2380,12 +2380,12 @@ requires-dist = [ ] [package.metadata.requires-dev] -lint = [{ name = "ruff", specifier = ">=0.13.1,<0.14.0" }] +lint = [{ name = "ruff", specifier = ">=0.14.10,<0.15.0" }] test = [{ name = "langchain-core", editable = "../core" }] test-integration = [] typing = [ { name = "langchain-core", editable = "../core" }, - { name = "mypy", specifier = ">=1.18.1,<1.19.0" }, + { name = "mypy", specifier = ">=1.19.1,<1.20.0" }, { name = "types-pyyaml", specifier = ">=6.0.12.2,<7.0.0.0" }, ] @@ -2407,7 +2407,7 @@ dev = [ ] lint = [ { name = "langchain-core", editable = "../core" }, - { name = "ruff", specifier = ">=0.13.1,<0.14.0" }, + { name = "ruff", specifier = ">=0.14.10,<0.15.0" }, ] test = [ { name = "freezegun", specifier = ">=1.2.2,<2.0.0" }, @@ -2433,7 +2433,7 @@ test-integration = [ typing = [ { name = "beautifulsoup4", specifier = ">=4.13.5,<5.0.0" }, { name = "lxml-stubs", specifier = ">=0.5.1,<1.0.0" }, - { name = "mypy", specifier = ">=1.18.1,<1.19.0" }, + { name = "mypy", specifier = ">=1.19.1,<1.20.0" }, { name = "tiktoken", specifier = ">=0.8.0,<1.0.0" }, { name = "types-requests", specifier = ">=2.31.0.20240218,<3.0.0.0" }, ]