Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 35 additions & 31 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def _resolve_schema(schemas: set[type], schema_name: str, omit_flag: str | None
return TypedDict(schema_name, all_annotations) # type: ignore[operator]


def _extract_metadata(type_: type) -> list:
def _extract_metadata(type_: type) -> list[Any]:
"""Extract metadata from a field type, handling Required/NotRequired and Annotated wrappers."""
# Handle Required[Annotated[...]] or NotRequired[Annotated[...]]
if get_origin(type_) in {Required, NotRequired}:
Expand Down Expand Up @@ -364,7 +364,9 @@ def _get_can_jump_to(middleware: AgentMiddleware[Any, Any], hook_name: str) -> l
return []


def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None = None) -> bool:
def _supports_provider_strategy(
model: str | BaseChatModel, tools: list[BaseTool | dict[str, Any]] | None = None
) -> bool:
"""Check if a model supports provider-specific structured output.

Args:
Expand Down Expand Up @@ -403,7 +405,7 @@ def _supports_provider_strategy(model: str | BaseChatModel, tools: list | None =

def _handle_structured_output_error(
exception: Exception,
response_format: ResponseFormat,
response_format: ResponseFormat[Any],
) -> tuple[bool, str]:
"""Handle structured output error. Returns `(should_retry, retry_tool_message)`."""
if not isinstance(response_format, ToolStrategy):
Expand Down Expand Up @@ -455,10 +457,10 @@ def compose_two(outer: ToolCallWrapper, inner: ToolCallWrapper) -> ToolCallWrapp

def composed(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
execute: Callable[[ToolCallRequest], ToolMessage | Command[Any]],
) -> ToolMessage | Command[Any]:
# Create a callable that invokes inner with the original execute
def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
def call_inner(req: ToolCallRequest) -> ToolMessage | Command[Any]:
return inner(req, execute)

# Outer can call call_inner multiple times
Expand All @@ -477,14 +479,14 @@ def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
def _chain_async_tool_call_wrappers(
wrappers: Sequence[
Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
Awaitable[ToolMessage | Command[Any]],
]
],
) -> (
Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
Awaitable[ToolMessage | Command[Any]],
]
| None
):
Expand All @@ -504,25 +506,25 @@ def _chain_async_tool_call_wrappers(

def compose_two(
outer: Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
Awaitable[ToolMessage | Command[Any]],
],
inner: Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
Awaitable[ToolMessage | Command[Any]],
],
) -> Callable[
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]],
Awaitable[ToolMessage | Command],
[ToolCallRequest, Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]]],
Awaitable[ToolMessage | Command[Any]],
]:
"""Compose two async wrappers where outer wraps inner."""

async def composed(
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
execute: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command[Any]]],
) -> ToolMessage | Command[Any]:
# Create an async callable that invokes inner with the original execute
async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:
async def call_inner(req: ToolCallRequest) -> ToolMessage | Command[Any]:
return await inner(req, execute)

# Outer can call call_inner multiple times
Expand All @@ -540,7 +542,7 @@ async def call_inner(req: ToolCallRequest) -> ToolMessage | Command:

def create_agent(
model: str | BaseChatModel,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
tools: Sequence[BaseTool | Callable[..., Any] | dict[str, Any]] | None = None,
*,
system_prompt: str | SystemMessage | None = None,
middleware: Sequence[AgentMiddleware[StateT_co, ContextT]] = (),
Expand All @@ -553,7 +555,7 @@ def create_agent(
interrupt_after: list[str] | None = None,
debug: bool = False,
name: str | None = None,
cache: BaseCache | None = None,
cache: BaseCache[Any] | None = None,
) -> CompiledStateGraph[
AgentState[ResponseT], ContextT, _InputAgentState, _OutputAgentState[ResponseT]
]:
Expand Down Expand Up @@ -704,7 +706,7 @@ def check_weather(location: str) -> str:
# Raw schemas are wrapped in AutoStrategy to preserve auto-detection intent.
# AutoStrategy is converted to ToolStrategy upfront to calculate tools during agent creation,
# but may be replaced with ProviderStrategy later based on model capabilities.
initial_response_format: ToolStrategy | ProviderStrategy | AutoStrategy | None
initial_response_format: ToolStrategy[Any] | ProviderStrategy[Any] | AutoStrategy[Any] | None
if response_format is None:
initial_response_format = None
elif isinstance(response_format, (ToolStrategy, ProviderStrategy)):
Expand All @@ -719,13 +721,13 @@ def check_weather(location: str) -> str:

# For AutoStrategy, convert to ToolStrategy to setup tools upfront
# (may be replaced with ProviderStrategy later based on model)
tool_strategy_for_setup: ToolStrategy | None = None
tool_strategy_for_setup: ToolStrategy[Any] | None = None
if isinstance(initial_response_format, AutoStrategy):
tool_strategy_for_setup = ToolStrategy(schema=initial_response_format.schema)
elif isinstance(initial_response_format, ToolStrategy):
tool_strategy_for_setup = initial_response_format

structured_output_tools: dict[str, OutputToolBinding] = {}
structured_output_tools: dict[str, OutputToolBinding[Any]] = {}
if tool_strategy_for_setup:
for response_schema in tool_strategy_for_setup.schema_specs:
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
Expand Down Expand Up @@ -872,7 +874,7 @@ def check_weather(location: str) -> str:
)

def _handle_model_output(
output: AIMessage, effective_response_format: ResponseFormat | None
output: AIMessage, effective_response_format: ResponseFormat[Any] | None
) -> dict[str, Any]:
"""Handle model output including structured responses.

Expand Down Expand Up @@ -975,7 +977,9 @@ def _handle_model_output(

return {"messages": [output]}

def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat | None]:
def _get_bound_model(
request: ModelRequest,
) -> tuple[Runnable[Any, Any], ResponseFormat[Any] | None]:
"""Get the model with appropriate tool bindings.

Performs auto-detection of strategy if needed based on model capabilities.
Expand Down Expand Up @@ -1025,7 +1029,7 @@ def _get_bound_model(request: ModelRequest) -> tuple[Runnable, ResponseFormat |
raise ValueError(msg)

# Determine effective response format (auto-detect if needed)
effective_response_format: ResponseFormat | None
effective_response_format: ResponseFormat[Any] | None
if isinstance(request.response_format, AutoStrategy):
# User provided raw schema via AutoStrategy - auto-detect best strategy based on model
if _supports_provider_strategy(request.model, tools=request.tools):
Expand Down Expand Up @@ -1119,7 +1123,7 @@ def _execute_model_sync(request: ModelRequest) -> ModelResponse:
structured_response=structured_response,
)

def model_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
def model_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Sync model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
Expand Down Expand Up @@ -1174,7 +1178,7 @@ async def _execute_model_async(request: ModelRequest) -> ModelResponse:
structured_response=structured_response,
)

async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any]:
async def amodel_node(state: AgentState[Any], runtime: Runtime[ContextT]) -> dict[str, Any]:
"""Async model request handler with sequential middleware processing."""
request = ModelRequest(
model=model,
Expand Down Expand Up @@ -1523,7 +1527,7 @@ def _fetch_last_ai_and_tool_messages(
def _make_model_to_tools_edge(
*,
model_destination: str,
structured_output_tools: dict[str, OutputToolBinding],
structured_output_tools: dict[str, OutputToolBinding[Any]],
end_destination: str,
) -> Callable[[dict[str, Any]], str | list[Send] | None]:
def model_to_tools(
Expand Down Expand Up @@ -1607,7 +1611,7 @@ def _make_tools_to_model_edge(
*,
tool_node: ToolNode,
model_destination: str,
structured_output_tools: dict[str, OutputToolBinding],
structured_output_tools: dict[str, OutputToolBinding[Any]],
end_destination: str,
) -> Callable[[dict[str, Any]], str | None]:
def tools_to_model(state: dict[str, Any]) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ class HITLResponse(TypedDict):
class _DescriptionFactory(Protocol):
"""Callable that generates a description for a tool call."""

def __call__(self, tool_call: ToolCall, state: AgentState, runtime: Runtime[ContextT]) -> str:
def __call__(
self, tool_call: ToolCall, state: AgentState[Any], runtime: Runtime[ContextT]
) -> str:
"""Generate a description for a tool call."""
...

Expand Down Expand Up @@ -203,7 +205,7 @@ def _create_action_and_config(
self,
tool_call: ToolCall,
config: InterruptOnConfig,
state: AgentState,
state: AgentState[Any],
runtime: Runtime[ContextT],
) -> tuple[ActionRequest, ReviewConfig]:
"""Create an ActionRequest and ReviewConfig for a tool call."""
Expand Down Expand Up @@ -277,7 +279,9 @@ def _process_decision(
)
raise ValueError(msg)

def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str, Any] | None:
def after_model(
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Trigger interrupt flows for relevant tool calls after an `AIMessage`.

Args:
Expand Down Expand Up @@ -363,7 +367,7 @@ def after_model(self, state: AgentState, runtime: Runtime[ContextT]) -> dict[str
return {"messages": [last_ai_msg, *artificial_tool_messages]}

async def aafter_model(
self, state: AgentState, runtime: Runtime[ContextT]
self, state: AgentState[Any], runtime: Runtime[ContextT]
) -> dict[str, Any] | None:
"""Async trigger interrupt flows for relevant tool calls after an `AIMessage`.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langgraph.runtime import Runtime


class ModelCallLimitState(AgentState):
class ModelCallLimitState(AgentState[Any]):
"""State schema for `ModelCallLimitMiddleware`.

Extends `AgentState` with model call tracking fields.
Expand Down
8 changes: 4 additions & 4 deletions libs/langchain_v1/langchain/agents/middleware/pii.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _process_content(self, content: str) -> tuple[str, list[PIIMatch]]:
@override
def before_model(
self,
state: AgentState,
state: AgentState[Any],
runtime: Runtime,
) -> dict[str, Any] | None:
"""Check user messages and tool results for PII before model invocation.
Expand Down Expand Up @@ -259,7 +259,7 @@ def before_model(
@hook_config(can_jump_to=["end"])
async def abefore_model(
self,
state: AgentState,
state: AgentState[Any],
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check user messages and tool results for PII before model invocation.
Expand All @@ -280,7 +280,7 @@ async def abefore_model(
@override
def after_model(
self,
state: AgentState,
state: AgentState[Any],
runtime: Runtime,
) -> dict[str, Any] | None:
"""Check AI messages for PII after model invocation.
Expand Down Expand Up @@ -339,7 +339,7 @@ def after_model(

async def aafter_model(
self,
state: AgentState,
state: AgentState[Any],
runtime: Runtime,
) -> dict[str, Any] | None:
"""Async check AI messages for PII after model invocation.
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/middleware/shell_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class _SessionResources:
session: ShellSession
tempdir: tempfile.TemporaryDirectory[str] | None
policy: BaseExecutionPolicy
finalizer: weakref.finalize = field(init=False, repr=False)
finalizer: weakref.finalize = field(init=False, repr=False) # type: ignore[type-arg]

def __post_init__(self) -> None:
self.finalizer = weakref.finalize(
Expand All @@ -90,7 +90,7 @@ def __post_init__(self) -> None:
)


class ShellToolState(AgentState):
class ShellToolState(AgentState[Any]):
"""Agent state extension for tracking shell session resources."""

shell_session_resources: NotRequired[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(
raise ValueError(msg)

@override
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
"""Process messages before model invocation, potentially triggering summarization.

Args:
Expand Down Expand Up @@ -305,7 +305,9 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] |
}

@override
async def abefore_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
async def abefore_model(
self, state: AgentState[Any], runtime: Runtime
) -> dict[str, Any] | None:
"""Process messages before model invocation, potentially triggering summarization.

Args:
Expand Down
21 changes: 8 additions & 13 deletions libs/langchain_v1/langchain/agents/middleware/todo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Todo(TypedDict):
"""The current status of the todo item."""


class PlanningState(AgentState):
class PlanningState(AgentState[Any]):
"""State schema for the todo middleware."""

todos: Annotated[NotRequired[list[Todo]], OmitFromInput]
Expand Down Expand Up @@ -118,7 +118,9 @@ class PlanningState(AgentState):


@tool(description=WRITE_TODOS_TOOL_DESCRIPTION)
def write_todos(todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]) -> Command:
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command[Any]:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
Expand Down Expand Up @@ -178,7 +180,7 @@ def __init__(
@tool(description=self.tool_description)
def write_todos(
todos: list[Todo], tool_call_id: Annotated[str, InjectedToolCallId]
) -> Command:
) -> Command[Any]:
"""Create and manage a structured task list for your current work session."""
return Command(
update={
Expand Down Expand Up @@ -246,11 +248,7 @@ async def awrap_model_call(
return await handler(request.override(system_message=new_system_message))

@override
def after_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
"""Check for parallel write_todos tool calls and return errors if detected.

The todo list is designed to be updated at most once per model turn. Since
Expand Down Expand Up @@ -299,11 +297,8 @@ def after_model(

return None

async def aafter_model(
self,
state: AgentState,
runtime: Runtime,
) -> dict[str, Any] | None:
@override
async def aafter_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any] | None:
"""Check for parallel write_todos tool calls and return errors if detected.

Async version of `after_model`. The todo list is designed to be updated at
Expand Down
Loading