Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion libs/core/langchain_core/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,9 @@ def text(self) -> TextAccessor:
Can be used as both property (`message.text`) and method (`message.text()`).
Handles both string and list content types (e.g. for content blocks). Only
extracts blocks with `type: 'text'`; other block types are ignored.
!!! deprecated
As of `langchain-core` 1.0.0, calling `.text()` as a method is deprecated.
Use `.text` as a property instead. This method will be removed in 2.0.0.
Expand All @@ -277,7 +280,7 @@ def text(self) -> TextAccessor:
if isinstance(self.content, str):
text_value = self.content
else:
# must be a list
# Must be a list
blocks = [
block
for block in self.content
Expand Down
3 changes: 3 additions & 0 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,16 @@ def get_buffer_string(
else:
msg = f"Got unsupported message type: {m}"
raise ValueError(msg) # noqa: TRY004

message = f"{role}: {m.text}"

if isinstance(m, AIMessage):
if m.tool_calls:
message += f"{m.tool_calls}"
elif "function_call" in m.additional_kwargs:
# Legacy behavior assumes only one function call per message
message += f"{m.additional_kwargs['function_call']}"

string_messages.append(message)

return "\n".join(string_messages)
Expand Down
55 changes: 46 additions & 9 deletions libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
from typing import Any, Literal, cast

from langchain_core.messages import (
AIMessage,
AnyMessage,
MessageLikeRepresentation,
RemoveMessage,
ToolMessage,
)
from langchain_core.messages.human import HumanMessage
from langchain_core.messages.utils import count_tokens_approximately, trim_messages
from langchain_core.messages.utils import (
count_tokens_approximately,
get_buffer_string,
trim_messages,
)
from langgraph.graph.message import (
REMOVE_ALL_MESSAGES,
)
Expand Down Expand Up @@ -474,13 +479,37 @@ def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -
def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int) -> int:
"""Find a safe cutoff point that doesn't split AI/Tool message pairs.

If the message at cutoff_index is a ToolMessage, advance until we find
a non-ToolMessage. This ensures we never cut in the middle of parallel
tool call responses.
If the message at `cutoff_index` is a `ToolMessage`, search backward for the
`AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
include it. This ensures tool call requests and responses stay together.

Falls back to advancing forward past `ToolMessage` objects only if no matching
`AIMessage` is found (edge case).
"""
while cutoff_index < len(messages) and isinstance(messages[cutoff_index], ToolMessage):
cutoff_index += 1
return cutoff_index
if cutoff_index >= len(messages) or not isinstance(messages[cutoff_index], ToolMessage):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also don't want to land on an AIMessage with tool calls, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous logic advanced forward past ToolMessage objects (aggressive summarization). This approach takes the opposite approach, searching backward to include the AIMessage requesting the tools (in other words, preserve more context for the sake of atomicity)

If the cutoff lands on an AIMessage with tool_calls, the corresponding ToolMessage responses would be in the summarized portion, creating the reverse orphaning problem — tool call requests without their responses. Landing on an AIMessage with tool_calls is safe because the ToolMessage objects come after it and will be preserved together.

return cutoff_index

# Collect tool_call_ids from consecutive ToolMessages at/after cutoff
tool_call_ids: set[str] = set()
idx = cutoff_index
while idx < len(messages) and isinstance(messages[idx], ToolMessage):
tool_msg = cast("ToolMessage", messages[idx])
if tool_msg.tool_call_id:
tool_call_ids.add(tool_msg.tool_call_id)
idx += 1

# Search backward for AIMessage with matching tool_calls
for i in range(cutoff_index - 1, -1, -1):
msg = messages[i]
if isinstance(msg, AIMessage) and msg.tool_calls:
ai_tool_call_ids = {tc.get("id") for tc in msg.tool_calls if tc.get("id")}
if tool_call_ids & ai_tool_call_ids:
# Found the AIMessage - move cutoff to include it
return i

# Fallback: no matching AIMessage found, advance past ToolMessages to avoid
# orphaned tool responses
return idx

def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""
Expand All @@ -491,8 +520,12 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
if not trimmed_messages:
return "Previous conversation was too long to summarize."

# Format messages to avoid token inflation from metadata when str() is called on
# message objects
formatted_messages = get_buffer_string(trimmed_messages)

try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
response = self.model.invoke(self.summary_prompt.format(messages=formatted_messages))
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
Expand All @@ -506,9 +539,13 @@ async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str
if not trimmed_messages:
return "Previous conversation was too long to summarize."

# Format messages to avoid token inflation from metadata when str() is called on
# message objects
formatted_messages = get_buffer_string(trimmed_messages)

try:
response = await self.model.ainvoke(
self.summary_prompt.format(messages=trimmed_messages)
self.summary_prompt.format(messages=formatted_messages)
)
return response.text.strip()
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.language_models import ModelProfile
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.messages.utils import count_tokens_approximately, get_buffer_string
from langchain_core.outputs import ChatGeneration, ChatResult
from langgraph.graph.message import REMOVE_ALL_MESSAGES

Expand Down Expand Up @@ -280,8 +281,8 @@ def token_counter(messages):
]


def test_summarization_middleware_token_retention_advances_past_tool_messages() -> None:
"""Ensure token retention advances past tool messages for aggressive summarization."""
def test_summarization_middleware_token_retention_preserves_ai_tool_pairs() -> None:
"""Ensure token retention preserves AI/Tool message pairs together."""

def token_counter(messages: list[AnyMessage]) -> int:
return sum(len(getattr(message, "content", "")) for message in messages)
Expand All @@ -296,7 +297,7 @@ def token_counter(messages: list[AnyMessage]) -> int:
# Total tokens: 300 + 200 + 50 + 180 + 160 = 890
# Target keep: 500 tokens (50% of 1000)
# Binary search finds cutoff around index 2 (ToolMessage)
# We advance past it to index 3 (HumanMessage)
# We move back to index 1 to preserve the AIMessage with its ToolMessage
messages: list[AnyMessage] = [
HumanMessage(content="H" * 300),
AIMessage(
Expand All @@ -313,14 +314,15 @@ def token_counter(messages: list[AnyMessage]) -> int:
assert result is not None

preserved_messages = result["messages"][2:]
# With aggressive summarization, we advance past the ToolMessage
# So we preserve messages from index 3 onward (the two HumanMessages)
assert preserved_messages == messages[3:]
# We move the cutoff back to include the AIMessage with its ToolMessage
# So we preserve messages from index 1 onward (AI + Tool + Human + Human)
assert preserved_messages == messages[1:]

# Verify preserved tokens are within budget
target_token_count = int(1000 * 0.5)
preserved_tokens = middleware.token_counter(preserved_messages)
assert preserved_tokens <= target_token_count
# Verify the AI/Tool pair is preserved together
assert isinstance(preserved_messages[0], AIMessage)
assert preserved_messages[0].tool_calls
assert isinstance(preserved_messages[1], ToolMessage)
assert preserved_messages[1].tool_call_id == preserved_messages[0].tool_calls[0]["id"]


def test_summarization_middleware_missing_profile() -> None:
Expand Down Expand Up @@ -665,7 +667,7 @@ def token_counter_small(messages):


def test_summarization_middleware_find_safe_cutoff_point() -> None:
"""Test _find_safe_cutoff_point finds safe cutoff past ToolMessages."""
"""Test `_find_safe_cutoff_point` preserves AI/Tool message pairs."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
Expand All @@ -675,16 +677,22 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
HumanMessage(content="msg1"),
AIMessage(content="ai", tool_calls=[{"name": "tool", "args": {}, "id": "call1"}]),
ToolMessage(content="result1", tool_call_id="call1"),
ToolMessage(content="result2", tool_call_id="call2"),
ToolMessage(content="result2", tool_call_id="call2"), # orphan - no matching AI
HumanMessage(content="msg2"),
]

# Starting at a non-ToolMessage returns the same index
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1

# Starting at a ToolMessage advances to the next non-ToolMessage
assert middleware._find_safe_cutoff_point(messages, 2) == 4
# Starting at ToolMessage with matching AIMessage moves back to include it
# ToolMessage at index 2 has tool_call_id="call1" which matches AIMessage at index 1
assert middleware._find_safe_cutoff_point(messages, 2) == 1

# Starting at orphan ToolMessage (no matching AIMessage) falls back to advancing
# ToolMessage at index 3 has tool_call_id="call2" with no matching AIMessage
# Since we only collect from cutoff_index onwards, only {call2} is collected
# No match found, so we fall back to advancing past ToolMessages
assert middleware._find_safe_cutoff_point(messages, 3) == 4

# Starting at the HumanMessage after tools returns that index
Expand All @@ -698,6 +706,25 @@ def test_summarization_middleware_find_safe_cutoff_point() -> None:
assert middleware._find_safe_cutoff_point(messages, len(messages) + 5) == len(messages) + 5


def test_summarization_middleware_find_safe_cutoff_point_orphan_tool() -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes on master

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're changing the convention (summarizing less vs. more)

I've added a test that fails on master

"""Test `_find_safe_cutoff_point` with truly orphan `ToolMessage` (no matching `AIMessage`)."""
model = FakeToolCallingModel()
middleware = SummarizationMiddleware(
model=model, trigger=("messages", 10), keep=("messages", 2)
)

# Messages where ToolMessage has no matching AIMessage at all
messages: list[AnyMessage] = [
HumanMessage(content="msg1"),
AIMessage(content="ai_no_tools"), # No tool_calls
ToolMessage(content="orphan_result", tool_call_id="orphan_call"),
HumanMessage(content="msg2"),
]

# Starting at orphan ToolMessage falls back to advancing forward
assert middleware._find_safe_cutoff_point(messages, 2) == 3


def test_summarization_middleware_zero_and_negative_target_tokens() -> None:
"""Test handling of edge cases with target token calculations."""
# Test with very small fraction that rounds to zero
Expand Down Expand Up @@ -813,7 +840,7 @@ def _llm_type(self) -> str:


def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
"""Test cutoff safety with many parallel tool calls extending beyond old search range."""
"""Test cutoff safety preserves AI message with many parallel tool calls."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 15), keep=("messages", 5)
)
Expand All @@ -825,20 +852,21 @@ def test_summarization_middleware_many_parallel_tool_calls_safety() -> None:
]
messages: list[AnyMessage] = [human_message, ai_message, *tool_messages]

# Cutoff at index 7 (a ToolMessage) advances to index 12 (end of messages)
assert middleware._find_safe_cutoff_point(messages, 7) == 12
# Cutoff at index 7 (a ToolMessage) moves back to index 1 (AIMessage)
# to preserve the AI/Tool pair together
assert middleware._find_safe_cutoff_point(messages, 7) == 1

# Any cutoff pointing at a ToolMessage (indices 2-11) advances to index 12
# Any cutoff pointing at a ToolMessage (indices 2-11) moves back to index 1
for i in range(2, 12):
assert middleware._find_safe_cutoff_point(messages, i) == 12
assert middleware._find_safe_cutoff_point(messages, i) == 1

# Cutoff at index 0, 1 (before tool messages) stays the same
assert middleware._find_safe_cutoff_point(messages, 0) == 0
assert middleware._find_safe_cutoff_point(messages, 1) == 1


def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None:
"""Test _find_safe_cutoff advances past ToolMessages to find safe cutoff."""
def test_summarization_middleware_find_safe_cutoff_preserves_ai_tool_pair() -> None:
"""Test `_find_safe_cutoff` preserves AI/Tool message pairs together."""
middleware = SummarizationMiddleware(
model=MockChatModel(), trigger=("messages", 10), keep=("messages", 3)
)
Expand All @@ -861,15 +889,15 @@ def test_summarization_middleware_find_safe_cutoff_advances_past_tools() -> None
]

# Target cutoff index is len(messages) - messages_to_keep = 6 - 3 = 3
# Index 3 is a ToolMessage, so we advance past the tool sequence to index 5
# Index 3 is a ToolMessage, we move back to index 1 to include AIMessage
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=3)
assert cutoff == 5
assert cutoff == 1

# With messages_to_keep=2, target cutoff index is 6 - 2 = 4
# Index 4 is a ToolMessage, so we advance past the tool sequence to index 5
# This is aggressive - we keep only 1 message instead of 2
# Index 4 is a ToolMessage, we move back to index 1 to include AIMessage
# This preserves the AI + Tools + Human, more than requested but valid
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=2)
assert cutoff == 5
assert cutoff == 1


def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
Expand All @@ -891,3 +919,58 @@ def test_summarization_middleware_cutoff_at_start_of_tool_sequence() -> None:
# Index 2 is an AIMessage (safe cutoff point), so no adjustment needed
cutoff = middleware._find_safe_cutoff(messages, messages_to_keep=4)
assert cutoff == 2


def test_create_summary_uses_get_buffer_string_format() -> None:
"""Test that `_create_summary` formats messages using `get_buffer_string`.

Ensures that messages are formatted efficiently for the summary prompt, avoiding
token inflation from metadata when `str()` is called on message objects.

This ensures the token count of the formatted prompt stays below what
`count_tokens_approximately` estimates for the raw messages.
"""
# Create messages with metadata that would inflate str() representation
messages: list[AnyMessage] = [
HumanMessage(content="What is the weather in NYC?"),
AIMessage(
content="Let me check the weather for you.",
tool_calls=[{"name": "get_weather", "args": {"city": "NYC"}, "id": "call_123"}],
usage_metadata={"input_tokens": 50, "output_tokens": 30, "total_tokens": 80},
response_metadata={"model": "gpt-4", "finish_reason": "tool_calls"},
),
ToolMessage(
content="72F and sunny",
tool_call_id="call_123",
name="get_weather",
),
AIMessage(
content="It is 72F and sunny in NYC!",
usage_metadata={
"input_tokens": 100,
"output_tokens": 25,
"total_tokens": 125,
},
response_metadata={"model": "gpt-4", "finish_reason": "stop"},
),
]

# Verify the token ratio is favorable (get_buffer_string < str)
approx_tokens = count_tokens_approximately(messages)
buffer_string = get_buffer_string(messages)
buffer_tokens_estimate = len(buffer_string) / 4 # ~4 chars per token

# The ratio should be less than 1.0 (buffer_string uses fewer tokens than counted)
ratio = buffer_tokens_estimate / approx_tokens
assert ratio < 1.0, (
f"get_buffer_string should produce fewer tokens than count_tokens_approximately. "
f"Got ratio {ratio:.2f}x (expected < 1.0)"
)

# Verify str() would have been worse
str_tokens_estimate = len(str(messages)) / 4
str_ratio = str_tokens_estimate / approx_tokens
assert str_ratio > 1.5, (
f"str(messages) should produce significantly more tokens. "
f"Got ratio {str_ratio:.2f}x (expected > 1.5)"
)
10 changes: 5 additions & 5 deletions libs/langchain_v1/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.