diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py index 44573309069e7..c5ec9b4dc7f91 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_response_format.py @@ -8,15 +8,20 @@ import pytest from langchain_core.language_models import LanguageModelInput +from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.fake_chat_models import GenericFakeChatModel -from langchain_core.messages import AIMessage as CoreAIMessage -from langchain_core.messages import BaseMessage, HumanMessage +from langchain_core.messages import HumanMessage from langchain_core.runnables import Runnable from pydantic import BaseModel, Field from typing_extensions import TypedDict from langchain.agents import create_agent -from langchain.agents.middleware.types import AgentMiddleware, ModelRequest +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelCallResult, + ModelRequest, + ModelResponse, +) from langchain.agents.structured_output import ( MultipleStructuredOutputsError, ProviderStrategy, @@ -96,14 +101,14 @@ def get_location() -> str: # Standardized test data -WEATHER_DATA = {"temperature": 75.0, "condition": "sunny"} -LOCATION_DATA = {"city": "New York", "country": "USA"} +WEATHER_DATA: dict[str, float | str] = {"temperature": 75.0, "condition": "sunny"} +LOCATION_DATA: dict[str, str] = {"city": "New York", "country": "USA"} # Standardized expected responses -EXPECTED_WEATHER_PYDANTIC = WeatherBaseModel(**WEATHER_DATA) -EXPECTED_WEATHER_DATACLASS = WeatherDataclass(**WEATHER_DATA) +EXPECTED_WEATHER_PYDANTIC = WeatherBaseModel(temperature=75.0, condition="sunny") +EXPECTED_WEATHER_DATACLASS = WeatherDataclass(temperature=75.0, condition="sunny") EXPECTED_WEATHER_DICT: WeatherTypedDict = {"temperature": 75.0, "condition": "sunny"} -EXPECTED_LOCATION = LocationResponse(**LOCATION_DATA) +EXPECTED_LOCATION = LocationResponse(city="New York", country="USA") EXPECTED_LOCATION_DICT: LocationTypedDict = {"city": "New York", "country": "USA"} @@ -780,9 +785,9 @@ class CustomModel(GenericFakeChatModel): def bind_tools( self, - tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool], + tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool], **kwargs: Any, - ) -> Runnable[LanguageModelInput, BaseMessage]: + ) -> Runnable[LanguageModelInput, AIMessage]: # Record every tool binding event. self.tool_bindings.append(tools) return self @@ -802,15 +807,17 @@ class ModelSwappingMiddleware(AgentMiddleware): def wrap_model_call( self, request: ModelRequest, - handler: Callable[[ModelRequest], CoreAIMessage], - ) -> CoreAIMessage: + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: # Replace the model with our custom test model return handler(request.override(model=model)) # Track which model is checked for provider strategy support calls = [] - def mock_supports_provider_strategy(model, tools) -> bool: + def mock_supports_provider_strategy( + model: str | BaseChatModel, tools: list[Any] | None = None + ) -> bool: """Track which model is checked and return True for ProviderStrategy.""" calls.append(model) return True diff --git a/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py b/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py index 5e21103e9da66..f45eb90c12345 100644 --- a/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py +++ b/libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any from langchain_core.messages import HumanMessage from langchain_core.tools import tool @@ -20,6 +20,9 @@ from .model import FakeToolCallingModel +if TYPE_CHECKING: + from langgraph.runtime import Runtime + @tool def simple_tool(x: int) -> str: @@ -30,7 +33,7 @@ def simple_tool(x: int) -> str: def test_state_schema_single_custom_field() -> None: """Test that a single custom state field is preserved through agent execution.""" - class CustomState(AgentState): + class CustomState(AgentState[Any]): custom_field: str agent = create_agent( @@ -50,7 +53,7 @@ class CustomState(AgentState): def test_state_schema_multiple_custom_fields() -> None: """Test that multiple custom state fields are preserved through agent execution.""" - class CustomState(AgentState): + class CustomState(AgentState[Any]): user_id: str session_id: str context: str @@ -81,7 +84,7 @@ class CustomState(AgentState): def test_state_schema_with_tool_runtime() -> None: """Test that custom state fields are accessible via ToolRuntime.""" - class ExtendedState(AgentState): + class ExtendedState(AgentState[Any]): counter: int runtime_data = {} @@ -109,19 +112,19 @@ def counter_tool(x: int, runtime: ToolRuntime) -> str: def test_state_schema_with_middleware() -> None: """Test that state_schema merges with middleware state schemas.""" - class UserState(AgentState): + class UserState(AgentState[Any]): user_name: str - class MiddlewareState(AgentState): + class MiddlewareState(AgentState[Any]): middleware_data: str middleware_calls = [] - class TestMiddleware(AgentMiddleware): + class TestMiddleware(AgentMiddleware[MiddlewareState, None]): state_schema = MiddlewareState - def before_model(self, state, runtime) -> dict[str, Any]: - middleware_calls.append(state.get("middleware_data", "")) + def before_model(self, state: MiddlewareState, runtime: Runtime) -> dict[str, Any]: + middleware_calls.append(state["middleware_data"]) return {} agent = create_agent( @@ -165,7 +168,7 @@ def test_state_schema_none_uses_default() -> None: async def test_state_schema_async() -> None: """Test that state_schema works with async agents.""" - class AsyncState(AgentState): + class AsyncState(AgentState[Any]): async_field: str @tool