diff --git a/examples/src/createAgent/middleware/toolCallLimit.ts b/examples/src/createAgent/middleware/toolCallLimit.ts new file mode 100644 index 000000000000..1bd4210c4089 --- /dev/null +++ b/examples/src/createAgent/middleware/toolCallLimit.ts @@ -0,0 +1,99 @@ +/** + * Basic example demonstrating tool call limit middleware. + * + * This middleware helps prevent infinite loops or excessive tool usage + * by limiting the number of tool calls an agent can make. + */ + +import { z } from "zod"; +import { createAgent, tool, toolCallLimitMiddleware } from "langchain"; +import { ChatOpenAI } from "@langchain/openai"; +import { HumanMessage } from "@langchain/core/messages"; +import { MemorySaver } from "@langchain/langgraph"; +const config = { configurable: { thread_id: "demo-thread" } }; + +/** + * Define a simple search tool + */ +const searchTool = tool( + async ({ query }) => { + console.log(`Searching for: ${query}`); + return `Results for: ${query}`; + }, + { + name: "search", + description: "Search for information", + schema: z.object({ + query: z.string(), + }), + } +); + +/** + * Create an agent with a tool call limit + */ +const agent = createAgent({ + model: new ChatOpenAI({ model: "gpt-4o-mini" }), + tools: [searchTool], + checkpointer: new MemorySaver(), + middleware: [ + /** + * Limit to 3 tool calls per conversation + */ + toolCallLimitMiddleware({ + threadLimit: 3, + /** + * Gracefully end when limit is reached + */ + exitBehavior: "end", + }), + ], +}); + +/** + * Example conversation that would exceed the limit + */ +const result = await agent.invoke( + { + messages: [ + new HumanMessage( + "Search for 'AI', 'ML', 'Deep Learning', 'Neural Networks', and 'LLMs'" + ), + ], + }, + { configurable: { thread_id: "demo-thread" } } +); + +console.log("\nAgent response:"); +console.log(result.messages[result.messages.length - 1].content); + +/** + * Create an agent with a tool call limit + */ +const agent2 = createAgent({ + model: new ChatOpenAI({ model: "gpt-4o-mini" }), + tools: [searchTool], + checkpointer: new MemorySaver(), + middleware: [ + /** + * Limit to 3 tool calls per conversation + */ + toolCallLimitMiddleware({ + threadLimit: 3, + /** + * Gracefully end when limit is reached + */ + exitBehavior: "end", + }), + ], +}); + +const result2 = await agent.invoke( + { + messages: [new HumanMessage("Search for 'AI' and 'ML'")], + }, + { configurable: { thread_id: "demo-thread" } } +); + +console.log("\nAgent response:"); +console.log(result2.messages[result2.messages.length - 1].content); diff --git a/libs/langchain/src/agents/middlewareAgent/middleware/index.ts b/libs/langchain/src/agents/middlewareAgent/middleware/index.ts index 64943a46f17b..9adf708220b9 100644 --- a/libs/langchain/src/agents/middlewareAgent/middleware/index.ts +++ b/libs/langchain/src/agents/middlewareAgent/middleware/index.ts @@ -26,6 +26,11 @@ export { piiRedactionMiddleware, type PIIRedactionMiddlewareConfig, } from "./piiRedaction.js"; +export { + toolCallLimitMiddleware, + ToolCallLimitExceededError, + type ToolCallLimitConfig, +} from "./toolCallLimit.js"; export { modelCallLimitMiddleware, type ModelCallLimitMiddlewareConfig, diff --git a/libs/langchain/src/agents/middlewareAgent/middleware/tests/toolCallLimit.test.ts b/libs/langchain/src/agents/middlewareAgent/middleware/tests/toolCallLimit.test.ts new file mode 100644 index 000000000000..ff35d7c443a6 --- /dev/null +++ b/libs/langchain/src/agents/middlewareAgent/middleware/tests/toolCallLimit.test.ts @@ -0,0 +1,731 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +/* eslint-disable no-instanceof/no-instanceof */ +import { describe, it, expect } from "vitest"; +import { z } from "zod/v3"; +import { tool } from "@langchain/core/tools"; +import { + HumanMessage, + AIMessage, + ToolMessage, + type BaseMessage, + type ToolCall, +} from "@langchain/core/messages"; +import { MemorySaver } from "@langchain/langgraph"; + +import { + toolCallLimitMiddleware, + ToolCallLimitExceededError, +} from "../toolCallLimit.js"; +import { createAgent } from "../../index.js"; +import { FakeToolCallingChatModel } from "../../../tests/utils.js"; + +describe("toolCallLimitMiddleware", () => { + // Helper to create test tools + const searchTool = tool(async ({ query }) => `Results for: ${query}`, { + name: "search", + description: "Search for information", + schema: z.object({ + query: z.string(), + }), + }); + + const calculatorTool = tool( + async ({ expression }) => `Result: ${expression}`, + { + name: "calculator", + description: "Calculate an expression", + schema: z.object({ + expression: z.string(), + }), + } + ); + + describe("Initialization and Validation", () => { + it("should throw error if no limits are specified", () => { + expect(() => + toolCallLimitMiddleware({ + // No threadLimit or runLimit + } as any) + ).toThrow("At least one limit must be specified"); + }); + + it("should throw error for invalid exit behavior", () => { + expect(() => + toolCallLimitMiddleware({ + threadLimit: 5, + exitBehavior: "invalid" as any, + }) + ).toThrow("Invalid exit behavior: invalid"); + }); + + it("should generate correct middleware name without tool name", () => { + const middleware = toolCallLimitMiddleware({ + threadLimit: 5, + }); + expect(middleware.name).toBe("ToolCallLimitMiddleware"); + }); + + it("should generate correct middleware name with tool name", () => { + const middleware = toolCallLimitMiddleware({ + toolName: "search", + threadLimit: 5, + }); + expect(middleware.name).toBe("ToolCallLimitMiddleware[search]"); + }); + }); + + describe("Thread-level Limits", () => { + it("should allow tool calls under thread limit", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [{ id: "1", name: "search", args: { query: "test1" } }], + }), + new AIMessage("Response after tool call"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + threadLimit: 5, + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Search for something")], + }); + + // Should complete successfully + expect(result.messages.length).toBeGreaterThan(0); + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).not.toContain("thread limit"); + }); + + it("should terminate when thread limit is exceeded", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "search", args: { query: "test2" } }, + ], + }), + new AIMessage({ + content: "", + tool_calls: [{ id: "3", name: "search", args: { query: "test3" } }], + }), + new AIMessage("Should not reach here"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + threadLimit: 3, + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Search for things")], + }); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage).toBeInstanceOf(AIMessage); + expect(lastMessage.content).toContain("thread limit reached (3/3)"); + }); + + it("should persist thread count across multiple runs", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + // First run: 2 tool calls + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "calculator", args: { expression: "1+1" } }, + ], + }), + new AIMessage("First run response"), + // Second run: 2 more tool calls (total: 4, at limit) + new AIMessage({ + content: "", + tool_calls: [ + { id: "3", name: "search", args: { query: "test2" } }, + { id: "4", name: "calculator", args: { expression: "2+2" } }, + ], + }), + new AIMessage("Second run response"), + // Third run: would exceed limit + new AIMessage("Should be blocked"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + threadLimit: 4, + exitBehavior: "end", + }); + + const checkpointer = new MemorySaver(); + const agent = createAgent({ + model, + tools: [searchTool, calculatorTool], + middleware: [middleware] as const, + checkpointer, + }); + + const threadConfig = { configurable: { thread_id: "test-thread" } }; + + // First run + await agent.invoke( + { messages: [new HumanMessage("First question")] }, + threadConfig + ); + + // Second run + await agent.invoke( + { messages: [new HumanMessage("Second question")] }, + threadConfig + ); + + const agent2 = createAgent({ + model, + tools: [searchTool, calculatorTool], + middleware: [middleware] as const, + checkpointer, + }); + + // Third run should hit limit + const finalResult = await agent2.invoke( + { messages: [new HumanMessage("Third question")] }, + threadConfig + ); + + const lastMessage = finalResult.messages[finalResult.messages.length - 1]; + expect(lastMessage.content).toContain("thread limit reached (4/4)"); + }); + }); + + describe("Run-level Limits", () => { + it("should allow tool calls under run limit", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [{ id: "1", name: "search", args: { query: "test" } }], + }), + new AIMessage("Response"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + runLimit: 2, + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Search")], + }); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).not.toContain("run limit"); + }); + + it("should terminate when run limit is exceeded", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "search", args: { query: "test2" } }, + ], + }), + new AIMessage("Should not reach here"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + runLimit: 2, + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Search for things")], + }); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).toContain("run limit reached (2/2)"); + }); + + it("should reset run count after new HumanMessage", async () => { + // Create separate model instances for each invocation + let callCount = 0; + const createModel = () => { + return new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [ + { + id: `${callCount++}`, + name: "search", + args: { query: "test1" }, + }, + { + id: `${callCount++}`, + name: "search", + args: { query: "test2" }, + }, + ], + }), + new AIMessage(`Response ${callCount}`), + ], + }); + }; + + const middleware = toolCallLimitMiddleware({ + runLimit: 2, + exitBehavior: "end", + }); + + const threadConfig = { configurable: { thread_id: "test-thread" } }; + + // First run - should hit run limit + const agent1 = createAgent({ + model: createModel(), + tools: [searchTool], + middleware: [middleware], + checkpointer: new MemorySaver(), + }); + + const result1 = await agent1.invoke( + { messages: [new HumanMessage("First question")] }, + threadConfig + ); + expect(result1.messages[result1.messages.length - 1].content).toContain( + "run limit reached (2/2)" + ); + + // Second run with new model - run count resets, should also hit limit + const agent2 = createAgent({ + model: createModel(), + tools: [searchTool], + middleware: [middleware], + checkpointer: new MemorySaver(), + }); + + const result2 = await agent2.invoke( + { messages: [new HumanMessage("Second question")] }, + threadConfig + ); + expect(result2.messages[result2.messages.length - 1].content).toContain( + "run limit reached (2/2)" + ); + }); + }); + + describe("Tool-specific Limits", () => { + it("should only count calls to specific tool", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test" } }, + { id: "2", name: "calculator", args: { expression: "1+1" } }, + { id: "3", name: "calculator", args: { expression: "2+2" } }, + ], + }), + new AIMessage("Response"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + toolName: "search", + threadLimit: 2, // Increased to allow 1 search call + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool, calculatorTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Do calculations")], + }); + + // Should complete - only 1 search call, calculators don't count + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).not.toContain("thread limit"); + }); + + it("should terminate when specific tool limit is exceeded", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "calculator", args: { expression: "1+1" } }, + ], + }), + new AIMessage({ + content: "", + tool_calls: [{ id: "3", name: "search", args: { query: "test2" } }], + }), + new AIMessage("Should not reach here"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + toolName: "search", + threadLimit: 2, + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool, calculatorTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Search and calculate")], + }); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).toContain("'search' tool call"); + expect(lastMessage.content).toContain("thread limit reached (2/2)"); + }); + }); + + describe("Multiple Middleware Instances", () => { + it("should work with both global and tool-specific limiters", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "search", args: { query: "test2" } }, + { id: "3", name: "calculator", args: { expression: "1+1" } }, + ], + }), + new AIMessage("Should not reach here"), + ], + }); + + const globalLimiter = toolCallLimitMiddleware({ + threadLimit: 10, // Won't hit this + exitBehavior: "end", + }); + + const searchLimiter = toolCallLimitMiddleware({ + toolName: "search", + threadLimit: 2, // Will hit this + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool, calculatorTool], + middleware: [globalLimiter, searchLimiter], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Search and calculate")], + }); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).toContain("'search' tool call"); + expect(lastMessage.content).toContain("thread limit reached (2/2)"); + }); + }); + + describe("Error Behavior", () => { + it("should throw ToolCallLimitExceededError when exitBehavior is error", async () => { + const middleware = toolCallLimitMiddleware({ + threadLimit: 2, + exitBehavior: "error", + }); + + // Test with state that has exceeded limit + const state = { + messages: [ + new HumanMessage("Q"), + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + ] as ToolCall[], + }), + new ToolMessage("Result", "1"), + new AIMessage({ + content: "", + tool_calls: [ + { id: "2", name: "search", args: { query: "test2" } }, + ] as ToolCall[], + }), + ], + }; + + await expect(async () => { + await middleware.beforeModel!(state as any, {} as any); + }).rejects.toThrow(ToolCallLimitExceededError); + }); + + it("should include correct information in ToolCallLimitExceededError", async () => { + const middleware = toolCallLimitMiddleware({ + threadLimit: 2, + runLimit: 1, + exitBehavior: "error", + }); + + const state = { + messages: [ + new HumanMessage("Q"), + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + ] as ToolCall[], + }), + new ToolMessage("Result", "1"), + new AIMessage({ + content: "", + tool_calls: [ + { id: "2", name: "search", args: { query: "test2" } }, + ] as ToolCall[], + }), + ], + }; + + try { + await middleware.beforeModel!(state as any, {} as any); + expect.fail("Should have thrown error"); + } catch (error) { + expect(error).toBeInstanceOf(ToolCallLimitExceededError); + if (error instanceof ToolCallLimitExceededError) { + expect(error.threadCount).toBe(2); + expect(error.threadLimit).toBe(2); + expect(error.runCount).toBe(2); + expect(error.runLimit).toBe(1); + expect(error.toolName).toBeUndefined(); + expect(error.message).toContain("thread limit reached (2/2)"); + expect(error.message).toContain("run limit reached (2/1)"); + } + } + }); + + it("should include tool name in error for tool-specific limits", async () => { + const middleware = toolCallLimitMiddleware({ + toolName: "search", + threadLimit: 2, + exitBehavior: "error", + }); + + const state = { + messages: [ + new HumanMessage("Q"), + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "calculator", args: { expression: "1+1" } }, + ] as ToolCall[], + }), + new ToolMessage("Result", "1"), + new ToolMessage("Result", "2"), + new AIMessage({ + content: "", + tool_calls: [ + { id: "3", name: "search", args: { query: "test2" } }, + ] as ToolCall[], + }), + ], + }; + + try { + await middleware.beforeModel!(state as any, {} as any); + expect.fail("Should have thrown error"); + } catch (error) { + expect(error).toBeInstanceOf(ToolCallLimitExceededError); + if (error instanceof ToolCallLimitExceededError) { + expect(error.toolName).toBe("search"); + expect(error.message).toContain("'search' tool call"); + } + } + }); + }); + + describe("Combined Thread and Run Limits", () => { + it("should check both thread and run limits", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "search", args: { query: "test2" } }, + ], + }), + new AIMessage("Should not reach here"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + threadLimit: 5, // Won't hit this + runLimit: 2, // Will hit this + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Search")], + }); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).toContain("run limit reached (2/2)"); + }); + + it("should report correct limit type when thread limit is hit first", async () => { + const model = new FakeToolCallingChatModel({ + responses: [ + new AIMessage({ + content: "", + tool_calls: [{ id: "1", name: "search", args: { query: "test1" } }], + }), + new AIMessage({ + content: "", + tool_calls: [{ id: "2", name: "search", args: { query: "test2" } }], + }), + new AIMessage("Should not reach here"), + ], + }); + + const middleware = toolCallLimitMiddleware({ + threadLimit: 2, // Will hit this + runLimit: 10, // Won't hit this + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool], + middleware: [middleware], + checkpointer: new MemorySaver(), + }); + + const result = await agent.invoke( + { + messages: [new HumanMessage("Search")], + }, + { configurable: { thread_id: "test-thread" } } + ); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).toContain("thread limit reached (2/2)"); + }); + }); + + describe("Edge Cases", () => { + it("should handle messages with no tool calls", async () => { + const model = new FakeToolCallingChatModel({ + responses: [new AIMessage("Just a response, no tool calls")], + }); + + const middleware = toolCallLimitMiddleware({ + threadLimit: 2, + exitBehavior: "end", + }); + + const agent = createAgent({ + model, + tools: [searchTool], + middleware: [middleware], + }); + + const result = await agent.invoke({ + messages: [new HumanMessage("Hello")], + }); + + // Should complete without hitting limit + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).not.toContain("thread limit"); + expect(lastMessage.content).not.toContain("run limit"); + }); + + it("should handle empty message history", async () => { + const middleware = toolCallLimitMiddleware({ + threadLimit: 5, + exitBehavior: "end", + }); + + const state = { + messages: [], + }; + + const result = await middleware.beforeModel!(state as any, {} as any); + expect(result).toBeUndefined(); + }); + + it("should correctly count multiple tool calls in single AIMessage", async () => { + const middleware = toolCallLimitMiddleware({ + threadLimit: 3, + exitBehavior: "end", + }); + + const state = { + messages: [ + new HumanMessage("Do multiple things"), + new AIMessage({ + content: "", + tool_calls: [ + { id: "1", name: "search", args: { query: "test1" } }, + { id: "2", name: "search", args: { query: "test2" } }, + { id: "3", name: "calculator", args: { expression: "1+1" } }, + ] as ToolCall[], + }), + ], + }; + + const result = await middleware.beforeModel!(state as any, {} as any); + + // Should hit limit (3 tool calls) + expect(result).toBeDefined(); + + const messages = (result as { messages: BaseMessage[] }).messages; + expect(messages[0].content).toContain("thread limit reached (3/3)"); + }); + }); +}); diff --git a/libs/langchain/src/agents/middlewareAgent/middleware/toolCallLimit.ts b/libs/langchain/src/agents/middlewareAgent/middleware/toolCallLimit.ts new file mode 100644 index 000000000000..5f2e380bd68e --- /dev/null +++ b/libs/langchain/src/agents/middlewareAgent/middleware/toolCallLimit.ts @@ -0,0 +1,352 @@ +/** + * Tool call limit middleware for agents. + */ + +import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages"; +import { z } from "zod/v3"; +import type { InferInteropZodInput } from "@langchain/core/utils/types"; + +import { createMiddleware } from "../middleware.js"; +import type { AgentBuiltInState } from "../types.js"; + +/** + * Count tool calls in a list of messages. + * + * @param messages - List of messages to count tool calls in. + * @param toolName - If specified, only count calls to this specific tool. + * If undefined, count all tool calls. + * @returns The total number of tool calls (optionally filtered by toolName). + */ +function countToolCallsInMessages( + messages: BaseMessage[], + toolName?: string +): number { + let count = 0; + for (const message of messages) { + if (AIMessage.isInstance(message) && message.tool_calls) { + if (toolName === undefined) { + // Count all tool calls + count += message.tool_calls.length; + } else { + // Count only calls to the specified tool + count += message.tool_calls.filter((tc) => tc.name === toolName).length; + } + } + } + return count; +} + +/** + * Get messages from the current run (after the last HumanMessage). + * + * @param messages - Full list of messages. + * @returns Messages from the current run (after last HumanMessage). + */ +function getRunMessages(messages: BaseMessage[]): BaseMessage[] { + /** + * Find the last HumanMessage + */ + let lastHumanIndex = -1; + for (let i = messages.length - 1; i >= 0; i--) { + if (HumanMessage.isInstance(messages[i])) { + lastHumanIndex = i; + break; + } + } + + /** + * If no HumanMessage found, return all messages + */ + if (lastHumanIndex === -1) { + return messages; + } + + /** + * Return messages after the last HumanMessage + */ + return messages.slice(lastHumanIndex + 1); +} + +/** + * Build a message indicating which tool call limits were reached. + * + * @param threadCount - Current thread tool call count. + * @param runCount - Current run tool call count. + * @param threadLimit - Thread tool call limit (if set). + * @param runLimit - Run tool call limit (if set). + * @param toolName - Tool name being limited (if specific tool), or undefined for all tools. + * @returns A formatted message describing which limits were reached. + */ +function buildToolLimitExceededMessage( + threadCount: number, + runCount: number, + threadLimit: number | undefined, + runLimit: number | undefined, + toolName: string | undefined +): string { + const toolDesc = toolName ? `'${toolName}' tool call` : "Tool call"; + const exceededLimits: string[] = []; + + if (threadLimit !== undefined && threadCount >= threadLimit) { + exceededLimits.push(`thread limit reached (${threadCount}/${threadLimit})`); + } + if (runLimit !== undefined && runCount >= runLimit) { + exceededLimits.push(`run limit reached (${runCount}/${runLimit})`); + } + + return `${toolDesc} limit${ + exceededLimits.length > 1 ? "s" : "" + }: ${exceededLimits.join(", ")}. Stopping to prevent further tool calls.`; +} + +/** + * Exception raised when tool call limits are exceeded. + * + * This exception is raised when the configured exit behavior is 'error' + * and either the thread or run tool call limit has been exceeded. + */ +export class ToolCallLimitExceededError extends Error { + /** + * Current thread tool call count. + */ + threadCount: number; + /** + * Current run tool call count. + */ + runCount: number; + /** + * Thread tool call limit (if set). + */ + threadLimit: number | undefined; + /** + * Run tool call limit (if set). + */ + runLimit: number | undefined; + /** + * Tool name being limited (if specific tool), or undefined for all tools. + */ + toolName: string | undefined; + + constructor( + threadCount: number, + runCount: number, + threadLimit: number | undefined, + runLimit: number | undefined, + toolName: string | undefined = undefined + ) { + const message = buildToolLimitExceededMessage( + threadCount, + runCount, + threadLimit, + runLimit, + toolName + ); + super(message); + + this.name = "ToolCallLimitExceededError"; + this.threadCount = threadCount; + this.runCount = runCount; + this.threadLimit = threadLimit; + this.runLimit = runLimit; + this.toolName = toolName; + } +} + +/** + * Options for configuring the Tool Call Limit middleware. + */ +export const ToolCallLimitOptionsSchema = z.object({ + /** + * Name of the specific tool to limit. If undefined, limits apply to all tools. + */ + toolName: z.string().optional(), + /** + * Maximum number of tool calls allowed per thread. + * undefined means no limit. + */ + threadLimit: z.number().optional(), + /** + * Maximum number of tool calls allowed per run. + * undefined means no limit. + */ + runLimit: z.number().optional(), + /** + * What to do when limits are exceeded. + * - "end": Jump to the end of the agent execution and inject an artificial + * AI message indicating that the limit was exceeded. + * - "error": throws a ToolCallLimitExceededError + */ + exitBehavior: z.enum(["end", "error"]).default("end"), +}); + +export type ToolCallLimitConfig = InferInteropZodInput< + typeof ToolCallLimitOptionsSchema +>; + +/** + * Middleware that tracks tool call counts and enforces limits. + * + * This middleware monitors the number of tool calls made during agent execution + * and can terminate the agent when specified limits are reached. It supports + * both thread-level and run-level call counting with configurable exit behaviors. + * + * Thread-level: The middleware counts all tool calls in the entire message history + * and persists this count across multiple runs (invocations) of the agent. + * + * Run-level: The middleware counts tool calls made after the last HumanMessage, + * representing the current run (invocation) of the agent. + * + * @param options - Configuration options for the middleware + * @param options.toolName - Name of the specific tool to limit. If undefined, limits apply to all tools. + * @param options.threadLimit - Maximum number of tool calls allowed per thread. undefined means no limit. + * @param options.runLimit - Maximum number of tool calls allowed per run. undefined means no limit. + * @param options.exitBehavior - What to do when limits are exceeded. + * - "end": Jump to the end of the agent execution and inject an artificial AI message indicating that the limit was exceeded. + * - "error": throws a ToolCallLimitExceededError + * + * @throws {Error} If both limits are undefined or if exitBehavior is invalid. + * + * @example Limit all tool calls globally + * ```ts + * import { toolCallLimitMiddleware } from "@langchain/langchain/agents/middleware"; + * import { createAgent } from "@langchain/langchain/agents"; + * + * const globalLimiter = toolCallLimitMiddleware({ + * threadLimit: 20, + * runLimit: 10, + * exitBehavior: "end" + * }); + * + * const agent = createAgent({ + * model: "openai:gpt-4o", + * middleware: [globalLimiter] + * }); + * ``` + * + * @example Limit a specific tool + * ```ts + * import { toolCallLimitMiddleware } from "@langchain/langchain/agents/middleware"; + * import { createAgent } from "@langchain/langchain/agents"; + * + * const searchLimiter = toolCallLimitMiddleware({ + * toolName: "search", + * threadLimit: 5, + * runLimit: 3, + * exitBehavior: "end" + * }); + * + * const agent = createAgent({ + * model: "openai:gpt-4o", + * middleware: [searchLimiter] + * }); + * ``` + * + * @example Use both in the same agent + * ```ts + * import { toolCallLimitMiddleware } from "@langchain/langchain/agents/middleware"; + * import { createAgent } from "@langchain/langchain/agents"; + * + * const globalLimiter = toolCallLimitMiddleware({ + * threadLimit: 20, + * runLimit: 10, + * exitBehavior: "end" + * }); + * + * const searchLimiter = toolCallLimitMiddleware({ + * toolName: "search", + * threadLimit: 5, + * runLimit: 3, + * exitBehavior: "end" + * }); + * + * const agent = createAgent({ + * model: "openai:gpt-4o", + * middleware: [globalLimiter, searchLimiter] + * }); + * ``` + */ +export function toolCallLimitMiddleware(options: ToolCallLimitConfig) { + /** + * Validate that at least one limit is specified + */ + if (options.threadLimit === undefined && options.runLimit === undefined) { + throw new Error( + "At least one limit must be specified (threadLimit or runLimit)" + ); + } + + /** + * Apply default for exitBehavior and validate + */ + const exitBehavior = options.exitBehavior ?? "end"; + if (exitBehavior !== "end" && exitBehavior !== "error") { + throw new Error( + `Invalid exit behavior: ${exitBehavior}. Must be 'end' or 'error'` + ); + } + + /** + * Generate the middleware name based on the tool name + */ + const middlewareName = options.toolName + ? `ToolCallLimitMiddleware[${options.toolName}]` + : "ToolCallLimitMiddleware"; + + return createMiddleware({ + name: middlewareName, + beforeModelJumpTo: ["end"], + async beforeModel(state: AgentBuiltInState) { + const messages = state.messages; + + /** + * Count tool calls in entire thread + */ + const threadCount = countToolCallsInMessages(messages, options.toolName); + + /** + * Count tool calls in current run (after last HumanMessage) + */ + const runMessages = getRunMessages(messages); + const runCount = countToolCallsInMessages(runMessages, options.toolName); + + /** + * Check if any limits are exceeded + */ + const threadLimitExceeded = + options.threadLimit !== undefined && threadCount >= options.threadLimit; + const runLimitExceeded = + options.runLimit !== undefined && runCount >= options.runLimit; + + if (!threadLimitExceeded && !runLimitExceeded) { + return undefined; + } + + if (exitBehavior === "error") { + throw new ToolCallLimitExceededError( + threadCount, + runCount, + options.threadLimit, + options.runLimit, + options.toolName + ); + } + + /** + * Create a message indicating the limit was exceeded + */ + const limitMessage = buildToolLimitExceededMessage( + threadCount, + runCount, + options.threadLimit, + options.runLimit, + options.toolName + ); + const limitAiMessage = new AIMessage(limitMessage); + + return { + jumpTo: "end", + messages: [limitAiMessage], + }; + }, + }); +}