From 7ec2c9e16482fa274295aa9a62676dd7379aad69 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 29 Oct 2025 13:42:21 -0700 Subject: [PATCH 1/8] fix(langchain): improved state schema typing --- libs/langchain/src/agents/middleware/index.ts | 2 +- .../{callLimit.ts => modelCallLimit.ts} | 0 ...llLimit.test.ts => modelCallLimit.test.ts} | 2 +- libs/langchain/src/agents/middleware/types.ts | 134 ++++++++++-------- .../src/agents/tests/middleware.test-d.ts | 25 +++- 5 files changed, 99 insertions(+), 64 deletions(-) rename libs/langchain/src/agents/middleware/{callLimit.ts => modelCallLimit.ts} (100%) rename libs/langchain/src/agents/middleware/tests/{callLimit.test.ts => modelCallLimit.test.ts} (98%) diff --git a/libs/langchain/src/agents/middleware/index.ts b/libs/langchain/src/agents/middleware/index.ts index d2ddeb5dc635..30f3b1de4979 100644 --- a/libs/langchain/src/agents/middleware/index.ts +++ b/libs/langchain/src/agents/middleware/index.ts @@ -40,7 +40,7 @@ export { export { modelCallLimitMiddleware, type ModelCallLimitMiddlewareConfig, -} from "./callLimit.js"; +} from "./modelCallLimit.js"; export { modelFallbackMiddleware } from "./modelFallback.js"; export { toolRetryMiddleware, diff --git a/libs/langchain/src/agents/middleware/callLimit.ts b/libs/langchain/src/agents/middleware/modelCallLimit.ts similarity index 100% rename from libs/langchain/src/agents/middleware/callLimit.ts rename to libs/langchain/src/agents/middleware/modelCallLimit.ts diff --git a/libs/langchain/src/agents/middleware/tests/callLimit.test.ts b/libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts similarity index 98% rename from libs/langchain/src/agents/middleware/tests/callLimit.test.ts rename to libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts index 344d395f5d2e..45f52f30ad2c 100644 --- a/libs/langchain/src/agents/middleware/tests/callLimit.test.ts +++ b/libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts @@ -4,7 +4,7 @@ import { tool } from "@langchain/core/tools"; import { MemorySaver } from "@langchain/langgraph-checkpoint"; import { FakeToolCallingChatModel } from "../../tests/utils.js"; -import { modelCallLimitMiddleware } from "../callLimit.js"; +import { modelCallLimitMiddleware } from "../modelCallLimit.js"; import { createAgent } from "../../index.js"; const toolCallMessage1 = new AIMessage({ diff --git a/libs/langchain/src/agents/middleware/types.ts b/libs/langchain/src/agents/middleware/types.ts index 31ee773df487..bfba57cc70a1 100644 --- a/libs/langchain/src/agents/middleware/types.ts +++ b/libs/langchain/src/agents/middleware/types.ts @@ -21,8 +21,21 @@ type PromiseOrValue = T | Promise; export type AnyAnnotationRoot = AnnotationRoot; -type NormalizedSchemaInput = - TSchema extends InteropZodObject ? InferInteropZodInput : {}; +type NormalizedSchemaInput< + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined + | never = any +> = [TSchema] extends [never] + ? AgentBuiltInState + : TSchema extends InteropZodObject + ? InferInteropZodOutput & AgentBuiltInState + : TSchema extends InteropZodDefault + ? InferInteropZodOutput & AgentBuiltInState + : TSchema extends Record + ? TSchema & AgentBuiltInState + : AgentBuiltInState; /** * Result type for middleware functions. @@ -61,13 +74,10 @@ export interface ToolCallRequest< * Takes a tool call request and returns the tool result or a command. */ export type ToolCallHandler< - TSchema extends InteropZodObject | undefined = any, + TSchema extends Record = AgentBuiltInState, TContext = unknown > = ( - request: ToolCallRequest< - NormalizedSchemaInput & AgentBuiltInState, - TContext - > + request: ToolCallRequest ) => PromiseOrValue; /** @@ -75,14 +85,14 @@ export type ToolCallHandler< * Allows middleware to intercept and modify tool execution. */ export type WrapToolCallHook< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined = undefined, TContext = unknown > = ( - request: ToolCallRequest< - NormalizedSchemaInput & AgentBuiltInState, - TContext - >, - handler: ToolCallHandler + request: ToolCallRequest, TContext>, + handler: ToolCallHandler, TContext> ) => PromiseOrValue; /** @@ -93,13 +103,13 @@ export type WrapToolCallHook< * @returns The AI message response from the model */ export type WrapModelCallHandler< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined = undefined, TContext = unknown > = ( - request: ModelRequest< - NormalizedSchemaInput & AgentBuiltInState, - TContext - > + request: ModelRequest, TContext> ) => PromiseOrValue; /** @@ -116,13 +126,14 @@ export type WrapModelCallHandler< * @returns The AI message response from the model (or a modified version) */ export type WrapModelCallHook< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined + | never = undefined, TContext = unknown > = ( - request: ModelRequest< - NormalizedSchemaInput & AgentBuiltInState, - TContext - >, + request: ModelRequest, TContext>, handler: WrapModelCallHandler ) => PromiseOrValue; @@ -134,13 +145,10 @@ export type WrapModelCallHook< * @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc. * @returns A middleware result containing partial state updates or undefined to pass through */ -export type BeforeAgentHandler< - TSchema extends InteropZodObject | undefined = any, - TContext = unknown -> = ( - state: NormalizedSchemaInput & AgentBuiltInState, +type BeforeAgentHandler = ( + state: TSchema, runtime: Runtime -) => PromiseOrValue>>>; +) => PromiseOrValue>>; /** * Hook type for the beforeAgent lifecycle event. @@ -148,12 +156,15 @@ export type BeforeAgentHandler< * This hook is called once at the start of the agent invocation. */ export type BeforeAgentHook< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined = undefined, TContext = unknown > = - | BeforeAgentHandler + | BeforeAgentHandler, TContext> | { - hook: BeforeAgentHandler; + hook: BeforeAgentHandler, TContext>; canJumpTo?: JumpToTarget[]; }; @@ -165,13 +176,10 @@ export type BeforeAgentHook< * @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc. * @returns A middleware result containing partial state updates or undefined to pass through */ -export type BeforeModelHandler< - TSchema extends InteropZodObject | undefined = any, - TContext = unknown -> = ( - state: NormalizedSchemaInput & AgentBuiltInState, +type BeforeModelHandler = ( + state: TSchema, runtime: Runtime -) => PromiseOrValue>>>; +) => PromiseOrValue>>; /** * Hook type for the beforeModel lifecycle event. @@ -179,12 +187,15 @@ export type BeforeModelHandler< * This hook is called before each model invocation. */ export type BeforeModelHook< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined = undefined, TContext = unknown > = - | BeforeModelHandler + | BeforeModelHandler, TContext> | { - hook: BeforeModelHandler; + hook: BeforeModelHandler, TContext>; canJumpTo?: JumpToTarget[]; }; @@ -197,13 +208,10 @@ export type BeforeModelHook< * @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc. * @returns A middleware result containing partial state updates or undefined to pass through */ -export type AfterModelHandler< - TSchema extends InteropZodObject | undefined = any, - TContext = unknown -> = ( - state: NormalizedSchemaInput & AgentBuiltInState, +type AfterModelHandler = ( + state: TSchema, runtime: Runtime -) => PromiseOrValue>>>; +) => PromiseOrValue>>; /** * Hook type for the afterModel lifecycle event. @@ -211,12 +219,15 @@ export type AfterModelHandler< * This hook is called after each model invocation. */ export type AfterModelHook< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined = undefined, TContext = unknown > = - | AfterModelHandler + | AfterModelHandler, TContext> | { - hook: AfterModelHandler; + hook: AfterModelHandler, TContext>; canJumpTo?: JumpToTarget[]; }; @@ -228,13 +239,10 @@ export type AfterModelHook< * @param runtime - The runtime context containing metadata, signal, writer, interrupt, etc. * @returns A middleware result containing partial state updates or undefined to pass through */ -export type AfterAgentHandler< - TSchema extends InteropZodObject | undefined = any, - TContext = unknown -> = ( - state: NormalizedSchemaInput & AgentBuiltInState, +type AfterAgentHandler = ( + state: TSchema, runtime: Runtime -) => PromiseOrValue>>>; +) => PromiseOrValue>>; /** * Hook type for the afterAgent lifecycle event. @@ -242,12 +250,15 @@ export type AfterAgentHandler< * This hook is called once at the end of the agent invocation. */ export type AfterAgentHook< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined = undefined, TContext = unknown > = - | AfterAgentHandler + | AfterAgentHandler, TContext> | { - hook: AfterAgentHandler; + hook: AfterAgentHandler, TContext>; canJumpTo?: JumpToTarget[]; }; @@ -255,7 +266,10 @@ export type AfterAgentHook< * Base middleware interface. */ export interface AgentMiddleware< - TSchema extends InteropZodObject | undefined = any, + TSchema extends + | InteropZodObject + | InteropZodDefault + | undefined = any, TContextSchema extends | InteropZodObject | InteropZodDefault diff --git a/libs/langchain/src/agents/tests/middleware.test-d.ts b/libs/langchain/src/agents/tests/middleware.test-d.ts index 3f95d6d934ac..72b1fcd2a222 100644 --- a/libs/langchain/src/agents/tests/middleware.test-d.ts +++ b/libs/langchain/src/agents/tests/middleware.test-d.ts @@ -4,6 +4,7 @@ import { HumanMessage, BaseMessage, AIMessage } from "@langchain/core/messages"; import { tool } from "@langchain/core/tools"; import { createAgent, createMiddleware } from "../index.js"; +import type { AgentBuiltInState } from "../runtime.js"; import type { ServerTool, ClientTool } from "../tools.js"; describe("middleware types", () => { @@ -169,14 +170,33 @@ describe("middleware types", () => { .default({ customRequiredContextProp: "default value", }), - beforeModel: async (_state, runtime) => { + stateSchema: z.object({ + customDefaultStateProp: z.string().default("default value"), + customOptionalStateProp: z.string().optional(), + customRequiredStateProp: z.string(), + }), + beforeModel: async (state, runtime) => { + expectTypeOf(state).toEqualTypeOf< + { + customDefaultStateProp: string; + customOptionalStateProp?: string; + customRequiredStateProp: string; + } & AgentBuiltInState + >(); expectTypeOf(runtime.context).toEqualTypeOf<{ customDefaultContextProp: string; customOptionalContextProp?: string; customRequiredContextProp: string; }>(); }, - afterModel: async (_state, runtime) => { + afterModel: async (state, runtime) => { + expectTypeOf(state).toEqualTypeOf< + { + customDefaultStateProp: string; + customOptionalStateProp?: string; + customRequiredStateProp: string; + } & AgentBuiltInState + >(); expectTypeOf(runtime.context).toEqualTypeOf<{ customDefaultContextProp: string; customOptionalContextProp?: string; @@ -209,6 +229,7 @@ describe("middleware types", () => { await agent.invoke( { messages: [new HumanMessage("Hello, world!")], + customRequiredStateProp: "default value", }, { configurable: { From bafbc600ea778168ab091efc09a13b0ae98d0583 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 29 Oct 2025 13:49:00 -0700 Subject: [PATCH 2/8] fix tests --- libs/langchain/src/agents/middleware/types.ts | 6 +++++- libs/langchain/src/agents/tests/state.test.ts | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/libs/langchain/src/agents/middleware/types.ts b/libs/langchain/src/agents/middleware/types.ts index bfba57cc70a1..eedd68b27dcd 100644 --- a/libs/langchain/src/agents/middleware/types.ts +++ b/libs/langchain/src/agents/middleware/types.ts @@ -40,7 +40,11 @@ type NormalizedSchemaInput< /** * Result type for middleware functions. */ -export type MiddlewareResult = TState | void; +export type MiddlewareResult = + | (TState & { + jumpTo?: JumpToTarget; + }) + | void; /** * Represents a tool call request for the wrapToolCall hook. diff --git a/libs/langchain/src/agents/tests/state.test.ts b/libs/langchain/src/agents/tests/state.test.ts index ceea197eed74..e390d426ab00 100644 --- a/libs/langchain/src/agents/tests/state.test.ts +++ b/libs/langchain/src/agents/tests/state.test.ts @@ -146,6 +146,7 @@ describe("middleware state management", () => { const model = new FakeToolCallingModel({}); const middleware = createMiddleware({ name: "middleware", + // @ts-expect-error _privateState is not an expected return type beforeModel: async (_, runtime) => { expect(runtime.threadLevelCallCount).toBe(0); expect(runtime.runModelCallCount).toBe(0); @@ -165,6 +166,7 @@ describe("middleware state management", () => { expect(request.runtime.runModelCallCount).toBe(0); return handler(request); }, + // @ts-expect-error _privateState is not an expected return type afterModel: async (_, runtime) => { expect(runtime.threadLevelCallCount).toBe(1); expect(runtime.runModelCallCount).toBe(1); From c2fd38966d8cd5031733632bcc761b4287f0dcb4 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 29 Oct 2025 13:50:38 -0700 Subject: [PATCH 3/8] minor tweak --- libs/langchain/src/agents/middleware/types.ts | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/libs/langchain/src/agents/middleware/types.ts b/libs/langchain/src/agents/middleware/types.ts index eedd68b27dcd..d04edbb05015 100644 --- a/libs/langchain/src/agents/middleware/types.ts +++ b/libs/langchain/src/agents/middleware/types.ts @@ -89,10 +89,7 @@ export type ToolCallHandler< * Allows middleware to intercept and modify tool execution. */ export type WrapToolCallHook< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined = undefined, + TSchema extends InteropZodObject | undefined = undefined, TContext = unknown > = ( request: ToolCallRequest, TContext>, From 0762092087a9fc12dd5c6b0e1efed041dfd9b2c9 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 29 Oct 2025 13:56:08 -0700 Subject: [PATCH 4/8] more tweaks --- libs/langchain/src/agents/middleware/types.ts | 44 ++++--------------- 1 file changed, 8 insertions(+), 36 deletions(-) diff --git a/libs/langchain/src/agents/middleware/types.ts b/libs/langchain/src/agents/middleware/types.ts index d04edbb05015..840644be1440 100644 --- a/libs/langchain/src/agents/middleware/types.ts +++ b/libs/langchain/src/agents/middleware/types.ts @@ -22,17 +22,11 @@ type PromiseOrValue = T | Promise; export type AnyAnnotationRoot = AnnotationRoot; type NormalizedSchemaInput< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined - | never = any + TSchema extends InteropZodObject | undefined | never = any > = [TSchema] extends [never] ? AgentBuiltInState : TSchema extends InteropZodObject ? InferInteropZodOutput & AgentBuiltInState - : TSchema extends InteropZodDefault - ? InferInteropZodOutput & AgentBuiltInState : TSchema extends Record ? TSchema & AgentBuiltInState : AgentBuiltInState; @@ -104,10 +98,7 @@ export type WrapToolCallHook< * @returns The AI message response from the model */ export type WrapModelCallHandler< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined = undefined, + TSchema extends InteropZodObject | undefined = undefined, TContext = unknown > = ( request: ModelRequest, TContext> @@ -127,11 +118,7 @@ export type WrapModelCallHandler< * @returns The AI message response from the model (or a modified version) */ export type WrapModelCallHook< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined - | never = undefined, + TSchema extends InteropZodObject | undefined = undefined, TContext = unknown > = ( request: ModelRequest, TContext>, @@ -157,10 +144,7 @@ type BeforeAgentHandler = ( * This hook is called once at the start of the agent invocation. */ export type BeforeAgentHook< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined = undefined, + TSchema extends InteropZodObject | undefined = undefined, TContext = unknown > = | BeforeAgentHandler, TContext> @@ -188,10 +172,7 @@ type BeforeModelHandler = ( * This hook is called before each model invocation. */ export type BeforeModelHook< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined = undefined, + TSchema extends InteropZodObject | undefined = undefined, TContext = unknown > = | BeforeModelHandler, TContext> @@ -220,10 +201,7 @@ type AfterModelHandler = ( * This hook is called after each model invocation. */ export type AfterModelHook< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined = undefined, + TSchema extends InteropZodObject | undefined = undefined, TContext = unknown > = | AfterModelHandler, TContext> @@ -251,10 +229,7 @@ type AfterAgentHandler = ( * This hook is called once at the end of the agent invocation. */ export type AfterAgentHook< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined = undefined, + TSchema extends InteropZodObject | undefined = undefined, TContext = unknown > = | AfterAgentHandler, TContext> @@ -267,10 +242,7 @@ export type AfterAgentHook< * Base middleware interface. */ export interface AgentMiddleware< - TSchema extends - | InteropZodObject - | InteropZodDefault - | undefined = any, + TSchema extends InteropZodObject | undefined = any, TContextSchema extends | InteropZodObject | InteropZodDefault From 5e2dfd896d662df7b58dcd2f1a7d4e31faeec451 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 29 Oct 2025 13:57:14 -0700 Subject: [PATCH 5/8] Fix state schema typing in langchain --- .changeset/twenty-clocks-raise.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/twenty-clocks-raise.md diff --git a/.changeset/twenty-clocks-raise.md b/.changeset/twenty-clocks-raise.md new file mode 100644 index 000000000000..9660e9c1a142 --- /dev/null +++ b/.changeset/twenty-clocks-raise.md @@ -0,0 +1,5 @@ +--- +"langchain": patch +--- + +fix(langchain): improved state schema typing From 2c35fbab8884ca09ae858671245684a76a44e510 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 29 Oct 2025 14:06:01 -0700 Subject: [PATCH 6/8] fix(langchain): don't allow default or optional context schemas --- libs/langchain/src/agents/middleware.ts | 18 +---- .../src/agents/tests/middleware.test-d.ts | 67 ++++++------------- 2 files changed, 22 insertions(+), 63 deletions(-) diff --git a/libs/langchain/src/agents/middleware.ts b/libs/langchain/src/agents/middleware.ts index 08cc21492805..b022c1e05cf5 100644 --- a/libs/langchain/src/agents/middleware.ts +++ b/libs/langchain/src/agents/middleware.ts @@ -1,8 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { InteropZodObject, - InteropZodDefault, - InteropZodOptional, InferInteropZodOutput, } from "@langchain/core/utils/types"; @@ -52,11 +50,7 @@ import type { */ export function createMiddleware< TSchema extends InteropZodObject | undefined = undefined, - TContextSchema extends - | InteropZodObject - | InteropZodOptional - | InteropZodDefault - | undefined = undefined + TContextSchema extends InteropZodObject | undefined = undefined >(config: { /** * The name of the middleware @@ -232,15 +226,7 @@ export function createMiddleware< } type NormalizeContextSchema< - TContextSchema extends - | InteropZodObject - | InteropZodOptional - | InteropZodDefault - | undefined = undefined + TContextSchema extends InteropZodObject | undefined = undefined > = TContextSchema extends InteropZodObject ? InferInteropZodOutput - : TContextSchema extends InteropZodDefault - ? InferInteropZodOutput - : TContextSchema extends InteropZodOptional - ? Partial> : never; diff --git a/libs/langchain/src/agents/tests/middleware.test-d.ts b/libs/langchain/src/agents/tests/middleware.test-d.ts index 72b1fcd2a222..b89cf9bb7de4 100644 --- a/libs/langchain/src/agents/tests/middleware.test-d.ts +++ b/libs/langchain/src/agents/tests/middleware.test-d.ts @@ -133,11 +133,9 @@ describe("middleware types", () => { it("doesn't require users to pass in a context if a middleware has optional context schema", async () => { const middleware = createMiddleware({ name: "Middleware", - contextSchema: z - .object({ - customOptionalContextProp: z.string().default("default value"), - }) - .optional(), + contextSchema: z.object({ + customOptionalContextProp: z.string().optional(), + }), }); const agent = createAgent({ @@ -158,18 +156,13 @@ describe("middleware types", () => { ); }); - it("doesn't require users to pass in a context if a middleware has context schema with defaults", async () => { + it("doesn't require users to pass in a context if a middleware has context schema with defaults or optional", async () => { const middleware = createMiddleware({ name: "Middleware", - contextSchema: z - .object({ - customDefaultContextProp: z.string().default("default value"), - customOptionalContextProp: z.string().optional(), - customRequiredContextProp: z.string(), - }) - .default({ - customRequiredContextProp: "default value", - }), + contextSchema: z.object({ + customDefaultContextProp: z.string().default("default value"), + customOptionalContextProp: z.string().optional(), + }), stateSchema: z.object({ customDefaultStateProp: z.string().default("default value"), customOptionalStateProp: z.string().optional(), @@ -186,7 +179,6 @@ describe("middleware types", () => { expectTypeOf(runtime.context).toEqualTypeOf<{ customDefaultContextProp: string; customOptionalContextProp?: string; - customRequiredContextProp: string; }>(); }, afterModel: async (state, runtime) => { @@ -200,7 +192,6 @@ describe("middleware types", () => { expectTypeOf(runtime.context).toEqualTypeOf<{ customDefaultContextProp: string; customOptionalContextProp?: string; - customRequiredContextProp: string; }>(); }, wrapModelCall: async (request, handler) => { @@ -210,7 +201,6 @@ describe("middleware types", () => { expectTypeOf(request.runtime.context).toEqualTypeOf<{ customDefaultContextProp: string; customOptionalContextProp?: string; - customRequiredContextProp: string; }>(); return handler({ @@ -242,40 +232,23 @@ describe("middleware types", () => { it("doesn't require users to pass in a context if a middleware has context schema as optional", async () => { const middleware = createMiddleware({ name: "Middleware", - contextSchema: z - .object({ - customOptionalContextProp: z.string().default("default value"), - }) - .optional(), + contextSchema: z.object({ + customOptionalContextProp: z.string().default("default value"), + }), beforeModel: async (_state, runtime) => { - expectTypeOf(runtime.context).toEqualTypeOf< - Partial< - | { - customOptionalContextProp: string; - } - | undefined - > - >(); + expectTypeOf(runtime.context).toEqualTypeOf<{ + customOptionalContextProp: string; + }>(); }, afterModel: async (_state, runtime) => { - expectTypeOf(runtime.context).toEqualTypeOf< - Partial< - | { - customOptionalContextProp: string; - } - | undefined - > - >(); + expectTypeOf(runtime.context).toEqualTypeOf<{ + customOptionalContextProp: string; + }>(); }, wrapModelCall: async (request) => { - expectTypeOf(request.runtime.context).toEqualTypeOf< - Partial< - | { - customOptionalContextProp: string; - } - | undefined - > - >(); + expectTypeOf(request.runtime.context).toEqualTypeOf<{ + customOptionalContextProp: string; + }>(); return new AIMessage("foobar"); }, From dcffe06c4076ff0ed3e59e6f1dead0ba0dc30bd5 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Wed, 29 Oct 2025 17:00:03 -0700 Subject: [PATCH 7/8] fix(langchain): remove _privateState --- libs/langchain/src/agents/ReactAgent.ts | 89 +++++-------- libs/langchain/src/agents/annotation.ts | 7 +- .../src/agents/middleware/modelCallLimit.ts | 29 +++-- .../middleware/tests/modelCallLimit.test.ts | 15 ++- .../middleware/tests/toolCallLimit.test.ts | 119 ++++++++---------- .../src/agents/middleware/toolCallLimit.ts | 118 +++++++---------- libs/langchain/src/agents/nodes/AgentNode.ts | 50 +------- libs/langchain/src/agents/nodes/ToolNode.ts | 23 +--- libs/langchain/src/agents/nodes/middleware.ts | 25 ++-- libs/langchain/src/agents/runtime.ts | 29 +---- .../src/agents/tests/middleware.test.ts | 40 +----- .../src/agents/tests/runtime.test.ts | 9 +- libs/langchain/src/agents/tests/state.test.ts | 105 ---------------- 13 files changed, 197 insertions(+), 461 deletions(-) diff --git a/libs/langchain/src/agents/ReactAgent.ts b/libs/langchain/src/agents/ReactAgent.ts index 73ba80d8a1f8..2e3e34228170 100644 --- a/libs/langchain/src/agents/ReactAgent.ts +++ b/libs/langchain/src/agents/ReactAgent.ts @@ -46,11 +46,7 @@ import type { JumpTo, UserInput, } from "./types.js"; -import type { - PrivateState, - InvokeConfiguration, - StreamConfiguration, -} from "./runtime.js"; +import type { InvokeConfiguration, StreamConfiguration } from "./runtime.js"; import type { AgentMiddleware, InferMiddlewareContextInputs, @@ -249,10 +245,20 @@ export class ReactAgent< throw new Error(`Middleware ${m.name} is defined multiple times`); } + const getState = () => { + return { + ...beforeAgentNode?.getState(), + ...beforeModelNode?.getState(), + ...afterModelNode?.getState(), + ...afterAgentNode?.getState(), + ...this.#agentNode.getState(), + }; + }; + middlewareNames.add(m.name); if (m.beforeAgent) { beforeAgentNode = new BeforeAgentNode(m, { - getPrivateState: () => this.#agentNode.getState()._privateState, + getState, }); const name = `${m.name}.before_agent`; beforeAgentNodes.push({ @@ -268,7 +274,7 @@ export class ReactAgent< } if (m.beforeModel) { beforeModelNode = new BeforeModelNode(m, { - getPrivateState: () => this.#agentNode.getState()._privateState, + getState, }); const name = `${m.name}.before_model`; beforeModelNodes.push({ @@ -284,7 +290,7 @@ export class ReactAgent< } if (m.afterModel) { afterModelNode = new AfterModelNode(m, { - getPrivateState: () => this.#agentNode.getState()._privateState, + getState, }); const name = `${m.name}.after_model`; afterModelNodes.push({ @@ -300,7 +306,7 @@ export class ReactAgent< } if (m.afterAgent) { afterAgentNode = new AfterAgentNode(m, { - getPrivateState: () => this.#agentNode.getState()._privateState, + getState, }); const name = `${m.name}.after_agent`; afterAgentNodes.push({ @@ -316,15 +322,7 @@ export class ReactAgent< } if (m.wrapModelCall) { - wrapModelCallHookMiddleware.push([ - m, - () => ({ - ...beforeAgentNode?.getState(), - ...beforeModelNode?.getState(), - ...afterModelNode?.getState(), - ...afterAgentNode?.getState(), - }), - ]); + wrapModelCallHookMiddleware.push([m, getState]); } } @@ -350,7 +348,6 @@ export class ReactAgent< const toolNode = new ToolNode(toolClasses.filter(isClientTool), { signal: this.options.signal, wrapToolCall: wrapToolCallHandler, - getPrivateState: () => this.#agentNode.getState()._privateState, }); allNodeWorkflows.addNode("tools", toolNode); } @@ -944,7 +941,8 @@ export class ReactAgent< * Initialize middleware states if not already present in the input state. */ async #initializeMiddlewareStates( - state: InvokeStateParameter + state: InvokeStateParameter, + config: RunnableConfig ): Promise> { if ( !this.options.middleware || @@ -959,10 +957,13 @@ export class ReactAgent< this.options.middleware, state ); - const updatedState = { ...state } as InvokeStateParameter< - StateSchema, - TMiddleware - >; + const threadState = await this.#graph + .getState(config) + .catch(() => ({ values: {} })); + const updatedState = { + ...threadState.values, + ...state, + } as InvokeStateParameter; if (!updatedState) { return updatedState; } @@ -977,35 +978,6 @@ export class ReactAgent< return updatedState; } - /** - * Populate the private state of the agent node from the previous state. - */ - async #populatePrivateState(config?: RunnableConfig) { - /** - * not needed if thread_id is not provided - */ - if (!config?.configurable?.thread_id) { - return; - } - const prevState = (await this.#graph.getState(config as any)) as { - values: { - _privateState: PrivateState; - }; - }; - - /** - * not need if state is empty - */ - if (!prevState.values._privateState) { - return; - } - - this.#agentNode.setState({ - structuredResponse: undefined, - _privateState: prevState.values._privateState, - }); - } - /** * Executes the agent with the given state and returns the final state after all processing. * @@ -1061,8 +1033,10 @@ export class ReactAgent< StructuredResponseFormat, TMiddleware >; - const initializedState = await this.#initializeMiddlewareStates(state); - await this.#populatePrivateState(config); + const initializedState = await this.#initializeMiddlewareStates( + state, + config as RunnableConfig + ); return this.#graph.invoke( initializedState, @@ -1120,7 +1094,10 @@ export class ReactAgent< InferMiddlewareContextInputs > ): Promise> { - const initializedState = await this.#initializeMiddlewareStates(state); + const initializedState = await this.#initializeMiddlewareStates( + state, + config as RunnableConfig + ); return this.#graph.stream(initializedState, config as Record); } diff --git a/libs/langchain/src/agents/annotation.ts b/libs/langchain/src/agents/annotation.ts index 769cbdfed856..ace135cd68b5 100644 --- a/libs/langchain/src/agents/annotation.ts +++ b/libs/langchain/src/agents/annotation.ts @@ -30,7 +30,12 @@ export function createAgentAnnotationConditional< const zodSchema: Record = { messages: withLangGraph(z.custom(), MessagesZodMeta), jumpTo: z - .union([z.literal("model_request"), z.literal("tools"), z.undefined()]) + .union([ + z.literal("model_request"), + z.literal("tools"), + z.literal("end"), + z.undefined(), + ]) .optional(), }; diff --git a/libs/langchain/src/agents/middleware/modelCallLimit.ts b/libs/langchain/src/agents/middleware/modelCallLimit.ts index 4cc3652d74b8..6d6896ef5a67 100644 --- a/libs/langchain/src/agents/middleware/modelCallLimit.ts +++ b/libs/langchain/src/agents/middleware/modelCallLimit.ts @@ -27,6 +27,14 @@ export type ModelCallLimitMiddlewareConfig = Partial< InferInteropZodInput >; +/** + * Middleware state schema to track the number of model calls made at the thread and run level. + */ +const stateSchema = z.object({ + threadModelCallCount: z.number().default(0), + runModelCallCount: z.number().default(0), +}); + /** * Error thrown when the model call limit is exceeded. * @@ -133,6 +141,7 @@ export function modelCallLimitMiddleware( return createMiddleware({ name: "ModelCallLimitMiddleware", contextSchema, + stateSchema, beforeModel: { canJumpTo: ["end"], hook: (state, runtime) => { @@ -145,13 +154,12 @@ export function modelCallLimitMiddleware( const runLimit = runtime.context.runLimit ?? middlewareOptions?.runLimit; - if ( - typeof threadLimit === "number" && - threadLimit <= runtime.threadLevelCallCount - ) { + const threadCount = state.threadModelCallCount; + const runCount = state.runModelCallCount; + if (typeof threadLimit === "number" && threadLimit <= threadCount) { const error = new ModelCallLimitMiddlewareError({ threadLimit, - threadCount: runtime.threadLevelCallCount, + threadCount, }); if (exitBehavior === "end") { return { @@ -162,13 +170,10 @@ export function modelCallLimitMiddleware( throw error; } - if ( - typeof runLimit === "number" && - runLimit <= runtime.runModelCallCount - ) { + if (typeof runLimit === "number" && runLimit <= runCount) { const error = new ModelCallLimitMiddlewareError({ runLimit, - runCount: runtime.runModelCallCount, + runCount, }); if (exitBehavior === "end") { return { @@ -183,5 +188,9 @@ export function modelCallLimitMiddleware( return state; }, }, + afterModel: (state) => ({ + runModelCallCount: state.runModelCallCount + 1, + threadModelCallCount: state.threadModelCallCount + 1, + }), }); } diff --git a/libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts b/libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts index 45f52f30ad2c..d1541264b290 100644 --- a/libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts +++ b/libs/langchain/src/agents/middleware/tests/modelCallLimit.test.ts @@ -52,6 +52,10 @@ const tools = [ name: "tool_1", description: "tool_1", }), + tool(() => "barfoo", { + name: "tool_2", + description: "tool_2", + }), ]; describe("ModelCallLimitMiddleware", () => { @@ -143,14 +147,19 @@ describe("ModelCallLimitMiddleware", () => { checkpointer, }); if (exitBehavior === "throw") { - await expect( - agent2.invoke({ messages: ["Hello, world!"] }, config) - ).resolves.not.toThrow(); + const result = await agent2.invoke( + { messages: ["Hello, world!"] }, + config + ); + await expect(result.runModelCallCount).toBe(3); + await expect(result.threadModelCallCount).toBe(3); } else { const result = await agent2.invoke( { messages: ["Hello, world!"] }, config ); + await expect(result.runModelCallCount).toBe(3); + await expect(result.threadModelCallCount).toBe(3); expect(result.messages.at(-1)?.content).not.toContain( "Model call limits exceeded" ); diff --git a/libs/langchain/src/agents/middleware/tests/toolCallLimit.test.ts b/libs/langchain/src/agents/middleware/tests/toolCallLimit.test.ts index 83824426a8d4..a2dce0feac66 100644 --- a/libs/langchain/src/agents/middleware/tests/toolCallLimit.test.ts +++ b/libs/langchain/src/agents/middleware/tests/toolCallLimit.test.ts @@ -1,12 +1,10 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -/* eslint-disable no-instanceof/no-instanceof */ import { describe, it, expect } from "vitest"; import { z } from "zod/v3"; import { tool } from "@langchain/core/tools"; import { HumanMessage, AIMessage, - ToolMessage, type BaseMessage, type ToolCall, } from "@langchain/core/messages"; @@ -477,26 +475,17 @@ describe("toolCallLimitMiddleware", () => { // Test with state that has exceeded limit const state = { - messages: [ - new HumanMessage("Q"), - new AIMessage({ - content: "", - tool_calls: [ - { id: "1", name: "search", args: { query: "test1" } }, - ] as ToolCall[], - }), - new ToolMessage("Result", "1"), - new AIMessage({ - content: "", - tool_calls: [ - { id: "2", name: "search", args: { query: "test2" } }, - ] as ToolCall[], - }), - ], + messages: [], + threadToolCallCount: { + __all__: 2, + }, + runToolCallCount: { + __all__: 2, + }, }; await expect(async () => { - const fn = getHookFunction(middleware.beforeModel!); + const fn = getHookFunction(middleware.beforeModel as any); await fn(state as any, {} as any); }).rejects.toThrow(ToolCallLimitExceededError); }); @@ -509,39 +498,33 @@ describe("toolCallLimitMiddleware", () => { }); const state = { - messages: [ - new HumanMessage("Q"), - new AIMessage({ - content: "", - tool_calls: [ - { id: "1", name: "search", args: { query: "test1" } }, - ] as ToolCall[], - }), - new ToolMessage("Result", "1"), - new AIMessage({ - content: "", - tool_calls: [ - { id: "2", name: "search", args: { query: "test2" } }, - ] as ToolCall[], - }), - ], + messages: [], + threadToolCallCount: { + __all__: 2, + }, + runToolCallCount: { + __all__: 2, + }, }; try { - const fn = getHookFunction(middleware.beforeModel!); + const fn = getHookFunction(middleware.beforeModel as any); await fn(state as any, {} as any); expect.fail("Should have thrown error"); } catch (error) { expect(error).toBeInstanceOf(ToolCallLimitExceededError); - if (error instanceof ToolCallLimitExceededError) { - expect(error.threadCount).toBe(2); - expect(error.threadLimit).toBe(2); - expect(error.runCount).toBe(2); - expect(error.runLimit).toBe(1); - expect(error.toolName).toBeUndefined(); - expect(error.message).toContain("thread limit reached (2/2)"); - expect(error.message).toContain("run limit reached (2/1)"); - } + const toolCallLimitExceededError = error as ToolCallLimitExceededError; + expect(toolCallLimitExceededError.threadCount).toBe(2); + expect(toolCallLimitExceededError.threadLimit).toBe(2); + expect(toolCallLimitExceededError.runCount).toBe(2); + expect(toolCallLimitExceededError.runLimit).toBe(1); + expect(toolCallLimitExceededError.toolName).toBeUndefined(); + expect(toolCallLimitExceededError.message).toContain( + "thread limit reached (2/2)" + ); + expect(toolCallLimitExceededError.message).toContain( + "run limit reached (2/1)" + ); } }); @@ -553,36 +536,26 @@ describe("toolCallLimitMiddleware", () => { }); const state = { - messages: [ - new HumanMessage("Q"), - new AIMessage({ - content: "", - tool_calls: [ - { id: "1", name: "search", args: { query: "test1" } }, - { id: "2", name: "calculator", args: { expression: "1+1" } }, - ] as ToolCall[], - }), - new ToolMessage("Result", "1"), - new ToolMessage("Result", "2"), - new AIMessage({ - content: "", - tool_calls: [ - { id: "3", name: "search", args: { query: "test2" } }, - ] as ToolCall[], - }), - ], + messages: [], + threadToolCallCount: { + search: 2, + }, + runToolCallCount: { + search: 2, + }, }; try { - const fn = getHookFunction(middleware.beforeModel!); + const fn = getHookFunction(middleware.beforeModel! as any); await fn(state as any, {} as any); expect.fail("Should have thrown error"); } catch (error) { expect(error).toBeInstanceOf(ToolCallLimitExceededError); - if (error instanceof ToolCallLimitExceededError) { - expect(error.toolName).toBe("search"); - expect(error.message).toContain("'search' tool call"); - } + const toolCallLimitExceededError = error as ToolCallLimitExceededError; + expect(toolCallLimitExceededError.toolName).toBe("search"); + expect(toolCallLimitExceededError.message).toContain( + "'search' tool call" + ); } }); }); @@ -699,7 +672,7 @@ describe("toolCallLimitMiddleware", () => { messages: [], }; - const fn = getHookFunction(middleware.beforeModel!); + const fn = getHookFunction(middleware.beforeModel! as any); const result = await fn(state as any, {} as any); expect(result).toBeUndefined(); }); @@ -722,9 +695,15 @@ describe("toolCallLimitMiddleware", () => { ] as ToolCall[], }), ], + threadToolCallCount: { + __all__: 3, + }, + runToolCallCount: { + __all__: 3, + }, }; - const fn = getHookFunction(middleware.beforeModel!); + const fn = getHookFunction(middleware.beforeModel! as any); const result = await fn(state as any, {} as any); // Should hit limit (3 tool calls) diff --git a/libs/langchain/src/agents/middleware/toolCallLimit.ts b/libs/langchain/src/agents/middleware/toolCallLimit.ts index eb4f00c7a47e..e856013e1372 100644 --- a/libs/langchain/src/agents/middleware/toolCallLimit.ts +++ b/libs/langchain/src/agents/middleware/toolCallLimit.ts @@ -2,70 +2,12 @@ * Tool call limit middleware for agents. */ -import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages"; +import { AIMessage } from "@langchain/core/messages"; import { z } from "zod/v3"; import type { InferInteropZodInput } from "@langchain/core/utils/types"; import { createMiddleware } from "../middleware.js"; -/** - * Count tool calls in a list of messages. - * - * @param messages - List of messages to count tool calls in. - * @param toolName - If specified, only count calls to this specific tool. - * If undefined, count all tool calls. - * @returns The total number of tool calls (optionally filtered by toolName). - */ -function countToolCallsInMessages( - messages: BaseMessage[], - toolName?: string -): number { - let count = 0; - for (const message of messages) { - if (AIMessage.isInstance(message) && message.tool_calls) { - if (toolName === undefined) { - // Count all tool calls - count += message.tool_calls.length; - } else { - // Count only calls to the specified tool - count += message.tool_calls.filter((tc) => tc.name === toolName).length; - } - } - } - return count; -} - -/** - * Get messages from the current run (after the last HumanMessage). - * - * @param messages - Full list of messages. - * @returns Messages from the current run (after last HumanMessage). - */ -function getRunMessages(messages: BaseMessage[]): BaseMessage[] { - /** - * Find the last HumanMessage - */ - let lastHumanIndex = -1; - for (let i = messages.length - 1; i >= 0; i--) { - if (HumanMessage.isInstance(messages[i])) { - lastHumanIndex = i; - break; - } - } - - /** - * If no HumanMessage found, return all messages - */ - if (lastHumanIndex === -1) { - return messages; - } - - /** - * Return messages after the last HumanMessage - */ - return messages.slice(lastHumanIndex + 1); -} - /** * Build a message indicating which tool call limits were reached. * @@ -182,6 +124,16 @@ export type ToolCallLimitConfig = InferInteropZodInput< typeof ToolCallLimitOptionsSchema >; +/** + * Middleware state schema to track the number of model calls made at the thread and run level. + */ +const stateSchema = z.object({ + threadToolCallCount: z.record(z.string(), z.number()).default({}), + runToolCallCount: z.record(z.string(), z.number()).default({}), +}); + +const DEFAULT_TOOL_COUNT_KEY = "__all__"; + /** * Middleware that tracks tool call counts and enforces limits. * @@ -293,27 +245,25 @@ export function toolCallLimitMiddleware(options: ToolCallLimitConfig) { return createMiddleware({ name: middlewareName, + stateSchema, beforeModel: { canJumpTo: ["end"], hook: (state) => { - const messages = state.messages; - /** * Count tool calls in entire thread */ - const threadCount = countToolCallsInMessages( - messages, - options.toolName - ); + const threadCount = + state.threadToolCallCount?.[ + options.toolName ?? DEFAULT_TOOL_COUNT_KEY + ] ?? 0; /** * Count tool calls in current run (after last HumanMessage) */ - const runMessages = getRunMessages(messages); - const runCount = countToolCallsInMessages( - runMessages, - options.toolName - ); + const runCount = + state.runToolCallCount?.[ + options.toolName ?? DEFAULT_TOOL_COUNT_KEY + ] ?? 0; /** * Check if any limits are exceeded @@ -356,5 +306,33 @@ export function toolCallLimitMiddleware(options: ToolCallLimitConfig) { }; }, }, + afterModel: (state) => { + const lastAIMessage = [...state.messages] + .reverse() + .find(AIMessage.isInstance); + if (!lastAIMessage || !lastAIMessage.tool_calls) { + return state; + } + + const toolCallCount = lastAIMessage.tool_calls.filter( + (toolCall) => + options.toolName === undefined || toolCall.name === options.toolName + ).length; + if (toolCallCount === 0) { + return state; + } + + const countKey = options.toolName ?? DEFAULT_TOOL_COUNT_KEY; + const threadCounts = state.threadToolCallCount; + const runCounts = state.runToolCallCount; + + threadCounts[countKey] = (threadCounts[countKey] ?? 0) + toolCallCount; + runCounts[countKey] = (runCounts[countKey] ?? 0) + toolCallCount; + + return { + threadToolCallCount: threadCounts, + runToolCallCount: runCounts, + }; + }, }); } diff --git a/libs/langchain/src/agents/nodes/AgentNode.ts b/libs/langchain/src/agents/nodes/AgentNode.ts index 1b3cfe7c4090..f3a8ccab5431 100644 --- a/libs/langchain/src/agents/nodes/AgentNode.ts +++ b/libs/langchain/src/agents/nodes/AgentNode.ts @@ -25,7 +25,7 @@ import { } from "../utils.js"; import { mergeAbortSignals } from "../nodes/utils.js"; import { CreateAgentParams } from "../types.js"; -import type { InternalAgentState, Runtime, PrivateState } from "../runtime.js"; +import type { InternalAgentState, Runtime } from "../runtime.js"; import type { AgentMiddleware, AnyAnnotationRoot, @@ -103,18 +103,14 @@ export class AgentNode< ContextSchema extends AnyAnnotationRoot | InteropZodObject = AnyAnnotationRoot > extends RunnableCallable< InternalAgentState & PreHookAnnotation["State"], - | (( + | ( | { messages: BaseMessage[] } | { structuredResponse: StructuredResponseFormat } - ) & { _privateState: PrivateState }) + ) | Command > { #options: AgentNodeOptions; - #runState: Pick = { - runModelCallCount: 0, - }; - constructor( options: AgentNodeOptions ) { @@ -203,16 +199,10 @@ export class AgentNode< /** * return directly without invoking the model again */ - return { messages: [], _privateState: this.getState()._privateState }; + return { messages: [] }; } - const privateState = this.getState()._privateState; const response = await this.#invokeModel(state, config); - this.#runState.runModelCallCount++; - const _privateState = { - ...privateState, - threadLevelCallCount: privateState.threadLevelCallCount + 1, - }; /** * if we were able to generate a structured response, return it @@ -221,7 +211,6 @@ export class AgentNode< return { messages: [...state.messages, ...(response.messages || [])], structuredResponse: response.structuredResponse, - _privateState, }; } @@ -244,11 +233,10 @@ export class AgentNode< id: response.id, }), ], - _privateState, }; } - return { messages: [response], _privateState }; + return { messages: [response] }; } /** @@ -398,9 +386,7 @@ export class AgentNode< /** * Create runtime */ - const privateState = this.getState()._privateState; const runtime: Runtime = Object.freeze({ - ...privateState, context, writer: lgConfig.writer, interrupt: lgConfig.interrupt, @@ -529,7 +515,6 @@ export class AgentNode< } as InternalAgentState & PreHookAnnotation["State"], runtime: Object.freeze({ - ...this.getState()._privateState, context: lgConfig?.context, writer: lgConfig.writer, interrupt: lgConfig.interrupt, @@ -832,47 +817,24 @@ export class AgentNode< static get nodeOptions(): { input: z.ZodObject<{ messages: z.ZodArray>; - _privateState: z.ZodObject<{ - threadLevelCallCount: z.ZodNumber; - }>; }>; } { return { input: z.object({ messages: z.array(z.custom()), - _privateState: z.object({ - threadLevelCallCount: z.number(), - }), }), }; } getState(): { messages: BaseMessage[]; - _privateState: PrivateState; } { const state = super.getState(); - const origState = - state && !(state instanceof Command) - ? state - : ({ - _privateState: { - threadLevelCallCount: 0, - runModelCallCount: 0, - }, - } as { - messages?: BaseMessage[]; - _privateState?: PrivateState; - }); + const origState = state && !(state instanceof Command) ? state : {}; return { messages: [], ...origState, - _privateState: { - threadLevelCallCount: 0, - ...(origState._privateState ?? {}), - ...this.#runState, - }, }; } } diff --git a/libs/langchain/src/agents/nodes/ToolNode.ts b/libs/langchain/src/agents/nodes/ToolNode.ts index 069108a663cc..2f34fbf35677 100644 --- a/libs/langchain/src/agents/nodes/ToolNode.ts +++ b/libs/langchain/src/agents/nodes/ToolNode.ts @@ -21,7 +21,6 @@ import { RunnableCallable } from "../RunnableCallable.js"; import { PreHookAnnotation } from "../annotation.js"; import { mergeAbortSignals } from "./utils.js"; import { ToolInvocationError } from "../errors.js"; -import type { PrivateState } from "../runtime.js"; import type { AnyAnnotationRoot, WrapToolCallHook, @@ -70,11 +69,6 @@ export interface ToolNodeOptions { * The wrapper receives the tool call request and a handler function to execute the tool. */ wrapToolCall?: WrapToolCallHook; - /** - * Optional function to get the private state (threadLevelCallCount, runModelCallCount). - * Used to provide runtime metadata to wrapToolCall middleware. - */ - getPrivateState?: () => PrivateState; } const isBaseMessageArray = (input: unknown): input is BaseMessage[] => @@ -183,14 +177,11 @@ export class ToolNode< wrapToolCall?: WrapToolCallHook; - getPrivateState?: () => PrivateState; - constructor( tools: (StructuredToolInterface | DynamicTool | RunnableToolLike)[], public options?: ToolNodeOptions ) { - const { name, tags, handleToolErrors, wrapToolCall, getPrivateState } = - options ?? {}; + const { name, tags, handleToolErrors, wrapToolCall } = options ?? {}; super({ name, tags, @@ -204,7 +195,6 @@ export class ToolNode< this.tools = tools; this.handleToolErrors = handleToolErrors ?? this.handleToolErrors; this.wrapToolCall = wrapToolCall; - this.getPrivateState = getPrivateState; this.signal = options?.signal; } @@ -335,22 +325,11 @@ export class ToolNode< * Build runtime from LangGraph config */ const lgConfig = config as LangGraphRunnableConfig; - - /** - * Get private state if available - */ - const privateState = this.getPrivateState?.() || { - threadLevelCallCount: 0, - runModelCallCount: 0, - }; - const runtime = { context: lgConfig?.context, writer: lgConfig?.writer, interrupt: lgConfig?.interrupt, signal: lgConfig?.signal, - threadLevelCallCount: privateState.threadLevelCallCount, - runModelCallCount: privateState.runModelCallCount, }; /** diff --git a/libs/langchain/src/agents/nodes/middleware.ts b/libs/langchain/src/agents/nodes/middleware.ts index e99e36df60e7..ddc2c60a3b3b 100644 --- a/libs/langchain/src/agents/nodes/middleware.ts +++ b/libs/langchain/src/agents/nodes/middleware.ts @@ -5,9 +5,9 @@ import { interopParse } from "@langchain/core/utils/types"; import { RunnableCallable, RunnableCallableArgs } from "../RunnableCallable.js"; import type { JumpToTarget } from "../constants.js"; -import type { Runtime, PrivateState } from "../runtime.js"; +import type { Runtime } from "../runtime.js"; import type { AgentMiddleware, MiddlewareResult } from "../middleware/types.js"; -import { derivePrivateState, parseJumpToTarget } from "./utils.js"; +import { derivePrivateState } from "./utils.js"; import { getHookConstraint } from "../middleware/utils.js"; /** @@ -21,7 +21,7 @@ type NodeOutput> = | Command; export interface MiddlewareNodeOptions { - getPrivateState: () => PrivateState; + getState: () => Record; } export abstract class MiddlewareNode< @@ -49,7 +49,7 @@ export abstract class MiddlewareNode< ): Promise> | MiddlewareResult; async invokeMiddleware( - state: TStateSchema, + invokeState: TStateSchema, config?: LangGraphRunnableConfig ): Promise> { /** @@ -83,6 +83,15 @@ export abstract class MiddlewareNode< } } + const state: TStateSchema = { + ...invokeState, + ...this.#options.getState(), + /** + * don't overwrite possible outdated messages from other middleware nodes + */ + messages: invokeState.messages, + }; + /** * ToDo: implement later */ @@ -91,7 +100,6 @@ export abstract class MiddlewareNode< writer: config?.writer, interrupt: config?.interrupt, signal: config?.signal, - ...this.#options.getPrivateState(), }; const result = await this.runHook( @@ -109,7 +117,6 @@ export abstract class MiddlewareNode< }) ) ); - delete result?._privateState; /** * If result is undefined, return current state @@ -151,8 +158,6 @@ export abstract class MiddlewareNode< throw new Error(`Invalid jump target: ${result.jumpTo}, ${suggestion}.`); } - const jumpTo = parseJumpToTarget(result.jumpTo as string); - /** * If result is a control action, handle it */ @@ -165,7 +170,7 @@ export abstract class MiddlewareNode< return { ...state, ...(result.result || {}), - jumpTo, + jumpTo: result.jumpTo, }; } @@ -175,7 +180,7 @@ export abstract class MiddlewareNode< /** * If result is a state update, merge it with current state */ - return { ...state, ...result, jumpTo }; + return { ...state, ...result, jumpTo: result.jumpTo }; } get nodeOptions(): { diff --git a/libs/langchain/src/agents/runtime.ts b/libs/langchain/src/agents/runtime.ts index ed0dd821ed70..6bb199f31668 100644 --- a/libs/langchain/src/agents/runtime.ts +++ b/libs/langchain/src/agents/runtime.ts @@ -61,39 +61,13 @@ export type WithMaybeContext = undefined extends TContext export type Runtime = Partial< Omit, "context" | "configurable"> > & - WithMaybeContext & - PrivateState & { + WithMaybeContext & { configurable?: { thread_id?: string; [key: string]: unknown; }; }; -export interface RunLevelPrivateState { - /** - * The number of times the model has been called at the run level. - * This includes multiple agent invocations. - */ - runModelCallCount: number; -} -export interface ThreadLevelPrivateState { - /** - * The number of times the model has been called at the thread level. - * This includes multiple agent invocations within different environments - * using the same thread. - */ - threadLevelCallCount: number; -} - -/** - * As private state we consider all information we want to track within - * the lifecycle of the agent, without exposing it to the user. These informations - * are propagated to the user as _readonly_ runtime properties. - */ -export interface PrivateState - extends ThreadLevelPrivateState, - RunLevelPrivateState {} - export type InternalAgentState< StructuredResponseType extends Record | undefined = Record< string, @@ -101,7 +75,6 @@ export type InternalAgentState< > > = { messages: BaseMessage[]; - _privateState?: PrivateState; } & (StructuredResponseType extends ResponseFormatUndefined ? Record : { structuredResponse: StructuredResponseType }); diff --git a/libs/langchain/src/agents/tests/middleware.test.ts b/libs/langchain/src/agents/tests/middleware.test.ts index 42328743941e..2999a7ae9660 100644 --- a/libs/langchain/src/agents/tests/middleware.test.ts +++ b/libs/langchain/src/agents/tests/middleware.test.ts @@ -515,8 +515,6 @@ describe("middleware", () => { middlewareContext: number; }>(); expectTypeOf(request.systemPrompt!).toBeString(); - expectTypeOf(request.runtime.runModelCallCount).toBeNumber(); - expectTypeOf(request.runtime.threadLevelCallCount).toBeNumber(); // Capture state and runtime capturedState = request.state; @@ -563,19 +561,14 @@ describe("middleware", () => { "Test" ); - const { context, threadLevelCallCount, runModelCallCount } = - capturedRuntime; + const { context } = capturedRuntime; expect({ context, - threadLevelCallCount, - runModelCallCount, }).toMatchInlineSnapshot(` { "context": { "middlewareContext": 2, }, - "runModelCallCount": 0, - "threadLevelCallCount": 0, } `); }); @@ -1017,8 +1010,6 @@ describe("middleware", () => { `); expect(request.runtime.context).toEqual({ foo: 123 }); expect(request.state.bar).toBe(true); - expect(request.runtime.runModelCallCount).toBe(1); - expect(request.runtime.threadLevelCallCount).toBe(1); /** * Let's test if we can modify tool args @@ -2266,35 +2257,6 @@ describe("middleware", () => { expect(messageContents).toContain("Original message"); }); - it("should allow accessing runtime metadata in before_agent and after_agent", async () => { - const middleware = createMiddleware({ - name: "RuntimeAccessMiddleware", - beforeAgent: async (_state, runtime) => { - expect(runtime.threadLevelCallCount).toBe(0); - expect(runtime.runModelCallCount).toBe(0); - }, - afterAgent: async (_state, runtime) => { - // After the agent completes, counts should be updated - expect(runtime.threadLevelCallCount).toBe(1); - expect(runtime.runModelCallCount).toBe(1); - }, - }); - - const model = new FakeToolCallingChatModel({ - responses: [new AIMessage("Response")], - }); - - const agent = createAgent({ - model, - tools: [], - middleware: [middleware], - }); - - await agent.invoke({ - messages: [new HumanMessage("Test")], - }); - }); - it("should propagate state changes from before_agent through the entire agent execution", async () => { const middleware = createMiddleware({ name: "StateTracker", diff --git a/libs/langchain/src/agents/tests/runtime.test.ts b/libs/langchain/src/agents/tests/runtime.test.ts index 7d7d3d39d568..aad665c16638 100644 --- a/libs/langchain/src/agents/tests/runtime.test.ts +++ b/libs/langchain/src/agents/tests/runtime.test.ts @@ -10,7 +10,8 @@ describe("runtime", () => { const middleware = createMiddleware({ name: "middleware", beforeModel: async (_, runtime) => { - runtime.runModelCallCount = 123; + // @ts-expect-error context is typed as readonly + runtime.context = 123; }, }); @@ -31,7 +32,8 @@ describe("runtime", () => { const middleware = createMiddleware({ name: "middleware", afterModel: async (_, runtime) => { - runtime.runModelCallCount = 123; + // @ts-expect-error context is typed as readonly + runtime.context = 123; }, }); @@ -52,7 +54,8 @@ describe("runtime", () => { const middleware = createMiddleware({ name: "middleware", wrapModelCall: async (request, handler) => { - request.runtime.runModelCallCount = 123; + // @ts-expect-error context is typed as readonly + request.runtime.context = 123; return handler(request); }, }); diff --git a/libs/langchain/src/agents/tests/state.test.ts b/libs/langchain/src/agents/tests/state.test.ts index e390d426ab00..1ba29b621913 100644 --- a/libs/langchain/src/agents/tests/state.test.ts +++ b/libs/langchain/src/agents/tests/state.test.ts @@ -1,18 +1,10 @@ import { z } from "zod/v3"; import { describe, it, expect } from "vitest"; import { HumanMessage } from "@langchain/core/messages"; -import { MemorySaver } from "@langchain/langgraph-checkpoint"; import { createMiddleware, createAgent } from "../index.js"; import { FakeToolCallingModel } from "./utils.js"; -const checkpointer = new MemorySaver(); -const config = { - configurable: { - thread_id: "test-123", - }, -}; - describe("middleware state management", () => { it("should allow to define private state props with _ that doesn't leak out", async () => { expect.assertions(10); @@ -140,101 +132,4 @@ describe("middleware state management", () => { middlewareCAfterModelState: "middlewareCAfterModelState", }); }); - - it("should track thread level call count and run model call count as part of a private state", async () => { - expect.assertions(9); - const model = new FakeToolCallingModel({}); - const middleware = createMiddleware({ - name: "middleware", - // @ts-expect-error _privateState is not an expected return type - beforeModel: async (_, runtime) => { - expect(runtime.threadLevelCallCount).toBe(0); - expect(runtime.runModelCallCount).toBe(0); - - /** - * try to override the private state - */ - return { - _privateState: { - threadLevelCallCount: 123, - runModelCallCount: 123, - }, - }; - }, - wrapModelCall: async (request, handler) => { - expect(request.runtime.threadLevelCallCount).toBe(0); - expect(request.runtime.runModelCallCount).toBe(0); - return handler(request); - }, - // @ts-expect-error _privateState is not an expected return type - afterModel: async (_, runtime) => { - expect(runtime.threadLevelCallCount).toBe(1); - expect(runtime.runModelCallCount).toBe(1); - - /** - * try to override the private state - */ - return { - _privateState: { - threadLevelCallCount: 123, - runModelCallCount: 123, - }, - }; - }, - }); - - const agent = createAgent({ - model, - middleware: [middleware], - checkpointer, - }); - - const result = await agent.invoke( - { - messages: [new HumanMessage("What is the weather in Tokyo?")], - }, - config - ); - - // @ts-expect-error should not be defined in the state - expect(result.threadLevelCallCount).toBe(undefined); - // @ts-expect-error should not be defined in the state - expect(result.runModelCallCount).toBe(undefined); - // @ts-expect-error should not be defined in the state - expect(result._privateState).toBe(undefined); - }); - - it("should allow to continue counting thread level call count and run model call count across multiple invocations", async () => { - expect.assertions(6); - const model = new FakeToolCallingModel({}); - const middleware = createMiddleware({ - name: "middleware", - beforeModel: async (_, runtime) => { - expect(runtime.threadLevelCallCount).toBe(1); - expect(runtime.runModelCallCount).toBe(0); - }, - wrapModelCall: async (request, handler) => { - expect(request.runtime.threadLevelCallCount).toBe(1); - expect(request.runtime.runModelCallCount).toBe(0); - return handler(request); - }, - afterModel: async (_, runtime) => { - expect(runtime.threadLevelCallCount).toBe(2); - expect(runtime.runModelCallCount).toBe(1); - }, - }); - - const agent = createAgent({ - model, - middleware: [middleware], - checkpointer, - }); - - await agent.invoke( - { - messages: [new HumanMessage("What is the weather in Tokyo?")], - }, - config - ); - }); }); From c8fc585140a2aef1a5c907b6d7cd3851dd297e79 Mon Sep 17 00:00:00 2001 From: Christian Bromann Date: Thu, 30 Oct 2025 15:25:56 -0700 Subject: [PATCH 8/8] feat(@langchain/core): support tools with custom state or context provided by ToolRuntime --- libs/langchain-core/src/tools/index.ts | 230 ++++++++++++++++++++++--- libs/langchain-core/src/tools/types.ts | 97 +++++++++++ 2 files changed, 302 insertions(+), 25 deletions(-) diff --git a/libs/langchain-core/src/tools/index.ts b/libs/langchain-core/src/tools/index.ts index 43d2b9d94de6..78c4a29520cc 100644 --- a/libs/langchain-core/src/tools/index.ts +++ b/libs/langchain-core/src/tools/index.ts @@ -17,7 +17,6 @@ import { pickRunnableConfigKeys, type RunnableConfig, } from "../runnables/config.js"; -import type { RunnableFunc } from "../runnables/base.js"; import { isDirectToolOutput, ToolCall, ToolMessage } from "../messages/tool.js"; import { AsyncLocalStorageProviderSingleton } from "../singletons/index.js"; import { @@ -54,6 +53,7 @@ import type { StringInputToolSchema, ToolInterface, ToolOutputType, + ToolRuntime, } from "./types.js"; import { type JSONSchema, validatesOnlyStrings } from "../utils/json_schema.js"; @@ -71,6 +71,7 @@ export type { ToolReturnType, ToolRunnableConfig, ToolInputSchemaBase as ToolSchemaBase, + ToolRuntime, } from "./types.js"; export { @@ -511,14 +512,64 @@ export abstract class BaseToolkit { } } +/** + * Helper type to check if a schema is defined (not undefined). + */ +type IsSchemaDefined = T extends undefined ? false : true; + +/** + * Helper type to determine if runtime should be passed to the function. + */ +type ShouldPassRuntime< + StateSchema extends InteropZodObject | undefined, + ContextSchema extends InteropZodObject | undefined +> = IsSchemaDefined extends true + ? true + : IsSchemaDefined extends true + ? true + : false; + +/** + * Helper type to create RunnableFunc with optional runtime parameter. + */ +type RunnableFuncWithRuntime< + RunInput, + RunOutput, + StateSchema extends InteropZodObject | undefined, + ContextSchema extends InteropZodObject | undefined, + CallOptions extends RunnableConfig = RunnableConfig +> = ShouldPassRuntime extends true + ? ( + input: RunInput, + runtime: ToolRuntime, + options?: + | CallOptions + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | (Record & CallOptions) + ) => RunOutput | Promise + : ( + input: RunInput, + options?: + | CallOptions + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | Record + // eslint-disable-next-line @typescript-eslint/no-explicit-any + | (Record & CallOptions) + ) => RunOutput | Promise; + /** * Parameters for the tool function. * Schema can be provided as Zod or JSON schema. * Both schema types will be validated. * @template {ToolInputSchemaBase} RunInput The input schema for the tool. */ -interface ToolWrapperParams - extends ToolParams { +interface ToolWrapperParams< + RunInput = ToolInputSchemaBase | undefined, + StateSchema extends InteropZodObject | undefined = undefined, + ContextSchema extends InteropZodObject | undefined = undefined +> extends ToolParams { /** * The name of the tool. If using with an LLM, this * will be passed as the tool name. @@ -552,6 +603,14 @@ interface ToolWrapperParams * an agent should stop looping. */ returnDirect?: boolean; + /** + * The state schema for the tool runtime. + */ + stateSchema?: StateSchema; + /** + * The context schema for the tool runtime. + */ + contextSchema?: ContextSchema; } /** @@ -562,6 +621,8 @@ interface ToolWrapperParams * @function * @template {ToolInputSchemaBase} SchemaT The input schema for the tool. * @template {ToolReturnType} ToolOutputT The output type of the tool. + * @template {InteropZodObject | undefined} StateSchema The state schema for the tool runtime. + * @template {InteropZodObject | undefined} ContextSchema The context schema for the tool runtime. * * @param {RunnableFunc, ToolOutputT>} func - The function to invoke when the tool is called. * @param {ToolWrapperParams} fields - An object containing the following properties: @@ -571,56 +632,90 @@ interface ToolWrapperParams * * @returns {DynamicStructuredTool} A new StructuredTool instance. */ -export function tool( - func: RunnableFunc< +export function tool< + SchemaT extends ZodStringV3, + ToolOutputT = ToolOutputType, + StateSchema extends InteropZodObject | undefined = undefined, + ContextSchema extends InteropZodObject | undefined = undefined +>( + func: RunnableFuncWithRuntime< InferInteropZodOutput, ToolOutputT, + StateSchema, + ContextSchema, ToolRunnableConfig >, - fields: ToolWrapperParams + fields: ToolWrapperParams ): DynamicTool; -export function tool( - func: RunnableFunc< +export function tool< + SchemaT extends ZodStringV4, + ToolOutputT = ToolOutputType, + StateSchema extends InteropZodObject | undefined = undefined, + ContextSchema extends InteropZodObject | undefined = undefined +>( + func: RunnableFuncWithRuntime< InferInteropZodOutput, ToolOutputT, + StateSchema, + ContextSchema, ToolRunnableConfig >, - fields: ToolWrapperParams + fields: ToolWrapperParams ): DynamicTool; export function tool< SchemaT extends ZodObjectV3, SchemaOutputT = InferInteropZodOutput, SchemaInputT = InferInteropZodInput, - ToolOutputT = ToolOutputType + ToolOutputT = ToolOutputType, + StateSchema extends InteropZodObject | undefined = undefined, + ContextSchema extends InteropZodObject | undefined = undefined >( - func: RunnableFunc, - fields: ToolWrapperParams + func: RunnableFuncWithRuntime< + SchemaOutputT, + ToolOutputT, + StateSchema, + ContextSchema, + ToolRunnableConfig + >, + fields: ToolWrapperParams ): DynamicStructuredTool; export function tool< SchemaT extends ZodObjectV4, SchemaOutputT = InferInteropZodOutput, SchemaInputT = InferInteropZodInput, - ToolOutputT = ToolOutputType + ToolOutputT = ToolOutputType, + StateSchema extends InteropZodObject | undefined = undefined, + ContextSchema extends InteropZodObject | undefined = undefined >( - func: RunnableFunc, - fields: ToolWrapperParams + func: RunnableFuncWithRuntime< + SchemaOutputT, + ToolOutputT, + StateSchema, + ContextSchema, + ToolRunnableConfig + >, + fields: ToolWrapperParams ): DynamicStructuredTool; export function tool< SchemaT extends JSONSchema, SchemaOutputT = ToolInputSchemaOutputType, SchemaInputT = ToolInputSchemaInputType, - ToolOutputT = ToolOutputType + ToolOutputT = ToolOutputType, + StateSchema extends InteropZodObject | undefined = undefined, + ContextSchema extends InteropZodObject | undefined = undefined >( - func: RunnableFunc< + func: RunnableFuncWithRuntime< Parameters["func"]>[0], ToolOutputT, + StateSchema, + ContextSchema, ToolRunnableConfig >, - fields: ToolWrapperParams + fields: ToolWrapperParams ): DynamicStructuredTool; export function tool< @@ -630,15 +725,26 @@ export function tool< | JSONSchema = InteropZodObject, SchemaOutputT = ToolInputSchemaOutputType, SchemaInputT = ToolInputSchemaInputType, - ToolOutputT = ToolOutputType + ToolOutputT = ToolOutputType, + StateSchema extends InteropZodObject | undefined = undefined, + ContextSchema extends InteropZodObject | undefined = undefined >( - func: RunnableFunc, - fields: ToolWrapperParams + func: RunnableFuncWithRuntime< + SchemaOutputT, + ToolOutputT, + StateSchema, + ContextSchema, + ToolRunnableConfig + >, + fields: ToolWrapperParams ): | DynamicStructuredTool | DynamicTool { const isSimpleStringSchema = isSimpleStringZodSchema(fields.schema); const isStringJSONSchema = validatesOnlyStrings(fields.schema); + const hasStateSchema = fields.stateSchema !== undefined; + const hasContextSchema = fields.contextSchema !== undefined; + const shouldPassRuntime = hasStateSchema || hasContextSchema; // If the schema is not provided, or it's a simple string schema, create a DynamicTool if (!fields.schema || isSimpleStringSchema || isStringJSONSchema) { @@ -658,9 +764,50 @@ export function tool< pickRunnableConfigKeys(childConfig), async () => { try { - // TS doesn't restrict the type here based on the guard above - // eslint-disable-next-line @typescript-eslint/no-explicit-any - resolve(func(input as any, childConfig)); + if (shouldPassRuntime) { + // Construct runtime object from config + // State will be provided by ToolNode, but we create a minimal runtime here + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const lgConfig = config as any; + const toolConfig = childConfig as ToolRunnableConfig; + const runtime: ToolRuntime = { + state: (lgConfig?.state || + {}) as StateSchema extends InteropZodObject + ? InferInteropZodOutput + : Record, + toolCallId: toolConfig?.toolCall?.id || "", + config: toolConfig, + context: (lgConfig?.context || + undefined) as ContextSchema extends InteropZodObject + ? InferInteropZodOutput + : unknown, + store: lgConfig?.store || null, + writer: lgConfig?.writer || null, + }; + const funcWithRuntime = func as ( + input: unknown, + runtime: ToolRuntime, + options?: unknown + ) => ToolOutputT | Promise; + resolve( + await funcWithRuntime( + input as InferInteropZodOutput, + runtime, + childConfig + ) + ); + } else { + const funcWithoutRuntime = func as ( + input: unknown, + options?: unknown + ) => ToolOutputT | Promise; + resolve( + await funcWithoutRuntime( + input as InferInteropZodOutput, + childConfig + ) + ); + } } catch (e) { reject(e); } @@ -703,7 +850,40 @@ export function tool< pickRunnableConfigKeys(childConfig), async () => { try { - const result = await func(input, childConfig); + let result: ToolOutputT; + if (shouldPassRuntime) { + // Construct runtime object from config + // State will be provided by ToolNode, but we create a minimal runtime here + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const lgConfig = config as any; + const toolConfig = childConfig as ToolRunnableConfig; + const runtime: ToolRuntime = { + state: (lgConfig?.state || + {}) as StateSchema extends InteropZodObject + ? InferInteropZodOutput + : Record, + toolCallId: toolConfig?.toolCall?.id || "", + config: toolConfig, + context: (lgConfig?.context || + undefined) as ContextSchema extends InteropZodObject + ? InferInteropZodOutput + : unknown, + store: lgConfig?.store || null, + writer: lgConfig?.writer || null, + }; + const funcWithRuntime = func as ( + input: SchemaOutputT, + runtime: ToolRuntime, + options?: unknown + ) => ToolOutputT | Promise; + result = await funcWithRuntime(input, runtime, childConfig); + } else { + const funcWithoutRuntime = func as ( + input: SchemaOutputT, + options?: unknown + ) => ToolOutputT | Promise; + result = await funcWithoutRuntime(input, childConfig); + } /** * If the signal is aborted, we don't want to resolve the promise diff --git a/libs/langchain-core/src/tools/types.ts b/libs/langchain-core/src/tools/types.ts index 5eb31e866598..14be5f3013a3 100644 --- a/libs/langchain-core/src/tools/types.ts +++ b/libs/langchain-core/src/tools/types.ts @@ -20,9 +20,11 @@ import { type InferInteropZodInput, type InferInteropZodOutput, type InteropZodType, + type InteropZodObject, isInteropZodSchema, } from "../utils/types/zod.js"; import { JSONSchema } from "../utils/json_schema.js"; +import type { BaseStore } from "../stores.js"; export type ResponseFormat = "content" | "content_and_artifact" | string; @@ -425,3 +427,98 @@ export function isLangChainTool(tool?: unknown): tool is StructuredToolParams { isStructuredTool(tool as any) ); } + +/** + * Runtime context automatically injected into tools. + * + * When a tool function has a parameter named `tool_runtime` with type hint + * `ToolRuntime`, the tool execution system will automatically inject an instance + * containing: + * + * - `state`: The current graph state + * - `toolCallId`: The ID of the current tool call + * - `config`: `RunnableConfig` for the current execution + * - `context`: Runtime context + * - `store`: `BaseStore` instance for persistent storage + * - `writer`: Stream writer for streaming output + * + * No `Annotated` wrapper is needed - just use `runtime: ToolRuntime` + * as a parameter. + * + * @example + * ```typescript + * import { tool, ToolRuntime } from "@langchain/core/tools"; + * import { z } from "zod"; + * + * const stateSchema = z.object({ + * messages: z.array(z.any()), + * userId: z.string().optional(), + * }); + * + * const greet = tool( + * async ({ name }, runtime) => { + * // Access state + * const messages = runtime.state.messages; + * + * // Access tool_call_id + * console.log(`Tool call ID: ${runtime.toolCallId}`); + * + * // Access config + * console.log(`Run ID: ${runtime.config.runId}`); + * + * // Access runtime context + * const userId = runtime.context?.userId; + * + * // Access store + * await runtime.store?.mset([["key", "value"]]); + * + * // Stream output + * runtime.writer?.("Processing..."); + * + * return `Hello! User ID: ${runtime.state.userId || "unknown"} ${name}`; + * }, + * { + * name: "greet", + * description: "Use this to greet the user once you found their info.", + * schema: z.object({ name: z.string() }), + * stateSchema, + * } + * ); + * ``` + * + * @template StateT - The type of the state schema (inferred from stateSchema) + * @template ContextT - The type of the context schema (inferred from contextSchema) + */ +export interface ToolRuntime< + StateT extends InteropZodObject | undefined = undefined, + ContextT extends InteropZodObject | undefined = undefined +> { + /** + * The current graph state. + */ + state: StateT extends InteropZodObject + ? InferInteropZodOutput + : Record; + /** + * The ID of the current tool call. + */ + toolCallId: string; + /** + * RunnableConfig for the current execution. + */ + config: ToolRunnableConfig; + /** + * Runtime context (from langgraph `Runtime`). + */ + context: ContextT extends InteropZodObject + ? InferInteropZodOutput + : unknown; + /** + * BaseStore instance for persistent storage (from langgraph `Runtime`). + */ + store: BaseStore | null; + /** + * Stream writer for streaming output (from langgraph `Runtime`). + */ + writer: ((chunk: unknown) => void) | null; +}