-
Notifications
You must be signed in to change notification settings - Fork 20.4k
fix(langchain): keep tool call / AIMessage pairings when summarizing
#34609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
2092730
665a876
d89af47
9e2046c
ad19a0b
0dc042a
5bcbba5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| from langchain_core.language_models import ModelProfile | ||
| from langchain_core.language_models.chat_models import BaseChatModel | ||
| from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage | ||
| from langchain_core.messages.utils import count_tokens_approximately, get_buffer_string | ||
| from langchain_core.outputs import ChatGeneration, ChatResult | ||
| from langgraph.graph.message import REMOVE_ALL_MESSAGES | ||
|
|
||
|
|
@@ -280,8 +281,8 @@ def token_counter(messages): | |
| ] | ||
|
|
||
|
|
||
| def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None: | ||
| """Ensure token retention advances past tool messages for aggressive summarization.""" | ||
| def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None: | ||
| """Ensure token retention preserves AI/Tool message pairs together.""" | ||
|
|
||
| def token_counter(messages: list[AnyMessage]) -> int: | ||
| return sum(len(getattr(message, "content", "")) for message in messages) | ||
|
|
@@ -296,7 +297,7 @@ def token_counter(messages: list[AnyMessage]) -> int: | |
| # Total tokens: 300 + 200 + 50 + 180 + 160 = 890 | ||
| # Target keep: 500 tokens (50% of 1000) | ||
| # Binary search finds cutoff around index 2 (ToolMessage) | ||
| # We advance past it to index 3 (HumanMessage) | ||
| # We move back to index 1 to preserve the AIMessage with its ToolMessage | ||
| messages: list[AnyMessage] = [ | ||
| HumanMessage(content="H" * 300), | ||
| AIMessage( | ||
|
|
@@ -313,14 +314,15 @@ def token_counter(messages: list[AnyMessage]) -> int: | |
| assert result is not None | ||
|
|
||
| preserved_messages = result["messages"][2:] | ||
| # With aggressive summarization, we advance past the ToolMessage | ||
| # So we preserve messages from index 3 onward (the two HumanMessages) | ||
| assert preserved_messages == messages[3:] | ||
| # We move the cutoff back to include the AIMessage with its ToolMessage | ||
| # So we preserve messages from index 1 onward (AI + Tool + Human + Human) | ||
| assert preserved_messages == messages[1:] | ||
|
|
||
| # Verify preserved tokens are within budget | ||
| target_token_count = int(1000 * 0.5) | ||
| preserved_tokens = middleware.token_counter(preserved_messages) | ||
| assert preserved_tokens <= target_token_count | ||
| # Verify the AI/Tool pair is preserved together | ||
| assert isinstance(preserved_messages[0], AIMessage) | ||
| assert preserved_messages[0].tool_calls | ||
| assert isinstance(preserved_messages[1], ToolMessage) | ||
| assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"] | ||
|
|
||
|
|
||
| def test_summarization_middleware_missing_profile() -> None: | ||
|
|
@@ -665,7 +667,7 @@ def token_counter_small(messages): | |
|
|
||
|
|
||
| def test_summarization_middleware_find_safe_cutoff_point() -> None: | ||
| """Test _find_safe_cutoff_point finds safe cutoff past ToolMessages.""" | ||
| """Test `_find_safe_cutoff_point` preserves AI/Tool message pairs.""" | ||
| model = FakeToolCallingModel() | ||
| middleware = SummarizationMiddleware( | ||
| model=model, trigger=("messages", 10), keep=("messages", 2) | ||
|
|
@@ -675,16 +677,22 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None: | |
| HumanMessage(content="msg1"), | ||
| AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]), | ||
| ToolMessage(content="result1", tool_call_id="call1"), | ||
| ToolMessage(content="result2", tool_call_id="call2"), | ||
| ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI | ||
| HumanMessage(content="msg2"), | ||
| ] | ||
|
|
||
| # Starting at a non-ToolMessage returns the same index | ||
| assert middleware._find_safe_cutoff_point(messages, 0) == 0 | ||
| assert middleware._find_safe_cutoff_point(messages, 1) == 1 | ||
|
|
||
| # Starting at a ToolMessage advances to the next non-ToolMessage | ||
| assert middleware._find_safe_cutoff_point(messages, 2) == 4 | ||
| # Starting at ToolMessage with matching AIMessage moves back to include it | ||
| # ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1 | ||
| assert middleware._find_safe_cutoff_point(messages, 2) == 1 | ||
|
|
||
| # Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing | ||
| # ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage | ||
| # Since we only collect from cutoff_index onwards, only {call2} is collected | ||
| # No match found, so we fall back to advancing past ToolMessages | ||
| assert middleware._find_safe_cutoff_point(messages, 3) == 4 | ||
|
|
||
| # Starting at the HumanMessage after tools returns that index | ||
|
|
@@ -698,6 +706,25 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None: | |
| assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5 | ||
|
|
||
|
|
||
| def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test passes on master
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We're changing the convention (summarizing less vs. more) I've added a test that fails on master |
||
| """Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`).""" | ||
| model = FakeToolCallingModel() | ||
| middleware = SummarizationMiddleware( | ||
| model=model, trigger=("messages", 10), keep=("messages", 2) | ||
| ) | ||
|
|
||
| # Messages where ToolMessage has no matching AIMessage at all | ||
| messages: list[AnyMessage] = [ | ||
| HumanMessage(content="msg1"), | ||
| AIMessage(content="ai_no_tools"), # No tool_calls | ||
| ToolMessage(content="orphan_result", tool_call_id="orphan_call"), | ||
| HumanMessage(content="msg2"), | ||
| ] | ||
|
|
||
| # Starting at orphan ToolMessage falls back to advancing forward | ||
| assert middleware._find_safe_cutoff_point(messages, 2) == 3 | ||
|
|
||
|
|
||
| def test_summarization_middleware_zero_and_negative_target_tokens() -> None: | ||
| """Test handling of edge cases with target token calculations.""" | ||
| # Test with very small fraction that rounds to zero | ||
|
|
@@ -813,7 +840,7 @@ def _llm_type(self) -> str: | |
|
|
||
|
|
||
| def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: | ||
| """Test cutoff safety with many parallel tool calls extending beyond old search range.""" | ||
| """Test cutoff safety preserves AI message with many parallel tool calls.""" | ||
| middleware = SummarizationMiddleware( | ||
| model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5) | ||
| ) | ||
|
|
@@ -825,20 +852,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None: | |
| ] | ||
| messages: list[AnyMessage] = [human_message, ai_message, *tool_messages] | ||
|
|
||
| # Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages) | ||
| assert middleware._find_safe_cutoff_point(messages, 7) == 12 | ||
| # Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage) | ||
| # to preserve the AI/Tool pair together | ||
| assert middleware._find_safe_cutoff_point(messages, 7) == 1 | ||
|
|
||
| # Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12 | ||
| # Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1 | ||
| for i in range(2, 12): | ||
| assert middleware._find_safe_cutoff_point(messages, i) == 12 | ||
| assert middleware._find_safe_cutoff_point(messages, i) == 1 | ||
|
|
||
| # Cutoff at index 0, 1 (before tool messages) stays the same | ||
| assert middleware._find_safe_cutoff_point(messages, 0) == 0 | ||
| assert middleware._find_safe_cutoff_point(messages, 1) == 1 | ||
|
|
||
|
|
||
| def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None: | ||
| """Test _find_safe_cutoff advances past ToolMessages to find safe cutoff.""" | ||
| def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None: | ||
| """Test `_find_safe_cutoff` preserves AI/Tool message pairs together.""" | ||
| middleware = SummarizationMiddleware( | ||
| model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3) | ||
| ) | ||
|
|
@@ -861,15 +889,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None | |
| ] | ||
|
|
||
| # Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3 | ||
| # Index 3 is a ToolMessage, so we advance past the tool sequence to index 5 | ||
| # Index 3 is a ToolMessage, we move back to index 1 to include AIMessage | ||
| cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3) | ||
| assert cutoff == 5 | ||
| assert cutoff == 1 | ||
|
|
||
| # With messages_to_keep=2, target cutoff index is 6 - 2 = 4 | ||
| # Index 4 is a ToolMessage, so we advance past the tool sequence to index 5 | ||
| # This is aggressive - we keep only 1 message instead of 2 | ||
| # Index 4 is a ToolMessage, we move back to index 1 to include AIMessage | ||
| # This preserves the AI + Tools + Human, more than requested but valid | ||
| cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2) | ||
| assert cutoff == 5 | ||
| assert cutoff == 1 | ||
|
|
||
|
|
||
| def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: | ||
|
|
@@ -891,3 +919,58 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None: | |
| # Index 2 is an AIMessage (safe cutoff point), so no adjustment needed | ||
| cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4) | ||
| assert cutoff == 2 | ||
|
|
||
|
|
||
| def test_create_summary_uses_get_buffer_string_format() -> None: | ||
| """Test that `_create_summary` formats messages using `get_buffer_string`. | ||
|
|
||
| Ensures that messages are formatted efficiently for the summary prompt, avoiding | ||
| token inflation from metadata when `str()` is called on message objects. | ||
|
|
||
| This ensures the token count of the formatted prompt stays below what | ||
| `count_tokens_approximately` estimates for the raw messages. | ||
| """ | ||
| # Create messages with metadata that would inflate str() representation | ||
| messages: list[AnyMessage] = [ | ||
| HumanMessage(content="What is the weather in NYC?"), | ||
| AIMessage( | ||
| content="Let me check the weather for you.", | ||
| tool_calls=[{"name": "get_weather", "args": {"city": "NYC"}, "id": "call_123"}], | ||
| usage_metadata={"input_tokens": 50, "output_tokens": 30, "total_tokens": 80}, | ||
| response_metadata={"model": "gpt-4", "finish_reason": "tool_calls"}, | ||
| ), | ||
| ToolMessage( | ||
| content="72F and sunny", | ||
| tool_call_id="call_123", | ||
| name="get_weather", | ||
| ), | ||
| AIMessage( | ||
| content="It is 72F and sunny in NYC!", | ||
| usage_metadata={ | ||
| "input_tokens": 100, | ||
| "output_tokens": 25, | ||
| "total_tokens": 125, | ||
| }, | ||
| response_metadata={"model": "gpt-4", "finish_reason": "stop"}, | ||
| ), | ||
| ] | ||
|
|
||
| # Verify the token ratio is favorable (get_buffer_string < str) | ||
| approx_tokens = count_tokens_approximately(messages) | ||
| buffer_string = get_buffer_string(messages) | ||
| buffer_tokens_estimate = len(buffer_string) / 4 # ~4 chars per token | ||
|
|
||
| # The ratio should be less than 1.0 (buffer_string uses fewer tokens than counted) | ||
| ratio = buffer_tokens_estimate / approx_tokens | ||
| assert ratio < 1.0, ( | ||
| f"get_buffer_string should produce fewer tokens than count_tokens_approximately. " | ||
| f"Got ratio {ratio:.2f}x (expected < 1.0)" | ||
| ) | ||
|
|
||
| # Verify str() would have been worse | ||
| str_tokens_estimate = len(str(messages)) / 4 | ||
| str_ratio = str_tokens_estimate / approx_tokens | ||
| assert str_ratio > 1.5, ( | ||
| f"str(messages) should produce significantly more tokens. " | ||
| f"Got ratio {str_ratio:.2f}x (expected > 1.5)" | ||
| ) | ||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also don't want to land on an AIMessage with tool calls, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous logic advanced forward past
ToolMessageobjects (aggressive summarization). This approach takes the opposite approach, searching backward to include theAIMessagerequesting the tools (in other words, preserve more context for the sake of atomicity)If the cutoff lands on an
AIMessagewithtool_calls, the correspondingToolMessageresponses would be in the summarized portion, creating the reverse orphaning problem — tool call requests without their responses. Landing on anAIMessagewithtool_callsis safe because theToolMessageobjects come after it and will be preserved together.