Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/thin-cherries-carry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"langchain": patch
---

fix(langchain): keep tool call / AIMessage pairings when summarizing
85 changes: 84 additions & 1 deletion libs/langchain/src/agents/middleware/summarization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,21 @@ async function findTokenBasedCutoff(
}

/**
* Find safe cutoff point that preserves tool pairs
* Find safe cutoff point that preserves AI/Tool pairs.
* If cutoff lands on ToolMessage, move backward to include the AIMessage.
*/
const safeCutoff = findSafeCutoffPoint(messages, cutoffCandidate);

/**
* If findSafeCutoffPoint moved forward (fallback case), verify it's safe.
* If it moved backward, we already have a safe point.
*/
if (safeCutoff <= cutoffCandidate) {
return safeCutoff;
}

/**
* Fallback: iterate backward to find a safe cutoff
*/
for (let i = cutoffCandidate; i >= 0; i--) {
if (isSafeCutoffPoint(messages, i)) {
Expand All @@ -705,6 +719,23 @@ function findSafeCutoff(

const targetCutoff = messages.length - messagesToKeep;

/**
* First, try to find a safe cutoff point using findSafeCutoffPoint.
* This handles the case where cutoff lands on a ToolMessage by moving
* backward to include the corresponding AIMessage.
*/
const safeCutoff = findSafeCutoffPoint(messages, targetCutoff);

/**
* If findSafeCutoffPoint moved backward (found matching AIMessage), use it.
*/
if (safeCutoff <= targetCutoff) {
return safeCutoff;
}

/**
* Fallback: iterate backward to find a safe cutoff
*/
for (let i = targetCutoff; i >= 0; i--) {
if (isSafeCutoffPoint(messages, i)) {
return i;
Expand Down Expand Up @@ -773,6 +804,58 @@ function extractToolCallIds(aiMessage: AIMessage): Set<string> {
return toolCallIds;
}

/**
* Find a safe cutoff point that doesn't split AI/Tool message pairs.
*
* If the message at `cutoffIndex` is a `ToolMessage`, search backward for the
* `AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to
* include it. This ensures tool call requests and responses stay together.
*
* Falls back to advancing forward past `ToolMessage` objects only if no matching
* `AIMessage` is found (edge case).
*/
function findSafeCutoffPoint(
messages: BaseMessage[],
cutoffIndex: number
): number {
if (
cutoffIndex >= messages.length ||
!ToolMessage.isInstance(messages[cutoffIndex])
) {
return cutoffIndex;
}

// Collect tool_call_ids from consecutive ToolMessages at/after cutoff
const toolCallIds = new Set<string>();
let idx = cutoffIndex;
while (idx < messages.length && ToolMessage.isInstance(messages[idx])) {
const toolMsg = messages[idx] as ToolMessage;
if (toolMsg.tool_call_id) {
toolCallIds.add(toolMsg.tool_call_id);
}
idx++;
}

// Search backward for AIMessage with matching tool_calls
for (let i = cutoffIndex - 1; i >= 0; i--) {
const msg = messages[i];
if (AIMessage.isInstance(msg) && hasToolCalls(msg)) {
const aiToolCallIds = extractToolCallIds(msg as AIMessage);
// Check if there's any overlap between the tool_call_ids
for (const id of toolCallIds) {
if (aiToolCallIds.has(id)) {
// Found the AIMessage - move cutoff to include it
return i;
}
}
}
}

// Fallback: no matching AIMessage found, advance past ToolMessages to avoid
// orphaned tool responses
return idx;
}

/**
* Check if cutoff separates an AI message from its corresponding tool messages
*/
Expand Down
198 changes: 198 additions & 0 deletions libs/langchain/src/agents/middleware/tests/summarization.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -989,4 +989,202 @@ describe("summarizationMiddleware", () => {
// Verify the main model's content DOES appear in the stream
expect(allStreamedAIContent).toContain(MAIN_MODEL_CONTENT);
});

it("should move cutoff backward to preserve AI/Tool pairs when cutoff lands on ToolMessage", async () => {
const summarizationModel = createMockSummarizationModel();

const model = new FakeToolCallingChatModel({
responses: [new AIMessage("Final response")],
});

const middleware = summarizationMiddleware({
model: summarizationModel as any,
trigger: { tokens: 100 },
keep: { tokens: 150 }, // Token budget that would land cutoff on ToolMessage
});

const agent = createAgent({
model,
middleware: [middleware],
});

// Create messages where aggressive summarization would normally land on ToolMessage
// Structure: HumanMessage(long) -> AIMessage(with tool_calls) -> ToolMessage -> HumanMessage -> HumanMessage
const messages = [
new HumanMessage("x".repeat(300)), // ~75 tokens
new AIMessage({
content: "y".repeat(200), // ~50 tokens
tool_calls: [
{ id: "call_preserve", name: "test_tool", args: { test: true } },
],
}),
new ToolMessage({
content: "z".repeat(50), // ~12 tokens
tool_call_id: "call_preserve",
name: "test_tool",
}),
new HumanMessage("a".repeat(180)), // ~45 tokens
new HumanMessage("b".repeat(160)), // ~40 tokens
];
// Total: ~222 tokens, keep ~150 tokens
// In case of cutoff landing on a ToolMessage, the middleware should move the
// cutoff backward to include the AIMessage that contains the matching tool_calls.

const result = await agent.invoke({ messages });

// Find the preserved messages (after summary)
const summaryIndex = result.messages.findIndex(
(msg) =>
HumanMessage.isInstance(msg) &&
typeof msg.content === "string" &&
msg.content.includes("Here is a summary")
);

const preservedMessages = result.messages.slice(summaryIndex + 1);

// The AIMessage with tool_calls should be preserved along with its ToolMessage
const hasAIWithToolCalls = preservedMessages.some(
(msg) =>
AIMessage.isInstance(msg) && msg.tool_calls && msg.tool_calls.length > 0
);
const hasMatchingToolMessage = preservedMessages.some(
(msg) =>
ToolMessage.isInstance(msg) && msg.tool_call_id === "call_preserve"
);

// Both must be present - the AI/Tool pair should be kept together
expect(hasAIWithToolCalls).toBe(true);
expect(hasMatchingToolMessage).toBe(true);
});

it("should handle orphan ToolMessage by advancing forward", async () => {
/**
* Edge case: If a ToolMessage has no matching AIMessage (orphan),
* the middleware should fall back to advancing past ToolMessages.
*/
const summarizationModel = createMockSummarizationModel();

const model = new FakeToolCallingChatModel({
responses: [new AIMessage("Final response")],
});

const middleware = summarizationMiddleware({
model: summarizationModel as any,
trigger: { tokens: 50 },
keep: { messages: 2 },
});

const agent = createAgent({
model,
middleware: [middleware],
});

// Create messages with an orphan ToolMessage (no matching AIMessage)
const messages = [
new HumanMessage("x".repeat(200)),
new AIMessage("No tool calls here"), // No tool_calls
new ToolMessage({
content: "Orphan result",
tool_call_id: "orphan_call", // No matching AIMessage
name: "orphan_tool",
}),
new HumanMessage("y".repeat(200)),
new HumanMessage("Final question"),
];

const result = await agent.invoke({ messages });

// Verify we don't crash and the conversation continues
expect(result.messages.length).toBeGreaterThan(0);
});

it("should preserve many parallel tool calls together with AIMessage", async () => {
/**
* Port of Python test: test_summarization_middleware_many_parallel_tool_calls_safety
*
* When an AIMessage has many parallel tool calls (e.g., reading 10 files),
* all corresponding ToolMessages should be preserved along with the AIMessage.
*/
const summarizationModel = createMockSummarizationModel();

const model = new FakeToolCallingChatModel({
responses: [new AIMessage("All files read and summarized")],
});

const middleware = summarizationMiddleware({
model: summarizationModel as any,
trigger: { tokens: 100 },
keep: { messages: 5 }, // This would normally cut in the middle of tool responses
});

const agent = createAgent({
model,
middleware: [middleware],
});

// Create 10 parallel tool calls
const toolCalls = Array.from({ length: 10 }, (_, i) => ({
id: `call_${i}`,
name: "read_file",
args: { file: `file${i}.txt` },
}));

const aiMessage = new AIMessage({
content: "I'll read all 10 files",
tool_calls: toolCalls,
});

const toolMessages = toolCalls.map(
(tc) =>
new ToolMessage({
content: `Contents of ${tc.args.file}`,
tool_call_id: tc.id,
name: tc.name,
})
);

const messages = [
new HumanMessage("x".repeat(500)), // Long message to trigger summarization
aiMessage,
...toolMessages,
new HumanMessage("Now summarize them"),
];

const result = await agent.invoke({ messages });

// Find preserved messages
const summaryIndex = result.messages.findIndex(
(msg) =>
HumanMessage.isInstance(msg) &&
typeof msg.content === "string" &&
msg.content.includes("Here is a summary")
);

if (summaryIndex === -1) {
// Summarization might not have triggered, that's fine
return;
}

const preservedMessages = result.messages.slice(summaryIndex + 1);

// If the AIMessage with tool_calls is preserved, all its ToolMessages should be too
const preservedAI = preservedMessages.find(
(msg) =>
AIMessage.isInstance(msg) && msg.tool_calls && msg.tool_calls.length > 0
);

if (preservedAI && AIMessage.isInstance(preservedAI)) {
// Count preserved ToolMessages that match this AI's tool_calls
const aiToolCallIds = new Set(
preservedAI.tool_calls?.map((tc) => tc.id) ?? []
);
const matchingToolMessages = preservedMessages.filter(
(msg) =>
ToolMessage.isInstance(msg) && aiToolCallIds.has(msg.tool_call_id)
);

// All matching tool messages should be preserved
expect(matchingToolMessages.length).toBe(preservedAI.tool_calls?.length);
}
});
});
Loading