Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Annotated
from typing import Annotated, Any

import pytest
from langchain_core.messages import HumanMessage
Expand Down Expand Up @@ -29,7 +29,7 @@ def test_tool_invocation_error_excludes_injected_state() -> None:
"""

# Define a custom state schema with injected data
class TestState(AgentState):
class TestState(AgentState[Any]):
secret_data: str # Example of state data not controlled by LLM

@dec_tool
Expand Down Expand Up @@ -95,7 +95,7 @@ async def test_tool_invocation_error_excludes_injected_state_async() -> None:
"""

# Define a custom state schema
class TestState(AgentState):
class TestState(AgentState[Any]):
internal_data: str

@dec_tool
Expand Down Expand Up @@ -194,10 +194,10 @@ async def test_create_agent_error_content_with_multiple_params() -> None:
This ensures the LLM receives focused, actionable feedback.
"""

class TestState(AgentState):
class TestState(AgentState[Any]):
user_id: str
api_key: str
session_data: dict
session_data: dict[str, Any]

@dec_tool
def complex_tool(
Expand Down Expand Up @@ -310,7 +310,7 @@ async def test_create_agent_error_only_model_controllable_params() -> None:
absent from error messages. This provides focused feedback to the LLM.
"""

class StateWithSecrets(AgentState):
class StateWithSecrets(AgentState[Any]):
password: str # Example of data not controlled by LLM

@dec_tool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from __future__ import annotations

from typing import Annotated, Any
from typing import TYPE_CHECKING, Annotated, Any

from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.tools import tool
Expand All @@ -29,11 +29,14 @@

from .model import FakeToolCallingModel

if TYPE_CHECKING:
from langgraph.runtime import Runtime


def test_tool_runtime_basic_injection() -> None:
"""Test basic ToolRuntime injection in tools with create_agent."""
# Track what was injected
injected_data = {}
injected_data: dict[str, Any] = {}

@tool
def runtime_tool(x: int, runtime: ToolRuntime) -> str:
Expand Down Expand Up @@ -79,7 +82,7 @@ def runtime_tool(x: int, runtime: ToolRuntime) -> str:

async def test_tool_runtime_async_injection() -> None:
"""Test ToolRuntime injection works with async tools."""
injected_data = {}
injected_data: dict[str, Any] = {}

@tool
async def async_runtime_tool(x: int, runtime: ToolRuntime) -> str:
Expand Down Expand Up @@ -194,7 +197,7 @@ def check_runtime_tool(runtime: ToolRuntime) -> str:

def test_tool_runtime_with_multiple_tools() -> None:
"""Test multiple tools can all access ToolRuntime."""
call_log = []
call_log: list[tuple[str, str | None, int | str]] = []

@tool
def tool_a(x: int, runtime: ToolRuntime) -> str:
Expand Down Expand Up @@ -241,7 +244,7 @@ def tool_b(y: str, runtime: ToolRuntime) -> str:

def test_tool_runtime_config_access() -> None:
"""Test tools can access config through ToolRuntime."""
config_data = {}
config_data: dict[str, Any] = {}

@tool
def config_tool(x: int, runtime: ToolRuntime) -> str:
Expand Down Expand Up @@ -281,7 +284,7 @@ def config_tool(x: int, runtime: ToolRuntime) -> str:
def test_tool_runtime_with_custom_state() -> None:
"""Test ToolRuntime works with custom state schemas."""

class CustomState(AgentState):
class CustomState(AgentState[Any]):
custom_field: str

runtime_state = {}
Expand Down Expand Up @@ -463,11 +466,11 @@ def test_tool_runtime_with_middleware() -> None:
runtime_calls = []

class TestMiddleware(AgentMiddleware):
def before_model(self, state, runtime) -> dict[str, Any]:
def before_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]:
middleware_calls.append("before_model")
return {}

def after_model(self, state, runtime) -> dict[str, Any]:
def after_model(self, state: AgentState[Any], runtime: Runtime) -> dict[str, Any]:
middleware_calls.append("after_model")
return {}

Expand Down Expand Up @@ -514,11 +517,7 @@ def test_tool_runtime_type_hints() -> None:
def typed_runtime_tool(x: int, runtime: ToolRuntime) -> str:
"""Tool with runtime access."""
# Access state dict - verify we can access standard state fields
if isinstance(runtime.state, dict):
# Count messages in state
typed_runtime["message_count"] = len(runtime.state.get("messages", []))
else:
typed_runtime["message_count"] = len(getattr(runtime.state, "messages", []))
typed_runtime["message_count"] = len(runtime.state.get("messages", []))
return f"Typed: {x}"

agent = create_agent(
Expand All @@ -545,7 +544,7 @@ def typed_runtime_tool(x: int, runtime: ToolRuntime) -> str:

def test_tool_runtime_name_based_injection() -> None:
"""Test that parameter named 'runtime' gets injected without type annotation."""
injected_data = {}
injected_data: dict[str, Any] = {}

@tool
def name_based_tool(x: int, runtime: Any) -> str:
Expand Down Expand Up @@ -600,7 +599,7 @@ def test_combined_injected_state_runtime_store() -> None:
injected_data = {}

# Custom state schema with additional fields
class CustomState(AgentState):
class CustomState(AgentState[Any]):
user_id: str
session_id: str

Expand Down Expand Up @@ -666,6 +665,7 @@ def multi_injection_tool(

# Verify the tool's args schema only includes LLM-controlled parameters
tool_args_schema = multi_injection_tool.args_schema
assert isinstance(tool_args_schema, dict)
assert "location" in tool_args_schema["properties"]
assert "state" not in tool_args_schema["properties"]
assert "runtime" not in tool_args_schema["properties"]
Expand Down Expand Up @@ -717,7 +717,7 @@ async def test_combined_injected_state_runtime_store_async() -> None:
injected_data = {}

# Custom state schema
class CustomState(AgentState):
class CustomState(AgentState[Any]):
api_key: str
request_id: str

Expand Down Expand Up @@ -791,6 +791,7 @@ async def async_multi_injection_tool(

# Verify the tool's args schema only includes LLM-controlled parameters
tool_args_schema = async_multi_injection_tool.args_schema
assert isinstance(tool_args_schema, dict)
assert "query" in tool_args_schema["properties"]
assert "max_results" in tool_args_schema["properties"]
assert "state" not in tool_args_schema["properties"]
Expand Down