diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index 818ab06f5f846..57d230078f9d4 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -314,7 +314,7 @@ def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None return TypedDict(schema_name, all_annotations) # type: ignore[operator] -def _extract_metadata(type_: type) -> list: +def _extract_metadata(type_: type) -> list[Any]: """Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers.""" # Handle Required[Annotated[...]] or NotRequired[Annotated[...]] if get_origin(type_) in {Required, NotRequired}: @@ -364,7 +364,9 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l return [] -def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool: +def _supports_provider_strategy( + model: str | BaseChatModel, tools: list[BaseTool | dict[str, Any]] | None = None +) -> bool: """Check if a model supports provider-specific structured output. Args: @@ -403,7 +405,7 @@ def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = def _handle_structured_output_error( exception: Exception, - response_format: ResponseFormat, + response_format: ResponseFormat[Any], ) -> tuple[bool, str]: """Handle structured output error. Returns `(should_retry, retry_tool_message)`.""" if not isinstance(response_format, ToolStrategy): @@ -455,10 +457,10 @@ def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapp def composed( request: ToolCallRequest, - execute: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: + execute: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: # Create a callable that invokes inner with the original execute - def call_inner(req: ToolCallRequest) -> ToolMessage | Command: + def call_inner(req: ToolCallRequest) -> ToolMessage | Command[Any]: return inner(req, execute) # Outer can call call_inner multiple times @@ -477,14 +479,14 @@ def call_inner(req: ToolCallRequest) -> ToolMessage | Command: def _chain_async_tool_call_wrappers( wrappers: Sequence[ Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]], - Awaitable[ToolMessage | Command], + [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]], + Awaitable[ToolMessage | Command[Any]], ] ], ) -> ( Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]], - Awaitable[ToolMessage | Command], + [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]], + Awaitable[ToolMessage | Command[Any]], ] | None ): @@ -504,25 +506,25 @@ def _chain_async_tool_call_wrappers( def compose_two( outer: Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]], - Awaitable[ToolMessage | Command], + [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]], + Awaitable[ToolMessage | Command[Any]], ], inner: Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]], - Awaitable[ToolMessage | Command], + [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]], + Awaitable[ToolMessage | Command[Any]], ], ) -> Callable[ - [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]], - Awaitable[ToolMessage | Command], + [ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]], + Awaitable[ToolMessage | Command[Any]], ]: """Compose two async wrappers where outer wraps inner.""" async def composed( request: ToolCallRequest, - execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: + execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: # Create an async callable that invokes inner with the original execute - async def call_inner(req: ToolCallRequest) -> ToolMessage | Command: + async def call_inner(req: ToolCallRequest) -> ToolMessage | Command[Any]: return await inner(req, execute) # Outer can call call_inner multiple times @@ -540,7 +542,7 @@ async def call_inner(req: ToolCallRequest) -> ToolMessage | Command: def create_agent( model: str | BaseChatModel, - tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None, + tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None, *, system_prompt: str | SystemMessage | None = None, middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (), @@ -553,7 +555,7 @@ def create_agent( interrupt_after: list[str] | None = None, debug: bool = False, name: str | None = None, - cache: BaseCache | None = None, + cache: BaseCache[Any] | None = None, ) -> CompiledStateGraph[ AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT] ]: @@ -704,7 +706,7 @@ def check_weather(location: str) -> str: # Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent. # AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation, # but may be replaced with ProviderStrategy later based on model capabilities. - initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None + initial_response_format: ToolStrategy[Any] | ProviderStrategy[Any] | AutoStrategy[Any] | None if response_format is None: initial_response_format = None elif isinstance(response_format, (ToolStrategy, ProviderStrategy)): @@ -719,13 +721,13 @@ def check_weather(location: str) -> str: # For AutoStrategy, convert to ToolStrategy to setup tools upfront # (may be replaced with ProviderStrategy later based on model) - tool_strategy_for_setup: ToolStrategy | None = None + tool_strategy_for_setup: ToolStrategy[Any] | None = None if isinstance(initial_response_format, AutoStrategy): tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema) elif isinstance(initial_response_format, ToolStrategy): tool_strategy_for_setup = initial_response_format - structured_output_tools: dict[str, OutputToolBinding] = {} + structured_output_tools: dict[str, OutputToolBinding[Any]] = {} if tool_strategy_for_setup: for response_schema in tool_strategy_for_setup.schema_specs: structured_tool_info = OutputToolBinding.from_schema_spec(response_schema) @@ -872,7 +874,7 @@ def check_weather(location: str) -> str: ) def _handle_model_output( - output: AIMessage, effective_response_format: ResponseFormat | None + output: AIMessage, effective_response_format: ResponseFormat[Any] | None ) -> dict[str, Any]: """Handle model output including structured responses. @@ -975,7 +977,9 @@ def _handle_model_output( return {"messages": [output]} - def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]: + def _get_bound_model( + request: ModelRequest, + ) -> tuple[Runnable[Any, Any], ResponseFormat[Any] | None]: """Get the model with appropriate tool bindings. Performs auto-detection of strategy if needed based on model capabilities. @@ -1025,7 +1029,7 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | raise ValueError(msg) # Determine effective response format (auto-detect if needed) - effective_response_format: ResponseFormat | None + effective_response_format: ResponseFormat[Any] | None if isinstance(request.response_format, AutoStrategy): # User provided raw schema via AutoStrategy - auto-detect best strategy based on model if _supports_provider_strategy(request.model, tools=request.tools): @@ -1119,7 +1123,7 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse: structured_response=structured_response, ) - def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: + def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]: """Sync model request handler with sequential middleware processing.""" request = ModelRequest( model=model, @@ -1174,7 +1178,7 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse: structured_response=structured_response, ) - async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]: + async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]: """Async model request handler with sequential middleware processing.""" request = ModelRequest( model=model, @@ -1523,7 +1527,7 @@ def _fetch_last_ai_and_tool_messages( def _make_model_to_tools_edge( *, model_destination: str, - structured_output_tools: dict[str, OutputToolBinding], + structured_output_tools: dict[str, OutputToolBinding[Any]], end_destination: str, ) -> Callable[[dict[str, Any]], str | list[Send] | None]: def model_to_tools( @@ -1607,7 +1611,7 @@ def _make_tools_to_model_edge( *, tool_node: ToolNode, model_destination: str, - structured_output_tools: dict[str, OutputToolBinding], + structured_output_tools: dict[str, OutputToolBinding[Any]], end_destination: str, ) -> Callable[[dict[str, Any]], str | None]: def tools_to_model(state: dict[str, Any]) -> str | None: diff --git a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py index 3547bc83f2a0b..7fbb6e964df51 100644 --- a/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py +++ b/libs/langchain_v1/langchain/agents/middleware/human_in_the_loop.py @@ -102,7 +102,9 @@ class HITLResponse(TypedDict): class _DescriptionFactory(Protocol): """Callable that generates a description for a tool call.""" - def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str: + def __call__( + self, tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[ContextT] + ) -> str: """Generate a description for a tool call.""" ... @@ -203,7 +205,7 @@ def _create_action_and_config( self, tool_call: ToolCall, config: InterruptOnConfig, - state: AgentState, + state: AgentState[Any], runtime: Runtime[ContextT], ) -> tuple[ActionRequest, ReviewConfig]: """Create an ActionRequest and ReviewConfig for a tool call.""" @@ -277,7 +279,9 @@ def _process_decision( ) raise ValueError(msg) - def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None: + def after_model( + self, state: AgentState[Any], runtime: Runtime[ContextT] + ) -> dict[str, Any] | None: """Trigger interrupt flows for relevant tool calls after an `AIMessage`. Args: @@ -363,7 +367,7 @@ def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str return {"messages": [last_ai_msg, *artificial_tool_messages]} async def aafter_model( - self, state: AgentState, runtime: Runtime[ContextT] + self, state: AgentState[Any], runtime: Runtime[ContextT] ) -> dict[str, Any] | None: """Async trigger interrupt flows for relevant tool calls after an `AIMessage`. diff --git a/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py b/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py index 2085045a98af7..3f1295260fb4a 100644 --- a/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py +++ b/libs/langchain_v1/langchain/agents/middleware/model_call_limit.py @@ -19,7 +19,7 @@ from langgraph.runtime import Runtime -class ModelCallLimitState(AgentState): +class ModelCallLimitState(AgentState[Any]): """State schema for `ModelCallLimitMiddleware`. Extends `AgentState` with model call tracking fields. diff --git a/libs/langchain_v1/langchain/agents/middleware/pii.py b/libs/langchain_v1/langchain/agents/middleware/pii.py index 446f74097fada..06b5a764e69bc 100644 --- a/libs/langchain_v1/langchain/agents/middleware/pii.py +++ b/libs/langchain_v1/langchain/agents/middleware/pii.py @@ -164,7 +164,7 @@ def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]: @override def before_model( self, - state: AgentState, + state: AgentState[Any], runtime: Runtime, ) -> dict[str, Any] | None: """Check user messages and tool results for PII before model invocation. @@ -259,7 +259,7 @@ def before_model( @hook_config(can_jump_to=["end"]) async def abefore_model( self, - state: AgentState, + state: AgentState[Any], runtime: Runtime, ) -> dict[str, Any] | None: """Async check user messages and tool results for PII before model invocation. @@ -280,7 +280,7 @@ async def abefore_model( @override def after_model( self, - state: AgentState, + state: AgentState[Any], runtime: Runtime, ) -> dict[str, Any] | None: """Check AI messages for PII after model invocation. @@ -339,7 +339,7 @@ def after_model( async def aafter_model( self, - state: AgentState, + state: AgentState[Any], runtime: Runtime, ) -> dict[str, Any] | None: """Async check AI messages for PII after model invocation. diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index de26623950c1a..f7d5861a1b95e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -78,7 +78,7 @@ class _SessionResources: session: ShellSession tempdir: tempfile.TemporaryDirectory[str] | None policy: BaseExecutionPolicy - finalizer: weakref.finalize = field(init=False, repr=False) + finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg] def __post_init__(self) -> None: self.finalizer = weakref.finalize( @@ -90,7 +90,7 @@ def __post_init__(self) -> None: ) -class ShellToolState(AgentState): +class ShellToolState(AgentState[Any]): """Agent state extension for tracking shell session resources.""" shell_session_resources: NotRequired[ diff --git a/libs/langchain_v1/langchain/agents/middleware/summarization.py b/libs/langchain_v1/langchain/agents/middleware/summarization.py index d3546c65bb600..d2c8971354ef2 100644 --- a/libs/langchain_v1/langchain/agents/middleware/summarization.py +++ b/libs/langchain_v1/langchain/agents/middleware/summarization.py @@ -269,7 +269,7 @@ def __init__( raise ValueError(msg) @override - def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None: """Process messages before model invocation, potentially triggering summarization. Args: @@ -305,7 +305,9 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | } @override - async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: + async def abefore_model( + self, state: AgentState[Any], runtime: Runtime + ) -> dict[str, Any] | None: """Process messages before model invocation, potentially triggering summarization. Args: diff --git a/libs/langchain_v1/langchain/agents/middleware/todo.py b/libs/langchain_v1/langchain/agents/middleware/todo.py index 564622d6fede8..1f1d0e9e57f6e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/todo.py +++ b/libs/langchain_v1/langchain/agents/middleware/todo.py @@ -35,7 +35,7 @@ class Todo(TypedDict): """The current status of the todo item.""" -class PlanningState(AgentState): +class PlanningState(AgentState[Any]): """State schema for the todo middleware.""" todos: Annotated[NotRequired[list[Todo]], OmitFromInput] @@ -118,7 +118,9 @@ class PlanningState(AgentState): @tool(description=WRITE_TODOS_TOOL_DESCRIPTION) -def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command: +def write_todos( + todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId] +) -> Command[Any]: """Create and manage a structured task list for your current work session.""" return Command( update={ @@ -178,7 +180,7 @@ def __init__( @tool(description=self.tool_description) def write_todos( todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId] - ) -> Command: + ) -> Command[Any]: """Create and manage a structured task list for your current work session.""" return Command( update={ @@ -246,11 +248,7 @@ async def awrap_model_call( return await handler(request.override(system_message=new_system_message)) @override - def after_model( - self, - state: AgentState, - runtime: Runtime, - ) -> dict[str, Any] | None: + def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None: """Check for parallel write_todos tool calls and return errors if detected. The todo list is designed to be updated at most once per model turn. Since @@ -299,11 +297,8 @@ def after_model( return None - async def aafter_model( - self, - state: AgentState, - runtime: Runtime, - ) -> dict[str, Any] | None: + @override + async def aafter_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None: """Check for parallel write_todos tool calls and return errors if detected. Async version of `after_model`. The todo list is designed to be updated at diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py index 16b5e57c56644..967ece0361191 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_emulator.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage, ToolMessage @@ -109,8 +109,8 @@ def __init__( def wrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: """Emulate tool execution using LLM if tool should be emulated. Args: @@ -159,8 +159,8 @@ def wrap_tool_call( async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: """Async version of `wrap_tool_call`. Emulate tool execution using LLM if tool should be emulated. diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py index 7ef7313f46fa7..c162b61c96c9d 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_retry.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_retry.py @@ -5,7 +5,7 @@ import asyncio import time import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from langchain_core.messages import ToolMessage @@ -288,8 +288,8 @@ def _handle_failure( def wrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: """Intercept tool execution and retry on failure. Args: @@ -346,8 +346,8 @@ def wrap_tool_call( async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: """Intercept and control async tool execution with retry logic. Args: diff --git a/libs/langchain_v1/langchain/agents/middleware/tool_selection.py b/libs/langchain_v1/langchain/agents/middleware/tool_selection.py index 24a2c79bded64..cbe926c15fdc7 100644 --- a/libs/langchain_v1/langchain/agents/middleware/tool_selection.py +++ b/libs/langchain_v1/langchain/agents/middleware/tool_selection.py @@ -4,12 +4,7 @@ import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Annotated, Literal, Union - -if TYPE_CHECKING: - from collections.abc import Awaitable, Callable - - from langchain.tools import BaseTool +from typing import TYPE_CHECKING, Annotated, Any, Literal, Union from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import HumanMessage @@ -24,6 +19,11 @@ ) from langchain.chat_models.base import init_chat_model +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from langchain.tools import BaseTool + logger = logging.getLogger(__name__) DEFAULT_SYSTEM_PROMPT = ( @@ -42,7 +42,7 @@ class _SelectionRequest: valid_tool_names: list[str] -def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter: +def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]: """Create a structured output schema for tool selection. Args: @@ -227,7 +227,7 @@ def _prepare_selection_request(self, request: ModelRequest) -> _SelectionRequest def _process_selection_response( self, - response: dict, + response: dict[str, Any], available_tools: list[BaseTool], valid_tool_names: list[str], request: ModelRequest, diff --git a/libs/langchain_v1/langchain/agents/middleware/types.py b/libs/langchain_v1/langchain/agents/middleware/types.py index b4f96c0e8c8dc..cc09f91b7e705 100644 --- a/libs/langchain_v1/langchain/agents/middleware/types.py +++ b/libs/langchain_v1/langchain/agents/middleware/types.py @@ -78,10 +78,10 @@ class _ModelRequestOverrides(TypedDict, total=False): system_message: SystemMessage | None messages: list[AnyMessage] tool_choice: Any | None - tools: list[BaseTool | dict] - response_format: ResponseFormat | None + tools: list[BaseTool | dict[str, Any]] + response_format: ResponseFormat[Any] | None model_settings: dict[str, Any] - state: AgentState + state: AgentState[Any] @dataclass(init=False) @@ -92,9 +92,9 @@ class ModelRequest: messages: list[AnyMessage] # excluding system message system_message: SystemMessage | None tool_choice: Any | None - tools: list[BaseTool | dict] - response_format: ResponseFormat | None - state: AgentState + tools: list[BaseTool | dict[str, Any]] + response_format: ResponseFormat[Any] | None + state: AgentState[Any] runtime: Runtime[ContextT] # type: ignore[valid-type] model_settings: dict[str, Any] = field(default_factory=dict) @@ -106,9 +106,9 @@ def __init__( system_message: SystemMessage | None = None, system_prompt: str | None = None, tool_choice: Any | None = None, - tools: list[BaseTool | dict] | None = None, - response_format: ResponseFormat | None = None, - state: AgentState | None = None, + tools: list[BaseTool | dict[str, Any]] | None = None, + response_format: ResponseFormat[Any] | None = None, + state: AgentState[Any] | None = None, runtime: Runtime[ContextT] | None = None, model_settings: dict[str, Any] | None = None, ) -> None: @@ -321,7 +321,7 @@ class AgentState(TypedDict, Generic[ResponseT]): class _InputAgentState(TypedDict): # noqa: PYI049 """Input state schema for the agent.""" - messages: Required[Annotated[list[AnyMessage | dict], add_messages]] + messages: Required[Annotated[list[AnyMessage | dict[str, Any]], add_messages]] class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049 @@ -331,9 +331,13 @@ class _OutputAgentState(TypedDict, Generic[ResponseT]): # noqa: PYI049 structured_response: NotRequired[ResponseT] -StateT = TypeVar("StateT", bound=AgentState, default=AgentState) -StateT_co = TypeVar("StateT_co", bound=AgentState, default=AgentState, covariant=True) -StateT_contra = TypeVar("StateT_contra", bound=AgentState, contravariant=True) +StateT = TypeVar("StateT", bound=AgentState[Any], default=AgentState[Any]) +StateT_co = TypeVar("StateT_co", bound=AgentState[Any], default=AgentState[Any], covariant=True) +StateT_contra = TypeVar("StateT_contra", bound=AgentState[Any], contravariant=True) + + +class _DefaultAgentState(AgentState[Any]): + """AgentMiddleware default state.""" class AgentMiddleware(Generic[StateT, ContextT]): @@ -343,7 +347,7 @@ class AgentMiddleware(Generic[StateT, ContextT]): between steps in the main agent loop. """ - state_schema: type[StateT] = cast("type[StateT]", AgentState) + state_schema: type[StateT] = cast("type[StateT]", _DefaultAgentState) """The schema for state passed to the middleware nodes.""" tools: Sequence[BaseTool] @@ -603,8 +607,8 @@ async def aafter_agent( def wrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: """Intercept tool execution for retries, monitoring, or modification. Async version is `awrap_tool_call` @@ -685,8 +689,8 @@ def wrap_tool_call(self, request, handler): async def awrap_tool_call( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: """Intercept and control async tool execution via handler callback. The handler callback executes the tool call and returns a `ToolMessage` or @@ -757,7 +761,7 @@ class _CallableWithStateAndRuntime(Protocol[StateT_contra, ContextT]): def __call__( self, state: StateT_contra, runtime: Runtime[ContextT] - ) -> dict[str, Any] | Command | None | Awaitable[dict[str, Any] | Command | None]: + ) -> dict[str, Any] | Command[Any] | None | Awaitable[dict[str, Any] | Command[Any] | None]: """Perform some logic with the state and runtime.""" ... @@ -798,8 +802,8 @@ class _CallableReturningToolResponse(Protocol): def __call__( self, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: """Intercept tool execution via handler callback.""" ... @@ -981,7 +985,7 @@ async def async_wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return await func(state, runtime) # type: ignore[misc] # Preserve can_jump_to metadata on the wrapped function @@ -1006,7 +1010,7 @@ def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return func(state, runtime) # type: ignore[return-value] # Preserve can_jump_to metadata on the wrapped function @@ -1141,7 +1145,7 @@ async def async_wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return await func(state, runtime) # type: ignore[misc] # Preserve can_jump_to metadata on the wrapped function @@ -1164,7 +1168,7 @@ def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return func(state, runtime) # type: ignore[return-value] # Preserve can_jump_to metadata on the wrapped function @@ -1332,7 +1336,7 @@ async def async_wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return await func(state, runtime) # type: ignore[misc] # Preserve can_jump_to metadata on the wrapped function @@ -1357,7 +1361,7 @@ def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return func(state, runtime) # type: ignore[return-value] # Preserve can_jump_to metadata on the wrapped function @@ -1493,7 +1497,7 @@ async def async_wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return await func(state, runtime) # type: ignore[misc] # Preserve can_jump_to metadata on the wrapped function @@ -1516,7 +1520,7 @@ def wrapped( _self: AgentMiddleware[StateT, ContextT], state: StateT, runtime: Runtime[ContextT], - ) -> dict[str, Any] | Command | None: + ) -> dict[str, Any] | Command[Any] | None: return func(state, runtime) # type: ignore[return-value] # Preserve can_jump_to metadata on the wrapped function @@ -1964,8 +1968,8 @@ def decorator( async def async_wrapped( _self: AgentMiddleware, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]], + ) -> ToolMessage | Command[Any]: return await func(request, handler) # type: ignore[arg-type,misc] middleware_name = name or cast( @@ -1985,8 +1989,8 @@ async def async_wrapped( def wrapped( _self: AgentMiddleware, request: ToolCallRequest, - handler: Callable[[ToolCallRequest], ToolMessage | Command], - ) -> ToolMessage | Command: + handler: Callable[[ToolCallRequest], ToolMessage | Command[Any]], + ) -> ToolMessage | Command[Any]: return func(request, handler) middleware_name = name or cast("str", getattr(func, "__name__", "WrapToolCallMiddleware")) diff --git a/libs/langchain_v1/langchain/agents/structured_output.py b/libs/langchain_v1/langchain/agents/structured_output.py index 3a019cfab6473..2683f6515fd1b 100644 --- a/libs/langchain_v1/langchain/agents/structured_output.py +++ b/libs/langchain_v1/langchain/agents/structured_output.py @@ -75,7 +75,7 @@ def __init__(self, tool_name: str, source: Exception, ai_message: AIMessage) -> def _parse_with_schema( - schema: type[SchemaT] | dict, schema_kind: SchemaKind, data: dict[str, Any] + schema: type[SchemaT] | dict[str, Any], schema_kind: SchemaKind, data: dict[str, Any] ) -> Any: """Parse data using for any supported schema type. diff --git a/libs/langchain_v1/langchain/chat_models/base.py b/libs/langchain_v1/langchain/chat_models/base.py index 940023caaf6c6..f8a46585f4ad3 100644 --- a/libs/langchain_v1/langchain/chat_models/base.py +++ b/libs/langchain_v1/langchain/chat_models/base.py @@ -581,12 +581,12 @@ class _ConfigurableModel(Runnable[LanguageModelInput, Any]): def __init__( self, *, - default_config: dict | None = None, + default_config: dict[str, Any] | None = None, configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any", config_prefix: str = "", - queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (), + queued_declarative_operations: Sequence[tuple[str, tuple[Any, ...], dict[str, Any]]] = (), ) -> None: - self._default_config: dict = default_config or {} + self._default_config: dict[str, Any] = default_config or {} self._configurable_fields: Literal["any"] | list[str] = ( "any" if configurable_fields == "any" else list(configurable_fields) ) @@ -595,8 +595,10 @@ def __init__( if config_prefix and not config_prefix.endswith("_") else config_prefix ) - self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list( - queued_declarative_operations, + self._queued_declarative_operations: list[tuple[str, tuple[Any, ...], dict[str, Any]]] = ( + list( + queued_declarative_operations, + ) ) def __getattr__(self, name: str) -> Any: @@ -629,14 +631,14 @@ def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel: msg += "." raise AttributeError(msg) - def _model(self, config: RunnableConfig | None = None) -> Runnable: + def _model(self, config: RunnableConfig | None = None) -> Runnable[Any, Any]: params = {**self._default_config, **self._model_params(config)} model = _init_chat_model_helper(**params) for name, args, kwargs in self._queued_declarative_operations: model = getattr(model, name)(*args, **kwargs) return model - def _model_params(self, config: RunnableConfig | None) -> dict: + def _model_params(self, config: RunnableConfig | None) -> dict[str, Any]: config = ensure_config(config) model_params = { _remove_prefix(k, self._config_prefix): v @@ -962,7 +964,7 @@ async def astream_events( # Explicitly added to satisfy downstream linters. def bind_tools( self, - tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], + tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool], **kwargs: Any, ) -> Runnable[LanguageModelInput, AIMessage]: return self.__getattr__("bind_tools")(tools, **kwargs) @@ -970,7 +972,7 @@ def bind_tools( # Explicitly added to satisfy downstream linters. def with_structured_output( self, - schema: dict | type[BaseModel], + schema: dict[str, Any] | type[BaseModel], **kwargs: Any, - ) -> Runnable[LanguageModelInput, dict | BaseModel]: + ) -> Runnable[LanguageModelInput, dict[str, Any] | BaseModel]: return self.__getattr__("with_structured_output")(schema, **kwargs) diff --git a/libs/langchain_v1/pyproject.toml b/libs/langchain_v1/pyproject.toml index 3fdbf8b0e5ac0..ed8524981f38c 100644 --- a/libs/langchain_v1/pyproject.toml +++ b/libs/langchain_v1/pyproject.toml @@ -95,7 +95,6 @@ warn_unreachable = true exclude = ["tests/unit_tests/agents/*"] # TODO: activate for 'strict' checking -disallow_any_generics = false warn_return_any = false [[tool.mypy.overrides]] diff --git a/libs/langchain_v1/tests/integration_tests/chat_models/test_base.py b/libs/langchain_v1/tests/integration_tests/chat_models/test_base.py index ee47dde3bc97e..9d1f22eab2efb 100644 --- a/libs/langchain_v1/tests/integration_tests/chat_models/test_base.py +++ b/libs/langchain_v1/tests/integration_tests/chat_models/test_base.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import Any, cast import pytest from langchain_core.language_models import BaseChatModel @@ -41,7 +41,7 @@ def chat_model_class(self) -> type[BaseChatModel]: return cast("type[BaseChatModel]", init_chat_model) @property - def chat_model_params(self) -> dict: + def chat_model_params(self) -> dict[str, Any]: return {"model": "gpt-4o", "configurable_fields": "any"} @property diff --git a/libs/langchain_v1/tests/unit_tests/conftest.py b/libs/langchain_v1/tests/unit_tests/conftest.py index 0b38ae788f71b..7279844a956a8 100644 --- a/libs/langchain_v1/tests/unit_tests/conftest.py +++ b/libs/langchain_v1/tests/unit_tests/conftest.py @@ -23,7 +23,7 @@ def remove_request_headers(request: Any) -> Any: return request -def remove_response_headers(response: dict) -> dict: +def remove_response_headers(response: dict[str, Any]) -> dict[str, Any]: """Remove sensitive headers from the response.""" for k in response["headers"]: response["headers"][k] = "**REDACTED**" @@ -31,7 +31,7 @@ def remove_response_headers(response: dict) -> dict: @pytest.fixture(scope="session") -def vcr_config() -> dict: +def vcr_config() -> dict[str, Any]: """Extend the default configuration coming from langchain_tests.""" config = base_vcr_config() config.setdefault("filter_headers", []).extend(_EXTRA_HEADERS) @@ -42,7 +42,7 @@ def vcr_config() -> dict: return config -def pytest_recording_configure(config: dict, vcr: VCR) -> None: # noqa: ARG001 +def pytest_recording_configure(config: dict[str, Any], vcr: VCR) -> None: # noqa: ARG001 vcr.register_persister(CustomPersister()) vcr.register_serializer("yaml.gz", CustomSerializer())