Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions libs/langchain_v1/langchain/agents/middleware/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,49 @@ def _find_safe_cutoff_point(self, messages: list[AnyMessage], cutoff_index: int)
cutoff_index += 1
return cutoff_index

def _format_clean_history(self, messages: list[AnyMessage]) -> str:
"""Formats messages to a simple string (Role: Content) to save tokens.

This method strips away metadata (token usage, logprobs) and handles
multimodal content (like images) by extracting only the text.

Args:
messages: The list of conversation messages to format.

Returns:
A clean string representation of the conversation history.
"""
formatted_lines = []
for msg in messages:
# 1. Map standard message types to readable roles
role = msg.type
role_mapping = {
"human": "User",
"ai": "Assistant",
"system": "System",
"tool": "Tool Output",
"function": "Function Output",
"chat": "Chat",
}
display_role = role_mapping.get(role.lower(), role.capitalize())

# 2. Handle content (strings vs multimodal lists)
content = msg.content
if isinstance(content, list):
text_parts = []
for item in content:
if isinstance(item, str):
text_parts.append(item)
elif isinstance(item, dict):
text_parts.append(item.get("text", str(item)))
else:
text_parts.append(str(item))
content = " ".join(text_parts)

formatted_lines.append(f"{display_role}: {content}")

return "\n".join(formatted_lines)

def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary for the given messages."""
if not messages_to_summarize:
Expand All @@ -491,8 +534,11 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
if not trimmed_messages:
return "Previous conversation was too long to summarize."

# FIX: Convert rich objects to a plain string before formatting
clean_history = self._format_clean_history(trimmed_messages)

try:
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
response = self.model.invoke(self.summary_prompt.format(messages=clean_history))
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
Expand All @@ -506,10 +552,11 @@ async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str
if not trimmed_messages:
return "Previous conversation was too long to summarize."

# FIX: Convert rich objects to a plain string before formatting
clean_history = self._format_clean_history(trimmed_messages)

try:
response = await self.model.ainvoke(
self.summary_prompt.format(messages=trimmed_messages)
)
response = await self.model.ainvoke(self.summary_prompt.format(messages=clean_history))
return response.text.strip()
except Exception as e:
return f"Error generating summary: {e!s}"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import Any

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult

from langchain.agents.middleware.summarization import SummarizationMiddleware


# --- 1. Define a Mock Model for Testing ---
# This prevents us from needing real API keys (like OpenAI) for unit tests
class FakeChatModel(BaseChatModel):
def _generate(
self, messages: list[BaseMessage], stop: list[str] | None = None, **kwargs: Any
) -> ChatResult:
return ChatResult(generations=[ChatGeneration(message=AIMessage(content="Summary"))])

@property
def _llm_type(self) -> str:
return "fake-chat-model"


# --- 2. The Tests ---


def test_format_clean_history_basic():
"""Test that basic human/ai messages are formatted correctly."""
# Use the FakeChatModel instead of "gpt-3.5-turbo"
middleware = SummarizationMiddleware(model=FakeChatModel())

messages = [
HumanMessage(content="Hello"),
AIMessage(
content="Hi there", response_metadata={"token_usage": 100}
), # Metadata that should be stripped
]

result = middleware._format_clean_history(messages)

expected = "User: Hello\nAssistant: Hi there"
assert result == expected


def test_format_clean_history_tool_messages():
"""Test that tool messages are handled correctly."""
middleware = SummarizationMiddleware(model=FakeChatModel())

messages = [
HumanMessage(content="Search for apples"),
AIMessage(content="", tool_calls=[{"name": "search", "args": {}, "id": "123"}]),
ToolMessage(content="Found red apples", tool_call_id="123"),
]

result = middleware._format_clean_history(messages)

# We expect the tool output to be included clearly
assert "Tool Output: Found red apples" in result


def test_format_clean_history_multimodal():
"""Test that multimodal (list) content is flattened to string."""
middleware = SummarizationMiddleware(model=FakeChatModel())

# Simulate a GPT-4o multimodal message
multimodal_content = [
{"type": "text", "text": "Look at this image"},
{"type": "image_url", "image_url": {"url": "http://image.com"}},
]
messages = [HumanMessage(content=multimodal_content)]

result = middleware._format_clean_history(messages)

# It should extract "Look at this image" and handle the dict gracefully
assert "User: Look at this image" in result