diff --git a/libs/langchain-core/src/language_models/base.ts b/libs/langchain-core/src/language_models/base.ts index b06b9f21a8e8..15e168d25873 100644 --- a/libs/langchain-core/src/language_models/base.ts +++ b/libs/langchain-core/src/language_models/base.ts @@ -1,6 +1,4 @@ import type { Tiktoken, TiktokenModel } from "js-tiktoken/lite"; -import type { ZodType as ZodTypeV3 } from "zod/v3"; -import type { $ZodType as ZodTypeV4 } from "zod/v4/core"; import { type BaseCache, InMemoryCache } from "../caches/base.js"; import { @@ -26,6 +24,7 @@ import { InteropZodObject, InteropZodType, } from "../utils/types/zod.js"; +import { type AIMessage, type AIMessageChunk } from "../messages/ai.js"; // https://www.npmjs.com/package/js-tiktoken @@ -280,6 +279,8 @@ export type BaseLanguageModelInput = | string | BaseMessageLike[]; +export type AnyAIMessage = AIMessage | AIMessageChunk; + export type StructuredOutputType = InferInteropZodOutput; export type StructuredOutputMethodOptions = @@ -560,49 +561,15 @@ export abstract class BaseLanguageModel< throw new Error("Use .toJSON() instead"); } - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; - - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; - - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput?>( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseLanguageModel; - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; + withStructuredOutput?>( + schema: InteropZodType | JSONSchema, + config: StructuredOutputMethodOptions + ): BaseLanguageModel<{ raw: AnyAIMessage; parsed: Output }, CallOptions>; /** * Model wrapper that returns outputs formatted to match the given schema. @@ -615,26 +582,14 @@ export abstract class BaseLanguageModel< * @param {string} name The name of the function to call. * @param {"functionCalling" | "jsonMode"} [method=functionCalling] The method to use for getting the structured output. Defaults to "functionCalling". * @param {boolean | undefined} [includeRaw=false] Whether to include the raw output in the result. Defaults to false. - * @returns {Runnable | Runnable} A new runnable that calls the LLM with structured output. + * @returns {Runnable | Runnable} A new runnable that calls the LLM with structured output. */ - withStructuredOutput?< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - schema: - | InteropZodType - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput?>( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions ): - | Runnable - | Runnable< - BaseLanguageModelInput, - { - raw: BaseMessage; - parsed: RunOutput; - } - >; + | BaseLanguageModel + | BaseLanguageModel<{ raw: AnyAIMessage; parsed: Output }, CallOptions>; } /** diff --git a/libs/langchain-core/src/language_models/chat_models.ts b/libs/langchain-core/src/language_models/chat_models.ts index c0ac0af0553c..4f6c53579a9c 100644 --- a/libs/langchain-core/src/language_models/chat_models.ts +++ b/libs/langchain-core/src/language_models/chat_models.ts @@ -1,5 +1,3 @@ -import type { ZodType as ZodTypeV3 } from "zod/v3"; -import type { $ZodType as ZodTypeV4 } from "zod/v4/core"; import { AIMessage, type BaseMessage, @@ -7,7 +5,6 @@ import { type BaseMessageLike, HumanMessage, coerceMessageLikeToMessage, - AIMessageChunk, isAIMessageChunk, isBaseMessage, isAIMessage, @@ -31,13 +28,14 @@ import { type BaseLanguageModelCallOptions, type BaseLanguageModelInput, type BaseLanguageModelParams, + type AnyAIMessage, } from "./base.js"; import { CallbackManager, type CallbackManagerForLLMRun, type Callbacks, } from "../callbacks/manager.js"; -import type { RunnableConfig } from "../runnables/config.js"; +import { mergeConfigs, type RunnableConfig } from "../runnables/config.js"; import type { BaseCache } from "../caches/base.js"; import { StructuredToolInterface, @@ -46,18 +44,17 @@ import { import { Runnable, RunnableLambda, - RunnableSequence, RunnableToolLike, } from "../runnables/base.js"; import { concat } from "../utils/stream.js"; -import { RunnablePassthrough } from "../runnables/passthrough.js"; import { getSchemaDescription, InteropZodType, isInteropZodSchema, } from "../utils/types/zod.js"; import { callbackHandlerPrefersStreaming } from "../callbacks/base.js"; -import { toJsonSchema } from "../utils/json_schema.js"; +import { JSONSchema, toJsonSchema } from "../utils/json_schema.js"; +import { Constructor } from "../types/type-utils.js"; // eslint-disable-next-line @typescript-eslint/no-explicit-any export type ToolChoice = string | Record | "auto" | "any"; @@ -183,15 +180,19 @@ export type BindToolsInput = | RunnableToolLike | StructuredToolParams; +export type ChatModelOutputParser = Runnable; +export type InferChatModelOutputParser = TOutput extends AnyAIMessage + ? undefined + : ChatModelOutputParser; + /** * Base class for chat models. It extends the BaseLanguageModel class and * provides methods for generating chat based on input messages. */ export abstract class BaseChatModel< CallOptions extends BaseChatModelCallOptions = BaseChatModelCallOptions, - // TODO: Fix the parameter order on the next minor version. - OutputMessageType extends BaseMessageChunk = AIMessageChunk -> extends BaseLanguageModel { + TOutput = AnyAIMessage +> extends BaseLanguageModel { // Backwards compatibility since fields have been moved to RunnableConfig declare ParsedCallOptions: Omit< CallOptions, @@ -203,10 +204,30 @@ export abstract class BaseChatModel< disableStreaming = false; - constructor(fields: BaseChatModelParams) { + outputParser: InferChatModelOutputParser; + + defaultOptions: CallOptions; + + constructor(protected fields: BaseChatModelParams) { super(fields); } + protected _combineCallOptions( + additionalOptions?: Partial + ): Partial { + return mergeConfigs( + this.defaultOptions, + additionalOptions + ) as Partial; + } + + protected async _parseOutput(output: BaseMessage): Promise { + if (this.outputParser) { + return this.outputParser.invoke(output); + } + return output as TOutput; + } + _combineLLMOutput?( ...llmOutputs: LLMResult["llmOutput"][] ): LLMResult["llmOutput"]; @@ -216,7 +237,9 @@ export abstract class BaseChatModel< ): [RunnableConfig, this["ParsedCallOptions"]] { // For backwards compat, keep `signal` in both runnableConfig and callOptions const [runnableConfig, callOptions] = - super._separateRunnableConfigFromCallOptions(options); + super._separateRunnableConfigFromCallOptions( + this._combineCallOptions(options) + ); (callOptions as this["ParsedCallOptions"]).signal = runnableConfig.signal; return [runnableConfig, callOptions as this["ParsedCallOptions"]]; } @@ -229,10 +252,7 @@ export abstract class BaseChatModel< * matching the provider's specific tool schema. * @param kwargs Any additional parameters to bind. */ - bindTools?( - tools: BindToolsInput[], - kwargs?: Partial - ): Runnable; + bindTools?(tools: BindToolsInput[], options?: Partial): this; /** * Invokes the chat model with a single input. @@ -243,7 +263,7 @@ export abstract class BaseChatModel< async invoke( input: BaseLanguageModelInput, options?: CallOptions - ): Promise { + ): Promise { const promptValue = BaseChatModel._convertInputToPromptValue(input); const result = await this.generatePrompt( [promptValue], @@ -251,8 +271,7 @@ export abstract class BaseChatModel< options?.callbacks ); const chatGeneration = result.generations[0][0] as ChatGeneration; - // TODO: Remove cast after figuring out inheritance - return chatGeneration.message as OutputMessageType; + return this._parseOutput(chatGeneration.message); } // eslint-disable-next-line require-yield @@ -267,7 +286,7 @@ export abstract class BaseChatModel< async *_streamIterator( input: BaseLanguageModelInput, options?: CallOptions - ): AsyncGenerator { + ): AsyncGenerator { // Subclass check required to avoid double callbacks with default implementation if ( this._streamResponseChunks === @@ -326,7 +345,7 @@ export abstract class BaseChatModel< ...chunk.generationInfo, ...chunk.message.response_metadata, }; - yield chunk.message as OutputMessageType; + yield await this._parseOutput(chunk.message); if (!generationChunk) { generationChunk = chunk; } else { @@ -904,68 +923,52 @@ export abstract class BaseChatModel< return result.content; } - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; + withConfig(config: Partial): this { + const Cls = this.constructor as Constructor; + const instance = new Cls(this.fields); + instance.defaultOptions = { + ...this.defaultOptions, + ...config, + }; + return instance; + } - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV4 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; + /** @internal */ + protected withOutputParser>( + outputParser: ChatModelOutputParser + ): BaseChatModel { + const Cls = this.constructor as Constructor< + BaseChatModel + >; + const instance = new Cls(this.fields); + instance.outputParser = outputParser as InferChatModelOutputParser; + instance.defaultOptions = this.defaultOptions; + return instance; + } - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput>( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseChatModel; - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | ZodTypeV3 - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, - config?: StructuredOutputMethodOptions - ): Runnable; + withStructuredOutput>( + schema: InteropZodType | JSONSchema, + config: StructuredOutputMethodOptions + ): BaseChatModel; - withStructuredOutput< - // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput extends Record = Record - >( - outputSchema: - | InteropZodType - // eslint-disable-next-line @typescript-eslint/no-explicit-any - | Record, + withStructuredOutput>( + schema: InteropZodType | JSONSchema, config?: StructuredOutputMethodOptions ): - | Runnable - | Runnable< - BaseLanguageModelInput, - { - raw: BaseMessage; - parsed: RunOutput; - } - > { + | BaseChatModel + | BaseChatModel; + + withStructuredOutput>( + schema: InteropZodType | JSONSchema, + config?: StructuredOutputMethodOptions + ): + | BaseChatModel + | BaseChatModel { if (typeof this.bindTools !== "function") { throw new Error( `Chat model must implement ".bindTools()" to use withStructuredOutput.` @@ -976,9 +979,6 @@ export abstract class BaseChatModel< `"strict" mode is not supported for this model by default.` ); } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - const schema: Record | InteropZodType = - outputSchema; const name = config?.name; const description = getSchemaDescription(schema) ?? "A function available to call."; @@ -1004,7 +1004,7 @@ export abstract class BaseChatModel< }, ]; } else { - if ("name" in schema) { + if ("name" in schema && typeof schema.name === "string") { functionName = schema.name; } tools = [ @@ -1020,8 +1020,8 @@ export abstract class BaseChatModel< } const llm = this.bindTools(tools); - const outputParser = RunnableLambda.from( - (input: AIMessageChunk): RunOutput => { + const toolMessageParser = RunnableLambda.from( + (input: AnyAIMessage): Output => { if (!input.tool_calls || input.tool_calls.length === 0) { throw new Error("No tool calls found in the response."); } @@ -1031,37 +1031,34 @@ export abstract class BaseChatModel< if (!toolCall) { throw new Error(`No tool call found with name ${functionName}.`); } - return toolCall.args as RunOutput; + return toolCall.args as Output; } ); if (!includeRaw) { - return llm.pipe(outputParser).withConfig({ - runName: "StructuredOutput", - }) as Runnable; + return llm.withOutputParser( + toolMessageParser.withConfig({ + runName: "StructuredOutput", + }) + ); } - const parserAssign = RunnablePassthrough.assign({ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - parsed: (input: any, config) => outputParser.invoke(input.raw, config), - }); - const parserNone = RunnablePassthrough.assign({ - parsed: () => null, - }); - const parsedWithFallback = parserAssign.withFallbacks({ - fallbacks: [parserNone], - }); - return RunnableSequence.from< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput } - >([ - { - raw: llm, - }, - parsedWithFallback, - ]).withConfig({ - runName: "StructuredOutputRunnable", + const rawOutputParser = RunnableLambda.from< + AnyAIMessage, + { raw: AnyAIMessage; parsed: Output } + >( + async ( + input: AnyAIMessage + ): Promise<{ raw: AnyAIMessage; parsed: Output }> => { + return { + raw: input, + parsed: await toolMessageParser.invoke(input), + }; + } + ).withConfig({ + runName: "StructuredOutput", }); + return llm.withOutputParser(rawOutputParser); } } diff --git a/libs/langchain-core/src/language_models/tests/v1.test.ts b/libs/langchain-core/src/language_models/tests/v1.test.ts new file mode 100644 index 000000000000..64142ccb6848 --- /dev/null +++ b/libs/langchain-core/src/language_models/tests/v1.test.ts @@ -0,0 +1,156 @@ +import { describe, it } from "vitest"; +import { z } from "zod"; +import { + BaseChatModel, + BaseChatModelCallOptions, + BaseChatModelParams, + BindToolsInput, +} from "../chat_models.js"; +import { BaseMessage } from "../../messages/base.js"; +import { AIMessage, AIMessageChunk } from "../../messages/ai.js"; +import { ChatGenerationChunk, ChatResult } from "../../outputs.js"; +import { StructuredOutputMethodOptions } from "../base.js"; +import { InteropZodType } from "../../utils/types/zod.js"; +import { JSONSchema } from "../../utils/json_schema.js"; + +interface SimpleChatModelCallOptions extends BaseChatModelCallOptions { + tools?: BindToolsInput[]; + temperature?: number; +} + +class SimpleChatModel< + CallOptions extends SimpleChatModelCallOptions, + Output +> extends BaseChatModel { + constructor(fields: BaseChatModelParams) { + super(fields); + } + + _llmType() { + return "simple_chat_model"; + } + + async _generate( + _messages: BaseMessage[], + _options: SimpleChatModelCallOptions + ): Promise { + return Promise.resolve({ + generations: [ + { + message: new AIMessage({ + content: "wassup", + }), + text: "wassup", + }, + ], + }); + } + + async *_streamResponseChunks( + _messages: BaseMessage[], + _options: SimpleChatModelCallOptions + ): AsyncGenerator { + for (const char of "wassup") { + yield new ChatGenerationChunk({ + text: char, + message: new AIMessageChunk({ + content: char, + }), + }); + } + } + + bindTools(tools: BindToolsInput[]) { + return this.withConfig({ + tools: [...(this.defaultOptions?.tools ?? []), ...tools], + }); + } + + withStructuredOutput>( + schema: InteropZodType | JSONSchema, + config?: StructuredOutputMethodOptions + ): SimpleChatModel; + + withStructuredOutput>( + schema: InteropZodType | JSONSchema, + config: StructuredOutputMethodOptions + ): SimpleChatModel; + + withStructuredOutput>( + schema: InteropZodType | JSONSchema, + config?: StructuredOutputMethodOptions + ): + | SimpleChatModel + | SimpleChatModel { + return super.withStructuredOutput(schema, config) as + | SimpleChatModel + | SimpleChatModel; + } +} + +describe("SimpleChatModel", () => { + const toolDefinition = (name: string) => ({ + type: "function", + function: { + name, + description: "A test tool", + parameters: { + type: "object", + }, + }, + }); + + it("can string multiple bindTools calls", () => { + const model = new SimpleChatModel({}) + .bindTools([toolDefinition("foo")]) + .bindTools([toolDefinition("bar")]) + .withStructuredOutput( + z.object({ + foo: z.string(), + bar: z.string(), + }) + ); + expect(model.defaultOptions.tools?.length).toBe(3); // + 1 structured output + expect(model.outputParser).toBeDefined(); + }); + + it("can string multiple withStructuredOutput calls", () => { + const model = new SimpleChatModel({}) + .withStructuredOutput( + z.object({ + foo: z.string(), + }) + ) + .withStructuredOutput( + z.object({ + bar: z.string(), + }) + ); + // this should technically be 1, but `SimpleChatModel` doesn't + // handle tools with naming collisions (pretty simply to rectify) + expect(model.defaultOptions.tools?.length).toBe(2); + expect(model.outputParser).toBeDefined(); + }); + + it("can string multiple config modifying calls", () => { + const model = new SimpleChatModel({}) + .withConfig({ temperature: 0.5 }) + .bindTools([toolDefinition("foo")]) + .withConfig({ temperature: 0.7 }) + .withStructuredOutput( + z.object({ + foo: z.string(), + }) + ) + .bindTools([toolDefinition("bar")]) + .withStructuredOutput( + z.object({ + bar: z.string(), + }) + ); + expect(model.defaultOptions.temperature).toBe(0.7); + // same case here + expect(model.defaultOptions.tools?.length).toBe(4); + expect(model.outputParser).toBeDefined(); + }); +}); diff --git a/libs/langchain-core/src/prompts/tests/structured.test.ts b/libs/langchain-core/src/prompts/tests/structured.test.ts index e421da9519dc..f11907777b03 100644 --- a/libs/langchain-core/src/prompts/tests/structured.test.ts +++ b/libs/langchain-core/src/prompts/tests/structured.test.ts @@ -2,13 +2,11 @@ import { ZodType, ZodTypeDef } from "zod"; import { test, expect } from "vitest"; import { - StructuredOutputMethodParams, StructuredOutputMethodOptions, BaseLanguageModelInput, + AnyAIMessage, } from "../../language_models/base.js"; -import { BaseMessage } from "../../messages/index.js"; -import { Runnable, RunnableLambda } from "../../runnables/base.js"; -import { RunnableConfig } from "../../runnables/config.js"; +import { RunnableLambda } from "../../runnables/base.js"; import { FakeListChatModel } from "../../utils/testing/index.js"; import { StructuredPrompt } from "../structured.js"; import { load } from "../../load/index.js"; @@ -17,42 +15,25 @@ class FakeStructuredChatModel extends FakeListChatModel { withStructuredOutput< RunOutput extends Record = Record >( - _params: - | Record - | StructuredOutputMethodParams - | ZodType, + _params: Record | ZodType, config?: StructuredOutputMethodOptions | undefined - ): Runnable; + ): FakeListChatModel; withStructuredOutput< RunOutput extends Record = Record >( - _params: - | Record - | StructuredOutputMethodParams - | ZodType, + _params: Record | ZodType, config?: StructuredOutputMethodOptions | undefined - ): Runnable< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput }, - RunnableConfig - >; + ): FakeListChatModel<{ raw: AnyAIMessage; parsed: RunOutput }>; withStructuredOutput< RunOutput extends Record = Record >( - _params: - | Record - | StructuredOutputMethodParams - | ZodType, + _params: Record | ZodType, _config?: StructuredOutputMethodOptions | undefined ): - | Runnable - | Runnable< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput }, - RunnableConfig - > { + | FakeListChatModel + | FakeListChatModel<{ raw: AnyAIMessage; parsed: RunOutput }> { if (!_config?.includeRaw) { if (typeof _params === "object") { const func = RunnableLambda.from( diff --git a/libs/langchain-core/src/types/type-utils.ts b/libs/langchain-core/src/types/type-utils.ts index e2c1e6970a52..ebe88416c768 100644 --- a/libs/langchain-core/src/types/type-utils.ts +++ b/libs/langchain-core/src/types/type-utils.ts @@ -1,3 +1,5 @@ // Utility for marking only some keys of an interface as optional // Compare to Partial which marks all keys as optional export type Optional = Omit & Partial>; + +export type Constructor = new (...args: unknown[]) => T; diff --git a/libs/langchain-core/src/utils/testing/index.ts b/libs/langchain-core/src/utils/testing/index.ts index f238a42f3b5d..375738651066 100644 --- a/libs/langchain-core/src/utils/testing/index.ts +++ b/libs/langchain-core/src/utils/testing/index.ts @@ -16,6 +16,7 @@ import { BaseChatModel, BaseChatModelCallOptions, BaseChatModelParams, + BindToolsInput, } from "../../language_models/chat_models.js"; import { BaseLLMParams, LLM } from "../../language_models/llms.js"; import { @@ -42,12 +43,10 @@ import { } from "../../embeddings.js"; import { StructuredOutputMethodParams, - BaseLanguageModelInput, StructuredOutputMethodOptions, + type AnyAIMessage, } from "../../language_models/base.js"; -import { toJsonSchema } from "../json_schema.js"; - import { VectorStore } from "../../vectorstores.js"; import { cosine } from "../ml-distance/similarities.js"; import { InteropZodObject, InteropZodType } from "../types/zod.js"; @@ -227,17 +226,15 @@ export class FakeStreamingChatModel extends BaseChatModel { - switch (this.toolStyle) { - case "openai": - return { - type: "function", - function: { - name: t.name, - description: t.description, - parameters: toJsonSchema(t.schema), - }, - }; - case "anthropic": - return { - name: t.name, - description: t.description, - input_schema: toJsonSchema(t.schema), - }; - case "bedrock": - return { - toolSpec: { - name: t.name, - description: t.description, - inputSchema: toJsonSchema(t.schema), - }, - }; - case "google": - return { - name: t.name, - description: t.description, - parameters: toJsonSchema(t.schema), - }; - default: - throw new Error(`Unsupported tool style: ${this.toolStyle}`); - } + bindTools(tools: BindToolsInput[]) { + return this.withConfig({ + tools: [...(this.defaultOptions?.tools ?? []), ...tools], }); - - const wrapped = - this.toolStyle === "google" - ? [{ functionDeclarations: toolDicts }] - : toolDicts; - - /* creating a *new* instance – mirrors LangChain .bind semantics for type-safety and avoiding noise */ - const next = new FakeStreamingChatModel({ - sleep: this.sleep, - responses: this.responses, - chunks: this.chunks, - toolStyle: this.toolStyle, - thrownErrorString: this.thrownErrorString, - }); - next.tools = merged; - - return next.withConfig({ tools: wrapped } as BaseChatModelCallOptions); } async _generate( @@ -420,7 +366,9 @@ export interface ToolSpec { * Interface specific to the Fake Streaming Chat model. */ export interface FakeStreamingChatModelCallOptions - extends BaseChatModelCallOptions {} + extends BaseChatModelCallOptions { + tools?: BindToolsInput[]; +} /** * Interface for the Constructor-field specific to the Fake Streaming Chat model (all optional because we fill in defaults). */ @@ -478,7 +426,10 @@ export interface FakeListChatModelCallOptions extends BaseChatModelCallOptions { * console.log({ secondResponse }); * ``` */ -export class FakeListChatModel extends BaseChatModel { +export class FakeListChatModel extends BaseChatModel< + FakeListChatModelCallOptions, + RunOutput +> { static lc_name() { return "FakeListChatModel"; } @@ -611,7 +562,7 @@ export class FakeListChatModel extends BaseChatModel, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseChatModel; withStructuredOutput< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -623,7 +574,10 @@ export class FakeListChatModel extends BaseChatModel, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseChatModel< + FakeListChatModelCallOptions, + { raw: AnyAIMessage; parsed: RunOutput } + >; withStructuredOutput< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -636,21 +590,22 @@ export class FakeListChatModel extends BaseChatModel, _config?: StructuredOutputMethodOptions ): - | Runnable - | Runnable< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput } + | BaseChatModel + | BaseChatModel< + FakeListChatModelCallOptions, + { raw: AnyAIMessage; parsed: RunOutput } > { - return RunnableLambda.from(async (input) => { - const message = await this.invoke(input); - if (message.tool_calls?.[0]?.args) { - return message.tool_calls[0].args as RunOutput; - } - if (typeof message.content === "string") { - return JSON.parse(message.content); - } - throw new Error("No structured output found"); - }) as Runnable; + return this.withOutputParser( + RunnableLambda.from(async (message): Promise => { + if (message.tool_calls?.[0]?.args) { + return message.tool_calls[0].args as RunOutput; + } + if (typeof message.content === "string") { + return JSON.parse(message.content); + } + throw new Error("No structured output found"); + }) + ); } } diff --git a/libs/providers/langchain-anthropic/src/chat_models.ts b/libs/providers/langchain-anthropic/src/chat_models.ts index 9afa0bd9fba1..1fd7a4a8a534 100644 --- a/libs/providers/langchain-anthropic/src/chat_models.ts +++ b/libs/providers/langchain-anthropic/src/chat_models.ts @@ -13,16 +13,11 @@ import { } from "@langchain/core/language_models/chat_models"; import { type StructuredOutputMethodOptions, - type BaseLanguageModelInput, isOpenAITool, } from "@langchain/core/language_models/base"; import { toJsonSchema } from "@langchain/core/utils/json_schema"; import { BaseLLMOutputParser } from "@langchain/core/output_parsers"; -import { - Runnable, - RunnablePassthrough, - RunnableSequence, -} from "@langchain/core/runnables"; +import { RunnableLambda } from "@langchain/core/runnables"; import { InteropZodType, isInteropZodSchema, @@ -622,9 +617,10 @@ function extractToken(chunk: AIMessageChunk): string | undefined { *
*/ export class ChatAnthropicMessages< - CallOptions extends ChatAnthropicCallOptions = ChatAnthropicCallOptions + CallOptions extends ChatAnthropicCallOptions = ChatAnthropicCallOptions, + RunOutput = AIMessageChunk > - extends BaseChatModel + extends BaseChatModel implements AnthropicInput { static lc_name() { @@ -791,7 +787,7 @@ export class ChatAnthropicMessages< override bindTools( tools: ChatAnthropicToolType[], kwargs?: Partial - ): Runnable { + ): this { return this.withConfig({ tools: this.formatStructuredToolToAnthropic(tools), ...kwargs, @@ -1093,7 +1089,7 @@ export class ChatAnthropicMessages< // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, config?: StructuredOutputMethodOptions - ): Runnable; + ): ChatAnthropicMessages; withStructuredOutput< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -1104,7 +1100,10 @@ export class ChatAnthropicMessages< // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, config?: StructuredOutputMethodOptions - ): Runnable; + ): ChatAnthropicMessages< + CallOptions, + { raw: AIMessageChunk; parsed: RunOutput } + >; withStructuredOutput< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -1116,10 +1115,10 @@ export class ChatAnthropicMessages< | Record, config?: StructuredOutputMethodOptions ): - | Runnable - | Runnable< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput } + | BaseChatModel + | BaseChatModel< + CallOptions, + { raw: AIMessageChunk; parsed: RunOutput | null } > { // eslint-disable-next-line @typescript-eslint/no-explicit-any const schema: InteropZodType | Record = @@ -1130,6 +1129,12 @@ export class ChatAnthropicMessages< if (method === "jsonMode") { throw new Error(`Anthropic only supports "functionCalling" as a method.`); } + const thinkingAdmonition = + "Anthropic structured output relies on forced tool calling, " + + "which is not supported when `thinking` is enabled. This method will raise " + + "OutputParserException if tool calls are not " + + "generated. Consider disabling `thinking` or adjust your prompt to ensure " + + "the tool is called."; let functionName = name ?? "extract"; let outputParser: BaseLLMOutputParser; @@ -1148,6 +1153,8 @@ export class ChatAnthropicMessages< returnSingle: true, keyName: functionName, zodSchema: schema, + errorMsgIfNoToolCalls: + this.thinking?.type === "enabled" ? thinkingAdmonition : undefined, }); } else { let anthropicTools: Anthropic.Messages.Tool; @@ -1170,19 +1177,13 @@ export class ChatAnthropicMessages< outputParser = new AnthropicToolsOutputParser({ returnSingle: true, keyName: functionName, + errorMsgIfNoToolCalls: + this.thinking?.type === "enabled" ? thinkingAdmonition : undefined, }); } - let llm; + let llm: this; if (this.thinking?.type === "enabled") { - const thinkingAdmonition = - "Anthropic structured output relies on forced tool calling, " + - "which is not supported when `thinking` is enabled. This method will raise " + - "OutputParserException if tool calls are not " + - "generated. Consider disabling `thinking` or adjust your prompt to ensure " + - "the tool is called."; - console.warn(thinkingAdmonition); - llm = this.withConfig({ tools, ls_structured_output_format: { @@ -1190,15 +1191,6 @@ export class ChatAnthropicMessages< schema: toJsonSchema(schema), }, } as Partial); - - const raiseIfNoToolCalls = (message: AIMessageChunk) => { - if (!message.tool_calls || message.tool_calls.length === 0) { - throw new Error(thinkingAdmonition); - } - return message; - }; - - llm = llm.pipe(raiseIfNoToolCalls); } else { llm = this.withConfig({ tools, @@ -1214,32 +1206,16 @@ export class ChatAnthropicMessages< } if (!includeRaw) { - return llm.pipe(outputParser).withConfig({ - runName: "ChatAnthropicStructuredOutput", - }) as Runnable; + return llm.withOutputParser(outputParser); } - const parserAssign = RunnablePassthrough.assign({ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - parsed: (input: any, config) => outputParser.invoke(input.raw, config), - }); - const parserNone = RunnablePassthrough.assign({ - parsed: () => null, - }); - const parsedWithFallback = parserAssign.withFallbacks({ - fallbacks: [parserNone], - }); - return RunnableSequence.from< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput } - >([ - { - raw: llm, - }, - parsedWithFallback, - ]).withConfig({ - runName: "StructuredOutputRunnable", - }); + const parserWithRaw = RunnableLambda.from( + async (message: AIMessageChunk, config) => ({ + raw: message, + parsed: await outputParser.invoke(message, config).catch(() => null), + }) + ); + return llm.withOutputParser(parserWithRaw); } } diff --git a/libs/providers/langchain-anthropic/src/output_parsers.ts b/libs/providers/langchain-anthropic/src/output_parsers.ts index 36d02c3c1d97..3d3e02c3b5a3 100644 --- a/libs/providers/langchain-anthropic/src/output_parsers.ts +++ b/libs/providers/langchain-anthropic/src/output_parsers.ts @@ -12,7 +12,9 @@ import { // eslint-disable-next-line @typescript-eslint/no-explicit-any interface AnthropicToolsOutputParserParams> - extends JsonOutputKeyToolsParserParamsInterop {} + extends JsonOutputKeyToolsParserParamsInterop { + errorMsgIfNoToolCalls?: string; +} export class AnthropicToolsOutputParser< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -34,11 +36,16 @@ export class AnthropicToolsOutputParser< zodSchema?: InteropZodType; + errorMsgIfNoToolCalls = + "No parseable tool calls provided to AnthropicToolsOutputParser."; + constructor(params: AnthropicToolsOutputParserParams) { super(params); this.keyName = params.keyName; this.returnSingle = params.returnSingle ?? this.returnSingle; this.zodSchema = params.zodSchema; + this.errorMsgIfNoToolCalls = + params.errorMsgIfNoToolCalls ?? this.errorMsgIfNoToolCalls; } protected async _validateResult(result: unknown): Promise { @@ -91,9 +98,7 @@ export class AnthropicToolsOutputParser< return tool; }); if (tools[0] === undefined) { - throw new Error( - "No parseable tool calls provided to AnthropicToolsOutputParser." - ); + throw new Error(this.errorMsgIfNoToolCalls); } const [tool] = tools; const validatedResult = await this._validateResult(tool.args); diff --git a/libs/providers/langchain-openai/src/chat_models.ts b/libs/providers/langchain-openai/src/chat_models.ts index 02d7c28cbcb4..87ed0da9c564 100644 --- a/libs/providers/langchain-openai/src/chat_models.ts +++ b/libs/providers/langchain-openai/src/chat_models.ts @@ -45,18 +45,11 @@ import { import { isOpenAITool, type BaseFunctionCallOptions, - type BaseLanguageModelInput, type FunctionDefinition, type StructuredOutputMethodOptions, - type StructuredOutputMethodParams, } from "@langchain/core/language_models/base"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; -import { - Runnable, - RunnableLambda, - RunnablePassthrough, - RunnableSequence, -} from "@langchain/core/runnables"; +import { Runnable, RunnableLambda } from "@langchain/core/runnables"; import { JsonOutputParser, StructuredOutputParser, @@ -545,9 +538,10 @@ export interface BaseChatOpenAIFields /** @internal */ export abstract class BaseChatOpenAI< - CallOptions extends BaseChatOpenAICallOptions + CallOptions extends BaseChatOpenAICallOptions, + RunOutput = AIMessageChunk > - extends BaseChatModel + extends BaseChatModel implements Partial { temperature?: number; @@ -627,8 +621,6 @@ export abstract class BaseChatOpenAI< */ service_tier?: OpenAIClient.Chat.ChatCompletionCreateParams["service_tier"]; - protected defaultOptions: CallOptions; - _llmType() { return "openai"; } @@ -861,15 +853,6 @@ export abstract class BaseChatOpenAI< | undefined; } - protected _combineCallOptions( - additionalOptions?: this["ParsedCallOptions"] - ): this["ParsedCallOptions"] { - return { - ...this.defaultOptions, - ...(additionalOptions ?? {}), - }; - } - /** @internal */ _getClientOptions( options: OpenAICoreRequestOptions | undefined @@ -923,7 +906,7 @@ export abstract class BaseChatOpenAI< override bindTools( tools: ChatOpenAIToolType[], kwargs?: Partial - ): Runnable { + ): this { let strict: boolean | undefined; if (kwargs?.strict !== undefined) { strict = kwargs.strict; @@ -940,20 +923,6 @@ export abstract class BaseChatOpenAI< } as Partial); } - override async stream(input: BaseLanguageModelInput, options?: CallOptions) { - return super.stream( - input, - this._combineCallOptions(options) as CallOptions - ); - } - - override async invoke(input: BaseLanguageModelInput, options?: CallOptions) { - return super.invoke( - input, - this._combineCallOptions(options) as CallOptions - ); - } - /** @ignore */ _combineLLMOutput(...llmOutputs: OpenAILLMOutput[]): OpenAILLMOutput { return llmOutputs.reduce<{ @@ -1114,7 +1083,7 @@ export abstract class BaseChatOpenAI< // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseChatModel; withStructuredOutput< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -1125,7 +1094,7 @@ export abstract class BaseChatOpenAI< // eslint-disable-next-line @typescript-eslint/no-explicit-any | Record, config?: StructuredOutputMethodOptions - ): Runnable; + ): BaseChatModel; withStructuredOutput< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -1137,8 +1106,8 @@ export abstract class BaseChatOpenAI< | Record, config?: StructuredOutputMethodOptions ): - | Runnable - | Runnable; + | BaseChatModel + | BaseChatModel; /** * Add structured output to the model. @@ -1165,7 +1134,7 @@ export abstract class BaseChatOpenAI< outputSchema: InteropZodType | Record, config?: StructuredOutputMethodOptions ) { - let llm: Runnable; + let llm: this; let outputParser: Runnable; const { schema, name, includeRaw } = { @@ -1313,26 +1282,16 @@ export abstract class BaseChatOpenAI< } if (!includeRaw) { - return llm.pipe(outputParser) as Runnable< - BaseLanguageModelInput, - RunOutput - >; + return llm.withOutputParser(outputParser); } - const parserAssign = RunnablePassthrough.assign({ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - parsed: (input: any, config) => outputParser.invoke(input.raw, config), - }); - const parserNone = RunnablePassthrough.assign({ - parsed: () => null, - }); - const parsedWithFallback = parserAssign.withFallbacks({ - fallbacks: [parserNone], - }); - return RunnableSequence.from< - BaseLanguageModelInput, - { raw: BaseMessage; parsed: RunOutput } - >([{ raw: llm }, parsedWithFallback]); + const parserWithRaw = RunnableLambda.from( + async (message: AIMessageChunk, config) => ({ + raw: message, + parsed: await outputParser.invoke(message, config).catch(() => null), + }) + ); + return llm.withOutputParser(parserWithRaw); } } @@ -3362,11 +3321,4 @@ export class ChatOpenAI< runManager ); } - - override withConfig( - config: Partial - ): Runnable { - this.defaultOptions = { ...this.defaultOptions, ...config }; - return this; - } }