-
Notifications
You must be signed in to change notification settings - Fork 20.4k
fix(langchain): keep tool call / AIMessage pairings when summarizing
#34609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+124
−32
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
2092730
fix(core,langchain): use `get_buffer_string` for message summarization
mdrxy 665a876
Merge branch 'master' into mdrxy/fix-summarization
mdrxy d89af47
fix(langchain): keep tool call / `AIMessage` pairings when summarizing
mdrxy 9e2046c
Merge branch 'master' into mdrxy/fix-cutoff-summarization
mdrxy ad19a0b
Merge branch 'master' into mdrxy/fix-cutoff-summarization
mdrxy 0dc042a
add test
mdrxy 5bcbba5
Merge branch 'master' into mdrxy/fix-cutoff-summarization
mdrxy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -281,8 +281,8 @@ def token_counter(messages): | |
| ] | ||
|
|
||
|
|
||
| def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None: | ||
| """Ensure token retention advances past tool messages for aggressive summarization.""" | ||
| def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None: | ||
| """Ensure token retention preserves AI/Tool message pairs together.""" | ||
|
|
||
| def token_counter(messages: list[AnyMessage]) -> int: | ||
| return sum(len(getattr(message, "content", "")) for message in messages) | ||
|
|
@@ -297,7 +297,7 @@ def token_counter(messages: list[AnyMessage]) -> int: | |
| # Total tokens: 300 + 200 + 50 + 180 + 160 = 890 | ||
| # Target keep: 500 tokens (50% of 1000) | ||
| # Binary search finds cutoff around index 2 (ToolMessage) | ||
| # We advance past it to index 3 (HumanMessage) | ||
| # We move back to index 1 to preserve the AIMessage with its ToolMessage | ||
| messages: list[AnyMessage] = [ | ||
| HumanMessage(content="H" * 300), | ||
| AIMessage( | ||
|
|
@@ -314,14 +314,15 @@ def token_counter(messages: list[AnyMessage]) -> int: | |
| assert result is not None | ||
|
|
||
| preserved_messages = result["messages"][2:] | ||
| # With aggressive summarization, we advance past the ToolMessage | ||
| # So we preserve messages from index 3 onward (the two HumanMessages) | ||
| assert preserved_messages == messages[3:] | ||
| # We move the cutoff back to include the AIMessage with its ToolMessage | ||
| # So we preserve messages from index 1 onward (AI + Tool + Human + Human) | ||
| assert preserved_messages == messages[1:] | ||
|
|
||
| # Verify preserved tokens are within budget | ||
| target_token_count = int(1000 * 0.5) | ||
| preserved_tokens = middleware.token_counter(preserved_messages) | ||
| assert preserved_tokens <= target_token_count | ||
| # Verify the AI/Tool pair is preserved together | ||
| assert isinstance(preserved_messages[0], AIMessage) | ||
| assert preserved_messages[0].tool_calls | ||
| assert isinstance(preserved_messages[1], ToolMessage) | ||
| assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"] | ||
|
|
||
|
|
||
| def test_summarization_middleware_missing_profile() -> None: | ||
|
|
@@ -666,7 +667,7 @@ def token_counter_small(messages): | |
|
|
||
|
|
||
| def test_summarization_middleware_find_safe_cutoff_point() -> None: | ||
| """Test _find_safe_cutoff_point finds safe cutoff past ToolMessages.""" | ||
| """Test `_find_safe_cutoff_point` preserves AI/Tool message pairs.""" | ||
| model = FakeToolCallingModel() | ||
| middleware = SummarizationMiddleware( | ||
| model=model, trigger=("messages", 10), keep=("messages", 2) | ||
|
|
@@ -676,16 +677,22 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None: | |
| HumanMessage(content="msg1"), | ||
| AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]), | ||
| ToolMessage(content="result1", tool_call_id="call1"), | ||
| ToolMessage(content="result2", tool_call_id="call2"), | ||
| ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI | ||
| HumanMessage(content="msg2"), | ||
| ] | ||
|
|
||
| # Starting at a non-ToolMessage returns the same index | ||
| assert middleware._find_safe_cutoff_point(messages, 0) == 0 | ||
| assert middleware._find_safe_cutoff_point(messages, 1) == 1 | ||
|
|
||
| # Starting at a ToolMessage advances to the next non-ToolMessage | ||
| assert middleware._find_safe_cutoff_point(messages, 2) == 4 | ||
| # Starting at ToolMessage with matching AIMessage moves back to include it | ||
| # ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1 | ||
| assert middleware._find_safe_cutoff_point(messages, 2) == 1 | ||
|
|
||
| # Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing | ||
| # ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage | ||
| # Since we only collect from cutoff_index onwards, only {call2} is collected | ||
| # No match found, so we fall back to advancing past ToolMessages | ||
| assert middleware._find_safe_cutoff_point(messages, 3) == 4 | ||
|
|
||
| # Starting at the HumanMessage after tools returns that index | ||
|
|
@@ -699,6 +706,25 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None: | |
| assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5 | ||
|
|
||
|
|
||
| def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test passes on master
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're changing the convention (summarizing less vs. more) I've added a test that fails on master |
||
| """Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`).""" | ||
| model = FakeToolCallingModel() | ||
| middleware = SummarizationMiddleware( | ||
| model=model, trigger=("messages", 10), keep=("messages", 2) | ||
| ) | ||
|
|
||
| # Messages where ToolMessage has no matching AIMessage at all | ||
| messages: list[AnyMessage] = [ | ||
| HumanMessage(content="msg1"), | ||
| AIMessage(content="ai_no_tools"), # No tool_calls | ||
| ToolMessage(content="orphan_result", tool_call_id="orphan_call"), | ||
| HumanMessage(content="msg2"), | ||
| ] | ||
|
|
||
| # Starting at orphan ToolMessage falls back to advancing forward | ||
| assert middleware._find_safe_cutoff_point(messages, 2) == 3 | ||
|
|
||
|
|
||
| def test_summarization_middleware_zero_and_negative_target_tokens() -> None: | ||
| """Test handling of edge cases with target token calculations.""" | ||
| # Test with very small fraction that rounds to zero | ||
|
|
@@ -814,7 +840,7 @@ def _llm_type(self) -> str: | |
|
|
||
|
|
||
| def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: | ||
| """Test cutoff safety with many parallel tool calls extending beyond old search range.""" | ||
| """Test cutoff safety preserves AI message with many parallel tool calls.""" | ||
| middleware = SummarizationMiddleware( | ||
| model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5) | ||
| ) | ||
|
|
@@ -826,20 +852,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: | |
| ] | ||
| messages: list[AnyMessage] = [human_message, ai_message, *tool_messages] | ||
|
|
||
| # Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages) | ||
| assert middleware._find_safe_cutoff_point(messages, 7) == 12 | ||
| # Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage) | ||
| # to preserve the AI/Tool pair together | ||
| assert middleware._find_safe_cutoff_point(messages, 7) == 1 | ||
|
|
||
| # Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12 | ||
| # Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1 | ||
| for i in range(2, 12): | ||
| assert middleware._find_safe_cutoff_point(messages, i) == 12 | ||
| assert middleware._find_safe_cutoff_point(messages, i) == 1 | ||
|
|
||
| # Cutoff at index 0, 1 (before tool messages) stays the same | ||
| assert middleware._find_safe_cutoff_point(messages, 0) == 0 | ||
| assert middleware._find_safe_cutoff_point(messages, 1) == 1 | ||
|
|
||
|
|
||
| def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None: | ||
| """Test _find_safe_cutoff advances past ToolMessages to find safe cutoff.""" | ||
| def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None: | ||
| """Test `_find_safe_cutoff` preserves AI/Tool message pairs together.""" | ||
| middleware = SummarizationMiddleware( | ||
| model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3) | ||
| ) | ||
|
|
@@ -862,15 +889,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None | |
| ] | ||
|
|
||
| # Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3 | ||
| # Index 3 is a ToolMessage, so we advance past the tool sequence to index 5 | ||
| # Index 3 is a ToolMessage, we move back to index 1 to include AIMessage | ||
| cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3) | ||
| assert cutoff == 5 | ||
| assert cutoff == 1 | ||
|
|
||
| # With messages_to_keep=2, target cutoff index is 6 - 2 = 4 | ||
| # Index 4 is a ToolMessage, so we advance past the tool sequence to index 5 | ||
| # This is aggressive - we keep only 1 message instead of 2 | ||
| # Index 4 is a ToolMessage, we move back to index 1 to include AIMessage | ||
| # This preserves the AI + Tools + Human, more than requested but valid | ||
| cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2) | ||
| assert cutoff == 5 | ||
| assert cutoff == 1 | ||
|
|
||
|
|
||
| def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also don't want to land on an AIMessage with tool calls, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous logic advanced forward past
ToolMessageobjects (aggressive summarization). This approach takes the opposite approach, searching backward to include theAIMessagerequesting the tools (in other words, preserve more context for the sake of atomicity)If the cutoff lands on an
AIMessagewithtool_calls, the correspondingToolMessageresponses would be in the summarized portion, creating the reverse orphaning problem — tool call requests without their responses. Landing on anAIMessagewithtool_callsis safe because theToolMessageobjects come after it and will be preserved together.