From 1fa7f6b441055f87925828a26cf3bcf463309e83 Mon Sep 17 00:00:00 2001 From: Hunter Lovell Date: Wed, 20 Aug 2025 10:20:18 -0700 Subject: [PATCH] wip: chat model without binding --- .../src/language_models/base.ts | 66 +----- .../src/language_models/chat_models.ts | 203 +++++++++--------- .../src/language_models/tests/v1.test.ts | 160 ++++++++++++++ libs/langchain-core/src/types/type-utils.ts | 2 + 4 files changed, 273 insertions(+), 158 deletions(-) create mode 100644 libs/langchain-core/src/language_models/tests/v1.test.ts diff --git a/libs/langchain-core/src/language_models/base.ts b/libs/langchain-core/src/language_models/base.ts index b06b9f21a8e8..f89a3221d0cf 100644 --- a/libs/langchain-core/src/language_models/base.ts +++ b/libs/langchain-core/src/language_models/base.ts @@ -560,49 +560,15 @@ export abstract class BaseLanguageModel< throw new Error("Use .toJSON() instead"); } - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput?( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseLanguageModel; - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput?( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions - ): Runnable; - - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; - - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseLanguageModel<{ raw: BaseMessage; parsed: Output }, CallOptions>; /** * Model wrapper that returns outputs formatted to match the given schema. @@ -617,24 +583,12 @@ export abstract class BaseLanguageModel< * @param {boolean | undefined} [includeRaw=false] Whether to include the raw output in the result. Defaults to false. * @returns {Runnable | Runnable} A new runnable that calls the LLM with structured output. */ - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | InteropZodType - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput?( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions ): - | Runnable - | Runnable< - BaseLanguageModelInput, - { - raw: BaseMessage; - parsed: RunOutput; - } - >; + | BaseLanguageModel + | BaseLanguageModel<{ raw: BaseMessage; parsed: Output }, CallOptions>; } /** diff --git a/libs/langchain-core/src/language_models/chat_models.ts b/libs/langchain-core/src/language_models/chat_models.ts index c0ac0af0553c..0fda69b1074c 100644 --- a/libs/langchain-core/src/language_models/chat_models.ts +++ b/libs/langchain-core/src/language_models/chat_models.ts @@ -1,5 +1,3 @@ -import type { ZodType as ZodTypeV3 } from "zod/v3"; -import type { $ZodType as ZodTypeV4 } from "zod/v4/core"; import { AIMessage, type BaseMessage, @@ -46,18 +44,17 @@ import { import { Runnable, RunnableLambda, - RunnableSequence, RunnableToolLike, } from "../runnables/base.js"; import { concat } from "../utils/stream.js"; -import { RunnablePassthrough } from "../runnables/passthrough.js"; import { getSchemaDescription, InteropZodType, isInteropZodSchema, } from "../utils/types/zod.js"; import { callbackHandlerPrefersStreaming } from "../callbacks/base.js"; -import { toJsonSchema } from "../utils/json_schema.js"; +import { JSONSchema, toJsonSchema } from "../utils/json_schema.js"; +import { Constructor } from "../types/type-utils.js"; // eslint-disable-next-line @typescript-eslint/no-explicit-any export type ToolChoice = string | Record | "auto" | "any"; @@ -183,15 +180,19 @@ export type BindToolsInput = | RunnableToolLike | StructuredToolParams; +export type ChatModelOutputParser = Runnable; +export type InferChatModelOutputParser = TOutput extends BaseMessage + ? ChatModelOutputParser + : undefined; + /** * Base class for chat models. It extends the BaseLanguageModel class and * provides methods for generating chat based on input messages. */ export abstract class BaseChatModel< CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions, - // TODO: Fix the parameter order on the next minor version. - OutputMessageType extends BaseMessageChunk = AIMessageChunk -> extends BaseLanguageModel { + TOutput = BaseMessage +> extends BaseLanguageModel { // Backwards compatibility since fields have been moved to RunnableConfig declare ParsedCallOptions: Omit< CallOptions, @@ -203,10 +204,30 @@ export abstract class BaseChatModel< disableStreaming = false; - constructor(fields: BaseChatModelParams) { + outputParser: InferChatModelOutputParser; + + defaultOptions: CallOptions; + + constructor(protected fields: BaseChatModelParams) { super(fields); } + protected _combineCallOptions( + additionalOptions?: Partial + ): CallOptions { + return { + ...this.defaultOptions, + ...additionalOptions, + }; + } + + protected async _parseOutput(output: BaseMessage): Promise { + if (this.outputParser) { + return this.outputParser.invoke(output); + } + return output as TOutput; + } + _combineLLMOutput?( ...llmOutputs: LLMResult["llmOutput"][] ): LLMResult["llmOutput"]; @@ -229,10 +250,10 @@ export abstract class BaseChatModel< * matching the provider's specific tool schema. * @param kwargs Any additional parameters to bind. */ - bindTools?( + abstract bindTools?( tools: BindToolsInput[], - kwargs?: Partial - ): Runnable; + options?: Partial + ): BaseChatModel; /** * Invokes the chat model with a single input. @@ -243,7 +264,7 @@ export abstract class BaseChatModel< async invoke( input: BaseLanguageModelInput, options?: CallOptions - ): Promise { + ): Promise { const promptValue = BaseChatModel._convertInputToPromptValue(input); const result = await this.generatePrompt( [promptValue], @@ -251,8 +272,7 @@ export abstract class BaseChatModel< options?.callbacks ); const chatGeneration = result.generations[0][0] as ChatGeneration; - // TODO: Remove cast after figuring out inheritance - return chatGeneration.message as OutputMessageType; + return this._parseOutput(chatGeneration.message); } // eslint-disable-next-line require-yield @@ -267,7 +287,7 @@ export abstract class BaseChatModel< async *_streamIterator( input: BaseLanguageModelInput, options?: CallOptions - ): AsyncGenerator { + ): AsyncGenerator { // Subclass check required to avoid double callbacks with default implementation if ( this._streamResponseChunks === @@ -278,8 +298,9 @@ export abstract class BaseChatModel< } else { const prompt = BaseChatModel._convertInputToPromptValue(input); const messages = prompt.toChatMessages(); + const combinedOptions = this._combineCallOptions(options); const [runnableConfig, callOptions] = - this._separateRunnableConfigFromCallOptionsCompat(options); + this._separateRunnableConfigFromCallOptionsCompat(combinedOptions); const inheritableMetadata = { ...runnableConfig.metadata, @@ -326,7 +347,7 @@ export abstract class BaseChatModel< ...chunk.generationInfo, ...chunk.message.response_metadata, }; - yield chunk.message as OutputMessageType; + yield await this._parseOutput(chunk.message); if (!generationChunk) { generationChunk = chunk; } else { @@ -904,68 +925,52 @@ export abstract class BaseChatModel< return result.content; } - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; + withConfig(config: Partial): this { + const Cls = this.constructor as Constructor; + const instance = new Cls(this.fields); + instance.defaultOptions = { + ...this.defaultOptions, + ...config, + }; + return instance; + } - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; + /** @internal */ + protected withOutputParser( + outputParser: ChatModelOutputParser + ): BaseChatModel { + const Cls = this.constructor as Constructor< + BaseChatModel + >; + const instance = new Cls(this.fields); + instance.outputParser = outputParser as InferChatModelOutputParser; + instance.defaultOptions = this.defaultOptions; + return instance; + } - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseChatModel; - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseChatModel; - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | InteropZodType - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions ): - | Runnable - | Runnable< - BaseLanguageModelInput, - { - raw: BaseMessage; - parsed: RunOutput; - } - > { + | BaseChatModel + | BaseChatModel; + + withStructuredOutput( + schema: InteropZodType | JSONSchema, + config?: StructuredOutputMethodOptions + ): + | BaseChatModel + | BaseChatModel { if (typeof this.bindTools !== "function") { throw new Error( `Chat model must implement ".bindTools()" to use withStructuredOutput.` @@ -976,9 +981,6 @@ export abstract class BaseChatModel< `"strict" mode is not supported for this model by default.` ); } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const schema: Record | InteropZodType = - outputSchema; const name = config?.name; const description = getSchemaDescription(schema) ?? "A function available to call."; @@ -1004,7 +1006,7 @@ export abstract class BaseChatModel< }, ]; } else { - if ("name" in schema) { + if ("name" in schema && typeof schema.name === "string") { functionName = schema.name; } tools = [ @@ -1020,8 +1022,8 @@ export abstract class BaseChatModel< } const llm = this.bindTools(tools); - const outputParser = RunnableLambda.from( - (input: AIMessageChunk): RunOutput => { + const toolMessageParser = RunnableLambda.from( + (input: AIMessageChunk): Output => { if (!input.tool_calls || input.tool_calls.length === 0) { throw new Error("No tool calls found in the response."); } @@ -1031,37 +1033,34 @@ export abstract class BaseChatModel< if (!toolCall) { throw new Error(`No tool call found with name ${functionName}.`); } - return toolCall.args as RunOutput; + return toolCall.args as Output; } ); if (!includeRaw) { - return llm.pipe(outputParser).withConfig({ - runName: "StructuredOutput", - }) as Runnable; + return llm.withOutputParser( + toolMessageParser.withConfig({ + runName: "StructuredOutput", + }) + ); } - const parserAssign = RunnablePassthrough.assign({ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - parsed: (input: any, config) => outputParser.invoke(input.raw, config), - }); - const parserNone = RunnablePassthrough.assign({ - parsed: () => null, - }); - const parsedWithFallback = parserAssign.withFallbacks({ - fallbacks: [parserNone], - }); - return RunnableSequence.from< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput } - >([ - { - raw: llm, - }, - parsedWithFallback, - ]).withConfig({ - runName: "StructuredOutputRunnable", + const rawOutputParser = RunnableLambda.from< + AIMessageChunk, + { raw: BaseMessage; parsed: Output } + >( + async ( + input: AIMessageChunk + ): Promise<{ raw: BaseMessage; parsed: Output }> => { + return { + raw: input, + parsed: await toolMessageParser.invoke(input), + }; + } + ).withConfig({ + runName: "StructuredOutput", }); + return llm.withOutputParser(rawOutputParser); } } diff --git a/libs/langchain-core/src/language_models/tests/v1.test.ts b/libs/langchain-core/src/language_models/tests/v1.test.ts new file mode 100644 index 000000000000..290612af00d6 --- /dev/null +++ b/libs/langchain-core/src/language_models/tests/v1.test.ts @@ -0,0 +1,160 @@ +import { describe, it } from "vitest"; +import { z } from "zod"; +import { + BaseChatModel, + BaseChatModelCallOptions, + BaseChatModelParams, + BindToolsInput, +} from "../chat_models.js"; +import { BaseMessage } from "../../messages/base.js"; +import { AIMessage, AIMessageChunk } from "../../messages/ai.js"; +import { ChatGenerationChunk, ChatResult } from "../../outputs.js"; +import { StructuredOutputMethodOptions } from "../base.js"; +import { InteropZodType } from "../../utils/types/zod.js"; +import { JSONSchema } from "../../utils/json_schema.js"; + +interface SimpleChatModelCallOptions extends BaseChatModelCallOptions { + tools?: BindToolsInput[]; + temperature?: number; +} + +class SimpleChatModel< + CallOptions extends SimpleChatModelCallOptions, + Output +> extends BaseChatModel { + constructor(fields: BaseChatModelParams) { + super(fields); + } + + _llmType() { + return "simple_chat_model"; + } + + async _generate( + _messages: BaseMessage[], + _options: SimpleChatModelCallOptions + ): Promise { + return Promise.resolve({ + generations: [ + { + message: new AIMessage({ + content: "wassup", + }), + text: "wassup", + }, + ], + }); + } + + async *_streamResponseChunks( + _messages: BaseMessage[], + _options: SimpleChatModelCallOptions + ): AsyncGenerator { + for (const char of "wassup") { + yield new ChatGenerationChunk({ + text: char, + message: new AIMessageChunk({ + content: char, + }), + }); + } + } + + withConfig(config: SimpleChatModelCallOptions) { + return super.withConfig(config); + } + + bindTools(tools: BindToolsInput[]) { + return this.withConfig({ + tools: [...(this.defaultOptions.tools ?? []), ...tools], + }); + } + + withStructuredOutput( + schema: InteropZodType | JSONSchema, + config?: StructuredOutputMethodOptions + ): SimpleChatModel; + + withStructuredOutput( + schema: InteropZodType | JSONSchema, + config?: StructuredOutputMethodOptions + ): SimpleChatModel; + + withStructuredOutput( + schema: InteropZodType | JSONSchema, + config?: StructuredOutputMethodOptions + ): + | SimpleChatModel + | SimpleChatModel { + return super.withStructuredOutput(schema, config) as + | SimpleChatModel + | SimpleChatModel; + } +} + +describe("SimpleChatModel", () => { + const toolDefinition = (name: string) => ({ + type: "function", + function: { + name, + description: "A test tool", + parameters: { + type: "object", + }, + }, + }); + + it("can string multiple bindTools calls", () => { + const model = new SimpleChatModel({}) + .bindTools([toolDefinition("foo")]) + .bindTools([toolDefinition("bar")]) + .withStructuredOutput( + z.object({ + foo: z.string(), + bar: z.string(), + }) + ); + expect(model.defaultOptions.tools?.length).toBe(2); + expect(model.outputParser).toBeDefined(); + }); + + it("can string multiple withStructuredOutput calls", () => { + const model = new SimpleChatModel({}) + .withStructuredOutput( + z.object({ + foo: z.string(), + }) + ) + .withStructuredOutput( + z.object({ + bar: z.string(), + }) + ); + // this should technically be 1, but `SimpleChatModel` doesn't + // handle tools with naming collisions (pretty simply to rectify) + expect(model.defaultOptions.tools?.length).toBe(2); + expect(model.outputParser).toBeDefined(); + }); + + it("can string multiple config modifying calls", () => { + const model = new SimpleChatModel({}) + .withConfig({ temperature: 0.5 }) + .bindTools([toolDefinition("foo")]) + .withConfig({ temperature: 0.7 }) + .withStructuredOutput( + z.object({ + foo: z.string(), + }) + ) + .bindTools([toolDefinition("bar")]) + .withStructuredOutput( + z.object({ + bar: z.string(), + }) + ); + expect(model.defaultOptions.temperature).toBe(0.7); + // same case here + expect(model.defaultOptions.tools?.length).toBe(4); + expect(model.outputParser).toBeDefined(); + }); +}); diff --git a/libs/langchain-core/src/types/type-utils.ts b/libs/langchain-core/src/types/type-utils.ts index e2c1e6970a52..ebe88416c768 100644 --- a/libs/langchain-core/src/types/type-utils.ts +++ b/libs/langchain-core/src/types/type-utils.ts @@ -1,3 +1,5 @@ // Utility for marking only some keys of an interface as optional // Compare to Partial which marks all keys as optional export type Optional = Omit & Partial>; + +export type Constructor = new (...args: unknown[]) => T;