diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py index d903b22c9be01..0923e7e7b1373 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_create_agent_tool_validation.py @@ -1,5 +1,5 @@ import sys -from typing import Annotated +from typing import Annotated, Any import pytest from langchain_core.messages import HumanMessage @@ -29,7 +29,7 @@ def test_tool_invocation_error_excludes_injected_state() -> None: """ # Define a custom state schema with injected data - class TestState(AgentState): + class TestState(AgentState[Any]): secret_data: str # Example of state data not controlled by LLM @dec_tool @@ -95,7 +95,7 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None: """ # Define a custom state schema - class TestState(AgentState): + class TestState(AgentState[Any]): internal_data: str @dec_tool @@ -194,10 +194,10 @@ async def test_create_agent_error_content_with_multiple_params() -> None: This ensures the LLM receives focused, actionable feedback. """ - class TestState(AgentState): + class TestState(AgentState[Any]): user_id: str api_key: str - session_data: dict + session_data: dict[str, Any] @dec_tool def complex_tool( @@ -310,7 +310,7 @@ async def test_create_agent_error_only_model_controllable_params() -> None: absent from error messages. This provides focused feedback to the LLM. """ - class StateWithSecrets(AgentState): + class StateWithSecrets(AgentState[Any]): password: str # Example of data not controlled by LLM @dec_tool diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py b/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py index 2be8d92306586..1e44f5f109690 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_injected_runtime_create_agent.py @@ -16,7 +16,7 @@ from __future__ import annotations -from typing import Annotated, Any +from typing import TYPE_CHECKING, Annotated, Any from langchain_core.messages import HumanMessage, ToolMessage from langchain_core.tools import tool @@ -29,11 +29,14 @@ from .model import FakeToolCallingModel +if TYPE_CHECKING: + from langgraph.runtime import Runtime + def test_tool_runtime_basic_injection() -> None: """Test basic ToolRuntime injection in tools with create_agent.""" # Track what was injected - injected_data = {} + injected_data: dict[str, Any] = {} @tool def runtime_tool(x: int, runtime: ToolRuntime) -> str: @@ -79,7 +82,7 @@ def runtime_tool(x: int, runtime: ToolRuntime) -> str: async def test_tool_runtime_async_injection() -> None: """Test ToolRuntime injection works with async tools.""" - injected_data = {} + injected_data: dict[str, Any] = {} @tool async def async_runtime_tool(x: int, runtime: ToolRuntime) -> str: @@ -194,7 +197,7 @@ def check_runtime_tool(runtime: ToolRuntime) -> str: def test_tool_runtime_with_multiple_tools() -> None: """Test multiple tools can all access ToolRuntime.""" - call_log = [] + call_log: list[tuple[str, str | None, int | str]] = [] @tool def tool_a(x: int, runtime: ToolRuntime) -> str: @@ -241,7 +244,7 @@ def tool_b(y: str, runtime: ToolRuntime) -> str: def test_tool_runtime_config_access() -> None: """Test tools can access config through ToolRuntime.""" - config_data = {} + config_data: dict[str, Any] = {} @tool def config_tool(x: int, runtime: ToolRuntime) -> str: @@ -281,7 +284,7 @@ def config_tool(x: int, runtime: ToolRuntime) -> str: def test_tool_runtime_with_custom_state() -> None: """Test ToolRuntime works with custom state schemas.""" - class CustomState(AgentState): + class CustomState(AgentState[Any]): custom_field: str runtime_state = {} @@ -463,11 +466,11 @@ def test_tool_runtime_with_middleware() -> None: runtime_calls = [] class TestMiddleware(AgentMiddleware): - def before_model(self, state, runtime) -> dict[str, Any]: + def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]: middleware_calls.append("before_model") return {} - def after_model(self, state, runtime) -> dict[str, Any]: + def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]: middleware_calls.append("after_model") return {} @@ -514,11 +517,7 @@ def test_tool_runtime_type_hints() -> None: def typed_runtime_tool(x: int, runtime: ToolRuntime) -> str: """Tool with runtime access.""" # Access state dict - verify we can access standard state fields - if isinstance(runtime.state, dict): - # Count messages in state - typed_runtime["message_count"] = len(runtime.state.get("messages", [])) - else: - typed_runtime["message_count"] = len(getattr(runtime.state, "messages", [])) + typed_runtime["message_count"] = len(runtime.state.get("messages", [])) return f"Typed: {x}" agent = create_agent( @@ -545,7 +544,7 @@ def typed_runtime_tool(x: int, runtime: ToolRuntime) -> str: def test_tool_runtime_name_based_injection() -> None: """Test that parameter named 'runtime' gets injected without type annotation.""" - injected_data = {} + injected_data: dict[str, Any] = {} @tool def name_based_tool(x: int, runtime: Any) -> str: @@ -600,7 +599,7 @@ def test_combined_injected_state_runtime_store() -> None: injected_data = {} # Custom state schema with additional fields - class CustomState(AgentState): + class CustomState(AgentState[Any]): user_id: str session_id: str @@ -666,6 +665,7 @@ def multi_injection_tool( # Verify the tool's args schema only includes LLM-controlled parameters tool_args_schema = multi_injection_tool.args_schema + assert isinstance(tool_args_schema, dict) assert "location" in tool_args_schema["properties"] assert "state" not in tool_args_schema["properties"] assert "runtime" not in tool_args_schema["properties"] @@ -717,7 +717,7 @@ async def test_combined_injected_state_runtime_store_async() -> None: injected_data = {} # Custom state schema - class CustomState(AgentState): + class CustomState(AgentState[Any]): api_key: str request_id: str @@ -791,6 +791,7 @@ async def async_multi_injection_tool( # Verify the tool's args schema only includes LLM-controlled parameters tool_args_schema = async_multi_injection_tool.args_schema + assert isinstance(tool_args_schema, dict) assert "query" in tool_args_schema["properties"] assert "max_results" in tool_args_schema["properties"] assert "state" not in tool_args_schema["properties"]