Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Literal, cast

from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
RemoveMessage,
Expand Down Expand Up @@ -478,13 +479,37 @@ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.

If the message at cutoff_index is a ToolMessage, advance until we find
a non-ToolMessage. This ensures we never cut in the middle of parallel
tool call responses.
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
`AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
include it. This ensures tool call requests and responses stay together.

Falls back to advancing forward past `ToolMessage` objects only if no matching
`AIMessage` is found (edge case).
"""
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
cutoff_index += 1
return cutoff_index
if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also don't want to land on an AIMessage with tool calls, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous logic advanced forward past ToolMessage objects (aggressive summarization). This approach takes the opposite approach, searching backward to include the AIMessage requesting the tools (in other words, preserve more context for the sake of atomicity)

If the cutoff lands on an AIMessage with tool_calls, the corresponding ToolMessage responses would be in the summarized portion, creating the reverse orphaning problem — tool call requests without their responses. Landing on an AIMessage with tool_calls is safe because the ToolMessage objects come after it and will be preserved together.

return cutoff_index

# Collect tool_call_ids from consecutive ToolMessages at/after cutoff
tool_call_ids: set[str] = set()
idx = cutoff_index
while idx < len(messages) and isinstance(messages[idx], ToolMessage):
tool_msg = cast("ToolMessage", messages[idx])
if tool_msg.tool_call_id:
tool_call_ids.add(tool_msg.tool_call_id)
idx += 1

# Search backward for AIMessage with matching tool_calls
for i in range(cutoff_index - 1, -1, -1):
msg = messages[i]
if isinstance(msg, AIMessage) and msg.tool_calls:
ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
if tool_call_ids & ai_tool_call_ids:
# Found the AIMessage - move cutoff to include it
return i

# Fallback: no matching AIMessage found, advance past ToolMessages to avoid
# orphaned tool responses
return idx

def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,8 @@ def token_counter(messages):
]


def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
"""Ensure token retention advances past tool messages for aggressive summarization."""
def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None:
"""Ensure token retention preserves AI/Tool message pairs together."""

def token_counter(messages: list[AnyMessage]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages)
Expand All @@ -297,7 +297,7 @@ def token_counter(messages: list[AnyMessage]) -> int:
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
# Target keep: 500 tokens (50% of 1000)
# Binary search finds cutoff around index 2 (ToolMessage)
# We advance past it to index 3 (HumanMessage)
# We move back to index 1 to preserve the AIMessage with its ToolMessage
messages: list[AnyMessage] = [
HumanMessage(content="H" * 300),
AIMessage(
Expand All @@ -314,14 +314,15 @@ def token_counter(messages: list[AnyMessage]) -> int:
assert result is not None

preserved_messages = result["messages"][2:]
# With aggressive summarization, we advance past the ToolMessage
# So we preserve messages from index 3 onward (the two HumanMessages)
assert preserved_messages == messages[3:]
# We move the cutoff back to include the AIMessage with its ToolMessage
# So we preserve messages from index 1 onward (AI + Tool + Human + Human)
assert preserved_messages == messages[1:]

# Verify preserved tokens are within budget
target_token_count = int(1000 * 0.5)
preserved_tokens = middleware.token_counter(preserved_messages)
assert preserved_tokens <= target_token_count
# Verify the AI/Tool pair is preserved together
assert isinstance(preserved_messages[0], AIMessage)
assert preserved_messages[0].tool_calls
assert isinstance(preserved_messages[1], ToolMessage)
assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"]


def test_summarization_middleware_missing_profile() -> None:
Expand Down Expand Up @@ -666,7 +667,7 @@ def token_counter_small(messages):


def test_summarization_middleware_find_safe_cutoff_point() -> None:
"""Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
"""Test `_find_safe_cutoff_point` preserves AI/Tool message pairs."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
Expand All @@ -676,16 +677,22 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
HumanMessage(content="msg1"),
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
ToolMessage(content="result1", tool_call_id="call1"),
ToolMessage(content="result2", tool_call_id="call2"),
ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI
HumanMessage(content="msg2"),
]

# Starting at a non-ToolMessage returns the same index
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1

# Starting at a ToolMessage advances to the next non-ToolMessage
assert middleware._find_safe_cutoff_point(messages, 2) == 4
# Starting at ToolMessage with matching AIMessage moves back to include it
# ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1
assert middleware._find_safe_cutoff_point(messages, 2) == 1

# Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing
# ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage
# Since we only collect from cutoff_index onwards, only {call2} is collected
# No match found, so we fall back to advancing past ToolMessages
assert middleware._find_safe_cutoff_point(messages, 3) == 4

# Starting at the HumanMessage after tools returns that index
Expand All @@ -699,6 +706,25 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5


def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes on master

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're changing the convention (summarizing less vs. more)

I've added a test that fails on master

"""Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`)."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
)

# Messages where ToolMessage has no matching AIMessage at all
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
AIMessage(content="ai_no_tools"), # No tool_calls
ToolMessage(content="orphan_result", tool_call_id="orphan_call"),
HumanMessage(content="msg2"),
]

# Starting at orphan ToolMessage falls back to advancing forward
assert middleware._find_safe_cutoff_point(messages, 2) == 3


def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
"""Test handling of edge cases with target token calculations."""
# Test with very small fraction that rounds to zero
Expand Down Expand Up @@ -814,7 +840,7 @@ def _llm_type(self) -> str:


def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
"""Test cutoff safety preserves AI message with many parallel tool calls."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
)
Expand All @@ -826,20 +852,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
]
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]

# Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
assert middleware._find_safe_cutoff_point(messages, 7) == 12
# Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage)
# to preserve the AI/Tool pair together
assert middleware._find_safe_cutoff_point(messages, 7) == 1

# Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
# Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1
for i in range(2, 12):
assert middleware._find_safe_cutoff_point(messages, i) == 12
assert middleware._find_safe_cutoff_point(messages, i) == 1

# Cutoff at index 0, 1 (before tool messages) stays the same
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1


def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
"""Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None:
"""Test `_find_safe_cutoff` preserves AI/Tool message pairs together."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
)
Expand All @@ -862,15 +889,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None
]

# Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
# Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
# Index 3 is a ToolMessage, we move back to index 1 to include AIMessage
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
assert cutoff == 5
assert cutoff == 1

# With messages_to_keep=2, target cutoff index is 6 - 2 = 4
# Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
# This is aggressive - we keep only 1 message instead of 2
# Index 4 is a ToolMessage, we move back to index 1 to include AIMessage
# This preserves the AI + Tools + Human, more than requested but valid
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
assert cutoff == 5
assert cutoff == 1


def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
Expand Down