diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index 37c782842b8c1..8f6fb2789cdf3 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -7,6 +7,7 @@ from typing import Any, Literal, cast from langchain_core.messages import ( + AIMessage, AnyMessage, MessageLikeRepresentation, RemoveMessage, @@ -478,13 +479,37 @@ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) - def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int: """Find a safe cutoff point that doesn't split AI/Tool message pairs. - If the message at cutoff_index is a ToolMessage, advance until we find - a non-ToolMessage. This ensures we never cut in the middle of parallel - tool call responses. + If the message at `cutoff_index` is a `ToolMessage`, search backward for the + `AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to + include it. This ensures tool call requests and responses stay together. + + Falls back to advancing forward past `ToolMessage` objects only if no matching + `AIMessage` is found (edge case). """ - while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage): - cutoff_index += 1 - return cutoff_index + if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage): + return cutoff_index + + # Collect tool_call_ids from consecutive ToolMessages at/after cutoff + tool_call_ids: set[str] = set() + idx = cutoff_index + while idx < len(messages) and isinstance(messages[idx], ToolMessage): + tool_msg = cast("ToolMessage", messages[idx]) + if tool_msg.tool_call_id: + tool_call_ids.add(tool_msg.tool_call_id) + idx += 1 + + # Search backward for AIMessage with matching tool_calls + for i in range(cutoff_index - 1, -1, -1): + msg = messages[i] + if isinstance(msg, AIMessage) and msg.tool_calls: + ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")} + if tool_call_ids & ai_tool_call_ids: + # Found the AIMessage - move cutoff to include it + return i + + # Fallback: no matching AIMessage found, advance past ToolMessages to avoid + # orphaned tool responses + return idx def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str: """Generate summary for the given messages.""" diff --git a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py index 7c2995061a60a..c0ac6c82f5150 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py +++ b/libs/langchain_v1/tests/unit_tests/agents/middleware/implementations/test_summarization.py @@ -281,8 +281,8 @@ def token_counter(messages): ] -def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None: - """Ensure token retention advances past tool messages for aggressive summarization.""" +def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None: + """Ensure token retention preserves AI/Tool message pairs together.""" def token_counter(messages: list[AnyMessage]) -> int: return sum(len(getattr(message, "content", "")) for message in messages) @@ -297,7 +297,7 @@ def token_counter(messages: list[AnyMessage]) -> int: # Total tokens: 300 + 200 + 50 + 180 + 160 = 890 # Target keep: 500 tokens (50% of 1000) # Binary search finds cutoff around index 2 (ToolMessage) - # We advance past it to index 3 (HumanMessage) + # We move back to index 1 to preserve the AIMessage with its ToolMessage messages: list[AnyMessage] = [ HumanMessage(content="H" * 300), AIMessage( @@ -314,14 +314,15 @@ def token_counter(messages: list[AnyMessage]) -> int: assert result is not None preserved_messages = result["messages"][2:] - # With aggressive summarization, we advance past the ToolMessage - # So we preserve messages from index 3 onward (the two HumanMessages) - assert preserved_messages == messages[3:] + # We move the cutoff back to include the AIMessage with its ToolMessage + # So we preserve messages from index 1 onward (AI + Tool + Human + Human) + assert preserved_messages == messages[1:] - # Verify preserved tokens are within budget - target_token_count = int(1000 * 0.5) - preserved_tokens = middleware.token_counter(preserved_messages) - assert preserved_tokens <= target_token_count + # Verify the AI/Tool pair is preserved together + assert isinstance(preserved_messages[0], AIMessage) + assert preserved_messages[0].tool_calls + assert isinstance(preserved_messages[1], ToolMessage) + assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"] def test_summarization_middleware_missing_profile() -> None: @@ -666,7 +667,7 @@ def token_counter_small(messages): def test_summarization_middleware_find_safe_cutoff_point() -> None: - """Test _find_safe_cutoff_point finds safe cutoff past ToolMessages.""" + """Test `_find_safe_cutoff_point` preserves AI/Tool message pairs.""" model = FakeToolCallingModel() middleware = SummarizationMiddleware( model=model, trigger=("messages", 10), keep=("messages", 2) @@ -676,7 +677,7 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None: HumanMessage(content="msg1"), AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]), ToolMessage(content="result1", tool_call_id="call1"), - ToolMessage(content="result2", tool_call_id="call2"), + ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI HumanMessage(content="msg2"), ] @@ -684,8 +685,14 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None: assert middleware._find_safe_cutoff_point(messages, 0) == 0 assert middleware._find_safe_cutoff_point(messages, 1) == 1 - # Starting at a ToolMessage advances to the next non-ToolMessage - assert middleware._find_safe_cutoff_point(messages, 2) == 4 + # Starting at ToolMessage with matching AIMessage moves back to include it + # ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1 + assert middleware._find_safe_cutoff_point(messages, 2) == 1 + + # Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing + # ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage + # Since we only collect from cutoff_index onwards, only {call2} is collected + # No match found, so we fall back to advancing past ToolMessages assert middleware._find_safe_cutoff_point(messages, 3) == 4 # Starting at the HumanMessage after tools returns that index @@ -699,6 +706,65 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None: assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5 +def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None: + """Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`).""" + model = FakeToolCallingModel() + middleware = SummarizationMiddleware( + model=model, trigger=("messages", 10), keep=("messages", 2) + ) + + # Messages where ToolMessage has no matching AIMessage at all + messages: list[AnyMessage] = [ + HumanMessage(content="msg1"), + AIMessage(content="ai_no_tools"), # No tool_calls + ToolMessage(content="orphan_result", tool_call_id="orphan_call"), + HumanMessage(content="msg2"), + ] + + # Starting at orphan ToolMessage falls back to advancing forward + assert middleware._find_safe_cutoff_point(messages, 2) == 3 + + +def test_summarization_cutoff_moves_backward_to_include_ai_message() -> None: + """Test that cutoff moves backward to include `AIMessage` with its `ToolMessage`s. + + Previously, when the cutoff landed on a `ToolMessage`, the code would advance + FORWARD past all `ToolMessage`s. This could result in orphaned `ToolMessage`s (kept + without their `AIMessage`) or aggressive summarization that removed AI/Tool pairs. + + The fix searches backward from a `ToolMessage` to find the `AIMessage` with matching + `tool_calls`, ensuring the pair stays together in the preserved messages. + """ + model = FakeToolCallingModel() + middleware = SummarizationMiddleware( + model=model, trigger=("messages", 10), keep=("messages", 2) + ) + + # Scenario: cutoff lands on ToolMessage that has a matching AIMessage before it + messages: list[AnyMessage] = [ + HumanMessage(content="initial question"), # index 0 + AIMessage( + content="I'll use a tool", + tool_calls=[{"name": "search", "args": {"q": "test"}, "id": "call_abc"}], + ), # index 1 + ToolMessage(content="search result", tool_call_id="call_abc"), # index 2 + HumanMessage(content="followup"), # index 3 + ] + + # When cutoff is at index 2 (ToolMessage), it should move BACKWARD to index 1 + # to include the AIMessage that generated the tool call + result = middleware._find_safe_cutoff_point(messages, 2) + + assert result == 1, ( + f"Expected cutoff to move backward to index 1 (AIMessage), got {result}. " + "The cutoff should preserve AI/Tool pairs together." + ) + + assert isinstance(messages[result], AIMessage) + assert messages[result].tool_calls # type: ignore[union-attr] + assert messages[result].tool_calls[0]["id"] == "call_abc" # type: ignore[union-attr] + + def test_summarization_middleware_zero_and_negative_target_tokens() -> None: """Test handling of edge cases with target token calculations.""" # Test with very small fraction that rounds to zero @@ -814,7 +880,7 @@ def _llm_type(self) -> str: def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: - """Test cutoff safety with many parallel tool calls extending beyond old search range.""" + """Test cutoff safety preserves AI message with many parallel tool calls.""" middleware = SummarizationMiddleware( model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5) ) @@ -826,20 +892,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: ] messages: list[AnyMessage] = [human_message, ai_message, *tool_messages] - # Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages) - assert middleware._find_safe_cutoff_point(messages, 7) == 12 + # Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage) + # to preserve the AI/Tool pair together + assert middleware._find_safe_cutoff_point(messages, 7) == 1 - # Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12 + # Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1 for i in range(2, 12): - assert middleware._find_safe_cutoff_point(messages, i) == 12 + assert middleware._find_safe_cutoff_point(messages, i) == 1 # Cutoff at index 0, 1 (before tool messages) stays the same assert middleware._find_safe_cutoff_point(messages, 0) == 0 assert middleware._find_safe_cutoff_point(messages, 1) == 1 -def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None: - """Test _find_safe_cutoff advances past ToolMessages to find safe cutoff.""" +def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None: + """Test `_find_safe_cutoff` preserves AI/Tool message pairs together.""" middleware = SummarizationMiddleware( model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3) ) @@ -862,15 +929,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None ] # Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3 - # Index 3 is a ToolMessage, so we advance past the tool sequence to index 5 + # Index 3 is a ToolMessage, we move back to index 1 to include AIMessage cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3) - assert cutoff == 5 + assert cutoff == 1 # With messages_to_keep=2, target cutoff index is 6 - 2 = 4 - # Index 4 is a ToolMessage, so we advance past the tool sequence to index 5 - # This is aggressive - we keep only 1 message instead of 2 + # Index 4 is a ToolMessage, we move back to index 1 to include AIMessage + # This preserves the AI + Tools + Human, more than requested but valid cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2) - assert cutoff == 5 + assert cutoff == 1 def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: