diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index cc09f91b7e705..37fd1cc385e4e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field, replace from inspect import iscoroutinefunction @@ -11,20 +12,10 @@ Any, Generic, Literal, - Protocol, cast, overload, ) -if TYPE_CHECKING: - from collections.abc import Awaitable - - from langgraph.types import Command - -# Needed as top level import for Pydantic schema generation on AgentState -import warnings -from typing import TypeAlias - from langchain_core.messages import ( AIMessage, AnyMessage, @@ -35,13 +26,14 @@ from langgraph.channels.ephemeral_value import EphemeralValue from langgraph.graph.message import add_messages from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper +from langgraph.runtime import Runtime +from langgraph.types import Command from langgraph.typing import ContextT from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.tools import BaseTool - from langgraph.runtime import Runtime from langchain.agents.structured_output import ResponseFormat @@ -279,7 +271,7 @@ class ModelResponse: # Type alias for middleware return type - allows returning either full response or just AIMessage -ModelCallResult: TypeAlias = ModelResponse | AIMessage +ModelCallResult = ModelResponse | AIMessage """`TypeAlias` for model call handler return value. Middleware can return either: @@ -756,56 +748,44 @@ async def awrap_tool_call(self, request, handler): raise NotImplementedError(msg) -class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): - """Callable with `AgentState` and `Runtime` as arguments.""" - - def __call__( - self, state: StateT_contra, runtime: Runtime[ContextT] - ) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]: - """Perform some logic with the state and runtime.""" - ... - - -class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc] - """Callable that returns a prompt string or SystemMessage given `ModelRequest`.""" - - def __call__( - self, request: ModelRequest - ) -> str | SystemMessage | Awaitable[str | SystemMessage]: - """Generate a system prompt string or SystemMessage based on the request.""" - ... - - -class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc] - """Callable for model call interception with handler callback. - - Receives handler callback to execute model and returns `ModelResponse` or - `AIMessage`. - """ - - def __call__( - self, - request: ModelRequest, - handler: Callable[[ModelRequest], ModelResponse], - ) -> ModelCallResult: - """Intercept model execution via handler callback.""" - ... - +_SyncCallableWithStateAndRuntime = Callable[ + [StateT_contra, Runtime[ContextT]], dict[str, Any] | Command[Any] | None +] +_AsyncCallableWithStateAndRuntime = Callable[ + [StateT_contra, Runtime[ContextT]], Awaitable[dict[str, Any] | Command[Any] | None] +] +_CallableWithStateAndRuntime = ( + _SyncCallableWithStateAndRuntime[StateT_contra, ContextT] + | _AsyncCallableWithStateAndRuntime[StateT_contra, ContextT] +) -class _CallableReturningToolResponse(Protocol): - """Callable for tool call interception with handler callback. +_SyncCallableReturningSystemMessage = Callable[[ModelRequest], str | SystemMessage] +_AsyncCallableReturningSystemMessage = Callable[[ModelRequest], Awaitable[str | SystemMessage]] +_CallableReturningSystemMessage = ( + _SyncCallableReturningSystemMessage | _AsyncCallableReturningSystemMessage +) - Receives handler callback to execute tool and returns final `ToolMessage` or - `Command`. - """ +_SyncCallableReturningModelResponse = Callable[ + [ModelRequest, Callable[[ModelRequest], ModelResponse]], ModelCallResult +] +_AsyncCallableReturningModelResponse = Callable[ + [ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], Awaitable[ModelCallResult] +] +_CallableReturningModelResponse = ( + _SyncCallableReturningModelResponse | _AsyncCallableReturningModelResponse +) - def __call__( - self, - request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], - ) -> ToolMessage | Command[Any]: - """Intercept tool execution via handler callback.""" - ... +_SyncCallableReturningToolResponse = Callable[ + [ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command[Any]]], + ToolMessage | Command[Any], +] +_AsyncCallableReturningToolResponse = Callable[ + [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]], + Awaitable[ToolMessage | Command[Any]], +] +_CallableReturningToolResponse = ( + _SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse +) CallableT = TypeVar("CallableT", bound=Callable[..., Any]) @@ -986,7 +966,9 @@ async def async_wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1011,7 +993,7 @@ def wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1146,7 +1128,9 @@ async def async_wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1169,7 +1153,7 @@ def wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1337,7 +1321,9 @@ async def async_wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1362,7 +1348,7 @@ def wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1498,7 +1484,9 @@ async def async_wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return await func(state, runtime) # type: ignore[misc] + return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)( + state, runtime + ) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1521,7 +1509,7 @@ def wrapped( state: StateT, runtime: Runtime[ContextT], ) -> dict[str, Any] | Command[Any] | None: - return func(state, runtime) # type: ignore[return-value] + return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime) # Preserve can_jump_to metadata on the wrapped function if func_can_jump_to: @@ -1547,7 +1535,7 @@ def wrapped( @overload def dynamic_prompt( - func: _CallableReturningSystemMessage[StateT, ContextT], + func: _CallableReturningSystemMessage, ) -> AgentMiddleware[StateT, ContextT]: ... @@ -1555,16 +1543,16 @@ def dynamic_prompt( def dynamic_prompt( func: None = None, ) -> Callable[ - [_CallableReturningSystemMessage[StateT, ContextT]], + [_CallableReturningSystemMessage], AgentMiddleware[StateT, ContextT], ]: ... def dynamic_prompt( - func: _CallableReturningSystemMessage[StateT, ContextT] | None = None, + func: _CallableReturningSystemMessage | None = None, ) -> ( Callable[ - [_CallableReturningSystemMessage[StateT, ContextT]], + [_CallableReturningSystemMessage], AgentMiddleware[StateT, ContextT], ] | AgentMiddleware[StateT, ContextT] @@ -1618,7 +1606,7 @@ def context_aware_prompt(request: ModelRequest) -> str: """ def decorator( - func: _CallableReturningSystemMessage[StateT, ContextT], + func: _CallableReturningSystemMessage, ) -> AgentMiddleware[StateT, ContextT]: is_async = iscoroutinefunction(func) @@ -1629,7 +1617,7 @@ async def async_wrapped( request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: - prompt = await func(request) # type: ignore[misc] + prompt = await cast("_AsyncCallableReturningSystemMessage", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1653,7 +1641,7 @@ def wrapped( request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: - prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request) + prompt = cast("_SyncCallableReturningSystemMessage", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1666,7 +1654,7 @@ async def async_wrapped_from_sync( handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: # Delegate to sync function - prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request) + prompt = cast("_SyncCallableReturningSystemMessage", func)(request) if isinstance(prompt, SystemMessage): request = request.override(system_message=prompt) else: @@ -1693,7 +1681,7 @@ async def async_wrapped_from_sync( @overload def wrap_model_call( - func: _CallableReturningModelResponse[StateT, ContextT], + func: _CallableReturningModelResponse, ) -> AgentMiddleware[StateT, ContextT]: ... @@ -1705,20 +1693,20 @@ def wrap_model_call( tools: list[BaseTool] | None = None, name: str | None = None, ) -> Callable[ - [_CallableReturningModelResponse[StateT, ContextT]], + [_CallableReturningModelResponse], AgentMiddleware[StateT, ContextT], ]: ... def wrap_model_call( - func: _CallableReturningModelResponse[StateT, ContextT] | None = None, + func: _CallableReturningModelResponse | None = None, *, state_schema: type[StateT] | None = None, tools: list[BaseTool] | None = None, name: str | None = None, ) -> ( Callable[ - [_CallableReturningModelResponse[StateT, ContextT]], + [_CallableReturningModelResponse], AgentMiddleware[StateT, ContextT], ] | AgentMiddleware[StateT, ContextT] @@ -1799,7 +1787,7 @@ def simple_response(request, handler): """ def decorator( - func: _CallableReturningModelResponse[StateT, ContextT], + func: _CallableReturningModelResponse, ) -> AgentMiddleware[StateT, ContextT]: is_async = iscoroutinefunction(func) @@ -1810,7 +1798,7 @@ async def async_wrapped( request: ModelRequest, handler: Callable[[ModelRequest], Awaitable[ModelResponse]], ) -> ModelCallResult: - return await func(request, handler) # type: ignore[misc, arg-type] + return await cast("_AsyncCallableReturningModelResponse", func)(request, handler) middleware_name = name or cast( "str", getattr(func, "__name__", "WrapModelCallMiddleware") @@ -1831,7 +1819,7 @@ def wrapped( request: ModelRequest, handler: Callable[[ModelRequest], ModelResponse], ) -> ModelCallResult: - return func(request, handler) + return cast("_SyncCallableReturningModelResponse", func)(request, handler) middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware")) @@ -1970,7 +1958,7 @@ async def async_wrapped( request: ToolCallRequest, handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], ) -> ToolMessage | Command[Any]: - return await func(request, handler) # type: ignore[arg-type,misc] + return await cast("_AsyncCallableReturningToolResponse", func)(request, handler) middleware_name = name or cast( "str", getattr(func, "__name__", "WrapToolCallMiddleware") @@ -1991,7 +1979,7 @@ def wrapped( request: ToolCallRequest, handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], ) -> ToolMessage | Command[Any]: - return func(request, handler) + return cast("_SyncCallableReturningToolResponse", func)(request, handler) middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))