diff --git a/.changeset/thin-cherries-carry.md b/.changeset/thin-cherries-carry.md new file mode 100644 index 000000000000..1bafae08ceab --- /dev/null +++ b/.changeset/thin-cherries-carry.md @@ -0,0 +1,5 @@ +--- +"langchain": patch +--- + +fix(langchain): keep tool call / AIMessage pairings when summarizing diff --git a/libs/langchain/src/agents/middleware/summarization.ts b/libs/langchain/src/agents/middleware/summarization.ts index 503cb0383060..dc9820ec4792 100644 --- a/libs/langchain/src/agents/middleware/summarization.ts +++ b/libs/langchain/src/agents/middleware/summarization.ts @@ -681,7 +681,21 @@ async function findTokenBasedCutoff( } /** - * Find safe cutoff point that preserves tool pairs + * Find safe cutoff point that preserves AI/Tool pairs. + * If cutoff lands on ToolMessage, move backward to include the AIMessage. + */ + const safeCutoff = findSafeCutoffPoint(messages, cutoffCandidate); + + /** + * If findSafeCutoffPoint moved forward (fallback case), verify it's safe. + * If it moved backward, we already have a safe point. + */ + if (safeCutoff <= cutoffCandidate) { + return safeCutoff; + } + + /** + * Fallback: iterate backward to find a safe cutoff */ for (let i = cutoffCandidate; i >= 0; i--) { if (isSafeCutoffPoint(messages, i)) { @@ -705,6 +719,23 @@ function findSafeCutoff( const targetCutoff = messages.length - messagesToKeep; + /** + * First, try to find a safe cutoff point using findSafeCutoffPoint. + * This handles the case where cutoff lands on a ToolMessage by moving + * backward to include the corresponding AIMessage. + */ + const safeCutoff = findSafeCutoffPoint(messages, targetCutoff); + + /** + * If findSafeCutoffPoint moved backward (found matching AIMessage), use it. + */ + if (safeCutoff <= targetCutoff) { + return safeCutoff; + } + + /** + * Fallback: iterate backward to find a safe cutoff + */ for (let i = targetCutoff; i >= 0; i--) { if (isSafeCutoffPoint(messages, i)) { return i; @@ -773,6 +804,58 @@ function extractToolCallIds(aiMessage: AIMessage): Set { return toolCallIds; } +/** + * Find a safe cutoff point that doesn't split AI/Tool message pairs. + * + * If the message at `cutoffIndex` is a `ToolMessage`, search backward for the + * `AIMessage` containing the corresponding `tool_calls` and adjust the cutoff to + * include it. This ensures tool call requests and responses stay together. + * + * Falls back to advancing forward past `ToolMessage` objects only if no matching + * `AIMessage` is found (edge case). + */ +function findSafeCutoffPoint( + messages: BaseMessage[], + cutoffIndex: number +): number { + if ( + cutoffIndex >= messages.length || + !ToolMessage.isInstance(messages[cutoffIndex]) + ) { + return cutoffIndex; + } + + // Collect tool_call_ids from consecutive ToolMessages at/after cutoff + const toolCallIds = new Set(); + let idx = cutoffIndex; + while (idx < messages.length && ToolMessage.isInstance(messages[idx])) { + const toolMsg = messages[idx] as ToolMessage; + if (toolMsg.tool_call_id) { + toolCallIds.add(toolMsg.tool_call_id); + } + idx++; + } + + // Search backward for AIMessage with matching tool_calls + for (let i = cutoffIndex - 1; i >= 0; i--) { + const msg = messages[i]; + if (AIMessage.isInstance(msg) && hasToolCalls(msg)) { + const aiToolCallIds = extractToolCallIds(msg as AIMessage); + // Check if there's any overlap between the tool_call_ids + for (const id of toolCallIds) { + if (aiToolCallIds.has(id)) { + // Found the AIMessage - move cutoff to include it + return i; + } + } + } + } + + // Fallback: no matching AIMessage found, advance past ToolMessages to avoid + // orphaned tool responses + return idx; +} + /** * Check if cutoff separates an AI message from its corresponding tool messages */ diff --git a/libs/langchain/src/agents/middleware/tests/summarization.test.ts b/libs/langchain/src/agents/middleware/tests/summarization.test.ts index 9502145b34b9..4a66c5da0a3d 100644 --- a/libs/langchain/src/agents/middleware/tests/summarization.test.ts +++ b/libs/langchain/src/agents/middleware/tests/summarization.test.ts @@ -989,4 +989,202 @@ describe("summarizationMiddleware", () => { // Verify the main model's content DOES appear in the stream expect(allStreamedAIContent).toContain(MAIN_MODEL_CONTENT); }); + + it("should move cutoff backward to preserve AI/Tool pairs when cutoff lands on ToolMessage", async () => { + const summarizationModel = createMockSummarizationModel(); + + const model = new FakeToolCallingChatModel({ + responses: [new AIMessage("Final response")], + }); + + const middleware = summarizationMiddleware({ + model: summarizationModel as any, + trigger: { tokens: 100 }, + keep: { tokens: 150 }, // Token budget that would land cutoff on ToolMessage + }); + + const agent = createAgent({ + model, + middleware: [middleware], + }); + + // Create messages where aggressive summarization would normally land on ToolMessage + // Structure: HumanMessage(long) -> AIMessage(with tool_calls) -> ToolMessage -> HumanMessage -> HumanMessage + const messages = [ + new HumanMessage("x".repeat(300)), // ~75 tokens + new AIMessage({ + content: "y".repeat(200), // ~50 tokens + tool_calls: [ + { id: "call_preserve", name: "test_tool", args: { test: true } }, + ], + }), + new ToolMessage({ + content: "z".repeat(50), // ~12 tokens + tool_call_id: "call_preserve", + name: "test_tool", + }), + new HumanMessage("a".repeat(180)), // ~45 tokens + new HumanMessage("b".repeat(160)), // ~40 tokens + ]; + // Total: ~222 tokens, keep ~150 tokens + // In case of cutoff landing on a ToolMessage, the middleware should move the + // cutoff backward to include the AIMessage that contains the matching tool_calls. + + const result = await agent.invoke({ messages }); + + // Find the preserved messages (after summary) + const summaryIndex = result.messages.findIndex( + (msg) => + HumanMessage.isInstance(msg) && + typeof msg.content === "string" && + msg.content.includes("Here is a summary") + ); + + const preservedMessages = result.messages.slice(summaryIndex + 1); + + // The AIMessage with tool_calls should be preserved along with its ToolMessage + const hasAIWithToolCalls = preservedMessages.some( + (msg) => + AIMessage.isInstance(msg) && msg.tool_calls && msg.tool_calls.length > 0 + ); + const hasMatchingToolMessage = preservedMessages.some( + (msg) => + ToolMessage.isInstance(msg) && msg.tool_call_id === "call_preserve" + ); + + // Both must be present - the AI/Tool pair should be kept together + expect(hasAIWithToolCalls).toBe(true); + expect(hasMatchingToolMessage).toBe(true); + }); + + it("should handle orphan ToolMessage by advancing forward", async () => { + /** + * Edge case: If a ToolMessage has no matching AIMessage (orphan), + * the middleware should fall back to advancing past ToolMessages. + */ + const summarizationModel = createMockSummarizationModel(); + + const model = new FakeToolCallingChatModel({ + responses: [new AIMessage("Final response")], + }); + + const middleware = summarizationMiddleware({ + model: summarizationModel as any, + trigger: { tokens: 50 }, + keep: { messages: 2 }, + }); + + const agent = createAgent({ + model, + middleware: [middleware], + }); + + // Create messages with an orphan ToolMessage (no matching AIMessage) + const messages = [ + new HumanMessage("x".repeat(200)), + new AIMessage("No tool calls here"), // No tool_calls + new ToolMessage({ + content: "Orphan result", + tool_call_id: "orphan_call", // No matching AIMessage + name: "orphan_tool", + }), + new HumanMessage("y".repeat(200)), + new HumanMessage("Final question"), + ]; + + const result = await agent.invoke({ messages }); + + // Verify we don't crash and the conversation continues + expect(result.messages.length).toBeGreaterThan(0); + }); + + it("should preserve many parallel tool calls together with AIMessage", async () => { + /** + * Port of Python test: test_summarization_middleware_many_parallel_tool_calls_safety + * + * When an AIMessage has many parallel tool calls (e.g., reading 10 files), + * all corresponding ToolMessages should be preserved along with the AIMessage. + */ + const summarizationModel = createMockSummarizationModel(); + + const model = new FakeToolCallingChatModel({ + responses: [new AIMessage("All files read and summarized")], + }); + + const middleware = summarizationMiddleware({ + model: summarizationModel as any, + trigger: { tokens: 100 }, + keep: { messages: 5 }, // This would normally cut in the middle of tool responses + }); + + const agent = createAgent({ + model, + middleware: [middleware], + }); + + // Create 10 parallel tool calls + const toolCalls = Array.from({ length: 10 }, (_, i) => ({ + id: `call_${i}`, + name: "read_file", + args: { file: `file${i}.txt` }, + })); + + const aiMessage = new AIMessage({ + content: "I'll read all 10 files", + tool_calls: toolCalls, + }); + + const toolMessages = toolCalls.map( + (tc) => + new ToolMessage({ + content: `Contents of ${tc.args.file}`, + tool_call_id: tc.id, + name: tc.name, + }) + ); + + const messages = [ + new HumanMessage("x".repeat(500)), // Long message to trigger summarization + aiMessage, + ...toolMessages, + new HumanMessage("Now summarize them"), + ]; + + const result = await agent.invoke({ messages }); + + // Find preserved messages + const summaryIndex = result.messages.findIndex( + (msg) => + HumanMessage.isInstance(msg) && + typeof msg.content === "string" && + msg.content.includes("Here is a summary") + ); + + if (summaryIndex === -1) { + // Summarization might not have triggered, that's fine + return; + } + + const preservedMessages = result.messages.slice(summaryIndex + 1); + + // If the AIMessage with tool_calls is preserved, all its ToolMessages should be too + const preservedAI = preservedMessages.find( + (msg) => + AIMessage.isInstance(msg) && msg.tool_calls && msg.tool_calls.length > 0 + ); + + if (preservedAI && AIMessage.isInstance(preservedAI)) { + // Count preserved ToolMessages that match this AI's tool_calls + const aiToolCallIds = new Set( + preservedAI.tool_calls?.map((tc) => tc.id) ?? [] + ); + const matchingToolMessages = preservedMessages.filter( + (msg) => + ToolMessage.isInstance(msg) && aiToolCallIds.has(msg.tool_call_id) + ); + + // All matching tool messages should be preserved + expect(matchingToolMessages.length).toBe(preservedAI.tool_calls?.length); + } + }); });