Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 72 additions & 84 deletions libs/langchain_v1/langchain/agents/middleware/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from collections.abc import Awaitable, Callable, Sequence
from dataclasses import dataclass, field, replace
from inspect import iscoroutinefunction
Expand All @@ -11,20 +12,10 @@
Any,
Generic,
Literal,
Protocol,
cast,
overload,
)

if TYPE_CHECKING:
from collections.abc import Awaitable

from langgraph.types import Command

# Needed as top level import for Pydantic schema generation on AgentState
import warnings
from typing import TypeAlias

from langchain_core.messages import (
AIMessage,
AnyMessage,
Expand All @@ -35,13 +26,14 @@
from langgraph.channels.ephemeral_value import EphemeralValue
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolCallRequest, ToolCallWrapper
from langgraph.runtime import Runtime
from langgraph.types import Command
from langgraph.typing import ContextT
from typing_extensions import NotRequired, Required, TypedDict, TypeVar, Unpack

if TYPE_CHECKING:
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.tools import BaseTool
from langgraph.runtime import Runtime

from langchain.agents.structured_output import ResponseFormat

Expand Down Expand Up @@ -279,7 +271,7 @@ class ModelResponse:


# Type alias for middleware return type - allows returning either full response or just AIMessage
ModelCallResult: TypeAlias = ModelResponse | AIMessage
ModelCallResult = ModelResponse | AIMessage
"""`TypeAlias` for model call handler return value.

Middleware can return either:
Expand Down Expand Up @@ -756,56 +748,44 @@ async def awrap_tool_call(self, request, handler):
raise NotImplementedError(msg)


class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]):
"""Callable with `AgentState` and `Runtime` as arguments."""

def __call__(
self, state: StateT_contra, runtime: Runtime[ContextT]
) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]:
"""Perform some logic with the state and runtime."""
...


class _CallableReturningSystemMessage(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable that returns a prompt string or SystemMessage given `ModelRequest`."""

def __call__(
self, request: ModelRequest
) -> str | SystemMessage | Awaitable[str | SystemMessage]:
"""Generate a system prompt string or SystemMessage based on the request."""
...


class _CallableReturningModelResponse(Protocol[StateT_contra, ContextT]): # type: ignore[misc]
"""Callable for model call interception with handler callback.

Receives handler callback to execute model and returns `ModelResponse` or
`AIMessage`.
"""

def __call__(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
"""Intercept model execution via handler callback."""
...

_SyncCallableWithStateAndRuntime = Callable[
[StateT_contra, Runtime[ContextT]], dict[str, Any] | Command[Any] | None
]
_AsyncCallableWithStateAndRuntime = Callable[
[StateT_contra, Runtime[ContextT]], Awaitable[dict[str, Any] | Command[Any] | None]
]
_CallableWithStateAndRuntime = (
_SyncCallableWithStateAndRuntime[StateT_contra, ContextT]
| _AsyncCallableWithStateAndRuntime[StateT_contra, ContextT]
)

class _CallableReturningToolResponse(Protocol):
"""Callable for tool call interception with handler callback.
_SyncCallableReturningSystemMessage = Callable[[ModelRequest], str | SystemMessage]
_AsyncCallableReturningSystemMessage = Callable[[ModelRequest], Awaitable[str | SystemMessage]]
_CallableReturningSystemMessage = (
_SyncCallableReturningSystemMessage | _AsyncCallableReturningSystemMessage
)

Receives handler callback to execute tool and returns final `ToolMessage` or
`Command`.
"""
_SyncCallableReturningModelResponse = Callable[
[ModelRequest, Callable[[ModelRequest], ModelResponse]], ModelCallResult
]
_AsyncCallableReturningModelResponse = Callable[
[ModelRequest, Callable[[ModelRequest], Awaitable[ModelResponse]]], Awaitable[ModelCallResult]
]
_CallableReturningModelResponse = (
_SyncCallableReturningModelResponse | _AsyncCallableReturningModelResponse
)

def __call__(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
"""Intercept tool execution via handler callback."""
...
_SyncCallableReturningToolResponse = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], ToolMessage | Command[Any]]],
ToolMessage | Command[Any],
]
_AsyncCallableReturningToolResponse = Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
Awaitable[ToolMessage | Command[Any]],
]
_CallableReturningToolResponse = (
_SyncCallableReturningToolResponse | _AsyncCallableReturningToolResponse
)


CallableT = TypeVar("CallableT", bound=Callable[..., Any])
Expand Down Expand Up @@ -986,7 +966,9 @@ async def async_wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand All @@ -1011,7 +993,7 @@ def wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand Down Expand Up @@ -1146,7 +1128,9 @@ async def async_wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand All @@ -1169,7 +1153,7 @@ def wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand Down Expand Up @@ -1337,7 +1321,9 @@ async def async_wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand All @@ -1362,7 +1348,7 @@ def wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand Down Expand Up @@ -1498,7 +1484,9 @@ async def async_wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return await func(state, runtime) # type: ignore[misc]
return await cast("_AsyncCallableWithStateAndRuntime[StateT, ContextT]", func)(
state, runtime
)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand All @@ -1521,7 +1509,7 @@ def wrapped(
state: StateT,
runtime: Runtime[ContextT],
) -> dict[str, Any] | Command[Any] | None:
return func(state, runtime) # type: ignore[return-value]
return cast("_SyncCallableWithStateAndRuntime[StateT, ContextT]", func)(state, runtime)

# Preserve can_jump_to metadata on the wrapped function
if func_can_jump_to:
Expand All @@ -1547,24 +1535,24 @@ def wrapped(

@overload
def dynamic_prompt(
func: _CallableReturningSystemMessage[StateT, ContextT],
func: _CallableReturningSystemMessage,
) -> AgentMiddleware[StateT, ContextT]: ...


@overload
def dynamic_prompt(
func: None = None,
) -> Callable[
[_CallableReturningSystemMessage[StateT, ContextT]],
[_CallableReturningSystemMessage],
AgentMiddleware[StateT, ContextT],
]: ...


def dynamic_prompt(
func: _CallableReturningSystemMessage[StateT, ContextT] | None = None,
func: _CallableReturningSystemMessage | None = None,
) -> (
Callable[
[_CallableReturningSystemMessage[StateT, ContextT]],
[_CallableReturningSystemMessage],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
Expand Down Expand Up @@ -1618,7 +1606,7 @@ def context_aware_prompt(request: ModelRequest) -> str:
"""

def decorator(
func: _CallableReturningSystemMessage[StateT, ContextT],
func: _CallableReturningSystemMessage,
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)

Expand All @@ -1629,7 +1617,7 @@ async def async_wrapped(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
prompt = await func(request) # type: ignore[misc]
prompt = await cast("_AsyncCallableReturningSystemMessage", func)(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
Expand All @@ -1653,7 +1641,7 @@ def wrapped(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
prompt = cast("_SyncCallableReturningSystemMessage", func)(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
Expand All @@ -1666,7 +1654,7 @@ async def async_wrapped_from_sync(
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
# Delegate to sync function
prompt = cast("Callable[[ModelRequest], SystemMessage | str]", func)(request)
prompt = cast("_SyncCallableReturningSystemMessage", func)(request)
if isinstance(prompt, SystemMessage):
request = request.override(system_message=prompt)
else:
Expand All @@ -1693,7 +1681,7 @@ async def async_wrapped_from_sync(

@overload
def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT],
func: _CallableReturningModelResponse,
) -> AgentMiddleware[StateT, ContextT]: ...


Expand All @@ -1705,20 +1693,20 @@ def wrap_model_call(
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> Callable[
[_CallableReturningModelResponse[StateT, ContextT]],
[_CallableReturningModelResponse],
AgentMiddleware[StateT, ContextT],
]: ...


def wrap_model_call(
func: _CallableReturningModelResponse[StateT, ContextT] | None = None,
func: _CallableReturningModelResponse | None = None,
*,
state_schema: type[StateT] | None = None,
tools: list[BaseTool] | None = None,
name: str | None = None,
) -> (
Callable[
[_CallableReturningModelResponse[StateT, ContextT]],
[_CallableReturningModelResponse],
AgentMiddleware[StateT, ContextT],
]
| AgentMiddleware[StateT, ContextT]
Expand Down Expand Up @@ -1799,7 +1787,7 @@ def simple_response(request, handler):
"""

def decorator(
func: _CallableReturningModelResponse[StateT, ContextT],
func: _CallableReturningModelResponse,
) -> AgentMiddleware[StateT, ContextT]:
is_async = iscoroutinefunction(func)

Expand All @@ -1810,7 +1798,7 @@ async def async_wrapped(
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await func(request, handler) # type: ignore[misc, arg-type]
return await cast("_AsyncCallableReturningModelResponse", func)(request, handler)

middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapModelCallMiddleware")
Expand All @@ -1831,7 +1819,7 @@ def wrapped(
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return func(request, handler)
return cast("_SyncCallableReturningModelResponse", func)(request, handler)

middleware_name = name or cast("str", getattr(func, "__name__", "WrapModelCallMiddleware"))

Expand Down Expand Up @@ -1970,7 +1958,7 @@ async def async_wrapped(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
return await func(request, handler) # type: ignore[arg-type,misc]
return await cast("_AsyncCallableReturningToolResponse", func)(request, handler)

middleware_name = name or cast(
"str", getattr(func, "__name__", "WrapToolCallMiddleware")
Expand All @@ -1991,7 +1979,7 @@ def wrapped(
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
return func(request, handler)
return cast("_SyncCallableReturningToolResponse", func)(request, handler)

middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware"))

Expand Down