Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions libs/langchain_v1/tests/unit_tests/agents/test_response_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@

import pytest
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
from langchain_core.messages import AIMessage as CoreAIMessage
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import Runnable
from pydantic import BaseModel, Field
from typing_extensions import TypedDict

from langchain.agents import create_agent
from langchain.agents.middleware.types import AgentMiddleware, ModelRequest
from langchain.agents.middleware.types import (
AgentMiddleware,
ModelCallResult,
ModelRequest,
ModelResponse,
)
from langchain.agents.structured_output import (
MultipleStructuredOutputsError,
ProviderStrategy,
Expand Down Expand Up @@ -96,14 +101,14 @@ def get_location() -> str:


# Standardized test data
WEATHER_DATA = {"temperature": 75.0, "condition": "sunny"}
LOCATION_DATA = {"city": "New York", "country": "USA"}
WEATHER_DATA: dict[str, float | str] = {"temperature": 75.0, "condition": "sunny"}
LOCATION_DATA: dict[str, str] = {"city": "New York", "country": "USA"}

# Standardized expected responses
EXPECTED_WEATHER_PYDANTIC = WeatherBaseModel(**WEATHER_DATA)
EXPECTED_WEATHER_DATACLASS = WeatherDataclass(**WEATHER_DATA)
EXPECTED_WEATHER_PYDANTIC = WeatherBaseModel(temperature=75.0, condition="sunny")
EXPECTED_WEATHER_DATACLASS = WeatherDataclass(temperature=75.0, condition="sunny")
EXPECTED_WEATHER_DICT: WeatherTypedDict = {"temperature": 75.0, "condition": "sunny"}
EXPECTED_LOCATION = LocationResponse(**LOCATION_DATA)
EXPECTED_LOCATION = LocationResponse(city="New York", country="USA")
EXPECTED_LOCATION_DICT: LocationTypedDict = {"city": "New York", "country": "USA"}


Expand Down Expand Up @@ -780,9 +785,9 @@ class CustomModel(GenericFakeChatModel):

def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable[..., Any] | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
) -> Runnable[LanguageModelInput, AIMessage]:
# Record every tool binding event.
self.tool_bindings.append(tools)
return self
Expand All @@ -802,15 +807,17 @@ class ModelSwappingMiddleware(AgentMiddleware):
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], CoreAIMessage],
) -> CoreAIMessage:
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
# Replace the model with our custom test model
return handler(request.override(model=model))

# Track which model is checked for provider strategy support
calls = []

def mock_supports_provider_strategy(model, tools) -> bool:
def mock_supports_provider_strategy(
model: str | BaseChatModel, tools: list[Any] | None = None
) -> bool:
"""Track which model is checked and return True for ProviderStrategy."""
calls.append(model)
return True
Expand Down
23 changes: 13 additions & 10 deletions libs/langchain_v1/tests/unit_tests/agents/test_state_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
Expand All @@ -20,6 +20,9 @@

from .model import FakeToolCallingModel

if TYPE_CHECKING:
from langgraph.runtime import Runtime


@tool
def simple_tool(x: int) -> str:
Expand All @@ -30,7 +33,7 @@ def simple_tool(x: int) -> str:
def test_state_schema_single_custom_field() -> None:
"""Test that a single custom state field is preserved through agent execution."""

class CustomState(AgentState):
class CustomState(AgentState[Any]):
custom_field: str

agent = create_agent(
Expand All @@ -50,7 +53,7 @@ class CustomState(AgentState):
def test_state_schema_multiple_custom_fields() -> None:
"""Test that multiple custom state fields are preserved through agent execution."""

class CustomState(AgentState):
class CustomState(AgentState[Any]):
user_id: str
session_id: str
context: str
Expand Down Expand Up @@ -81,7 +84,7 @@ class CustomState(AgentState):
def test_state_schema_with_tool_runtime() -> None:
"""Test that custom state fields are accessible via ToolRuntime."""

class ExtendedState(AgentState):
class ExtendedState(AgentState[Any]):
counter: int

runtime_data = {}
Expand Down Expand Up @@ -109,19 +112,19 @@ def counter_tool(x: int, runtime: ToolRuntime) -> str:
def test_state_schema_with_middleware() -> None:
"""Test that state_schema merges with middleware state schemas."""

class UserState(AgentState):
class UserState(AgentState[Any]):
user_name: str

class MiddlewareState(AgentState):
class MiddlewareState(AgentState[Any]):
middleware_data: str

middleware_calls = []

class TestMiddleware(AgentMiddleware):
class TestMiddleware(AgentMiddleware[MiddlewareState, None]):
state_schema = MiddlewareState

def before_model(self, state, runtime) -> dict[str, Any]:
middleware_calls.append(state.get("middleware_data", ""))
def before_model(self, state: MiddlewareState, runtime: Runtime) -> dict[str, Any]:
middleware_calls.append(state["middleware_data"])
return {}

agent = create_agent(
Expand Down Expand Up @@ -165,7 +168,7 @@ def test_state_schema_none_uses_default() -> None:
async def test_state_schema_async() -> None:
"""Test that state_schema works with async agents."""

class AsyncState(AgentState):
class AsyncState(AgentState[Any]):
async_field: str

@tool
Expand Down