Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions libs/langchain_v1/langchain/agents/structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(
class ToolStrategy(Generic[SchemaT]):
"""Use a tool calling strategy for model responses."""

schema: type[SchemaT]
schema: type[SchemaT] | dict[str, Any]
"""Schema for the tool calls."""

schema_specs: list[_SchemaSpec[SchemaT]]
Expand All @@ -218,7 +218,7 @@ class ToolStrategy(Generic[SchemaT]):

def __init__(
self,
schema: type[SchemaT],
schema: type[SchemaT] | dict[str, Any],
*,
tool_message_content: str | None = None,
handle_errors: bool
Expand Down
1 change: 0 additions & 1 deletion libs/langchain_v1/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ ignore-var-parameters = true # ignore missing documentation for *args and **kwa
"ARG", # Arguments, needs to fix
]
"tests/unit_tests/agents/test_return_direct_spec.py" = ["F821"]
"tests/unit_tests/agents/test_responses_spec.py" = ["F821"]
"tests/unit_tests/agents/test_responses.py" = ["F821"]
"tests/unit_tests/agents/test_react_agent.py" = ["ALL"]

Expand Down
36 changes: 27 additions & 9 deletions libs/langchain_v1/tests/unit_tests/agents/test_responses_spec.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
from __future__ import annotations

import os
from typing import (
TYPE_CHECKING,
Any,
)
from unittest.mock import MagicMock

import httpx
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool
from pydantic import BaseModel, create_model

from langchain.agents import create_agent
from langchain.agents.structured_output import (
ToolStrategy,
)
from tests.unit_tests.agents.utils import BaseSchema, load_spec

if TYPE_CHECKING:
from collections.abc import Callable

# Skip this test since langgraph.prebuilt.responses is not available
pytest.skip("langgraph.prebuilt.responses not available", allow_module_level=True)

try:
from langchain_openai import ChatOpenAI
except ImportError:
skip_openai_integration_tests = True
else:
skip_openai_integration_tests = False
skip_openai_integration_tests = "OPENAI_API_KEY" not in os.environ

AGENT_PROMPT = "You are an HR assistant."

Expand All @@ -30,8 +48,8 @@ class AssertionByInvocation(BaseSchema):

class TestCase(BaseSchema):
name: str
response_format: Union[Dict[str, Any], List[Dict[str, Any]]]
assertions_by_invocation: List[AssertionByInvocation]
response_format: dict[str, Any] | list[dict[str, Any]]
assertions_by_invocation: list[AssertionByInvocation]


class Employee(BaseModel):
Expand All @@ -49,12 +67,12 @@ class Employee(BaseModel):
TEST_CASES = load_spec("responses", as_model=TestCase)


def _make_tool(fn, *, name: str, description: str):
def _make_tool(fn: Callable[..., str | None], *, name: str, description: str) -> dict[str, Any]:
mock = MagicMock(side_effect=lambda *, name: fn(name=name))
input_model = create_model(f"{name}_input", name=(str, ...))

@tool(name, description=description, args_schema=input_model)
def _wrapped(name: str):
def _wrapped(name: str) -> Any:
return mock(name=name)

return {"tool": _wrapped, "mock": mock}
Expand Down Expand Up @@ -106,7 +124,7 @@ def get_employee_department(*, name: str) -> str | None:

for assertion in case.assertions_by_invocation:

def on_request(request: httpx.Request) -> None:
def on_request(_request: httpx.Request) -> None:
nonlocal llm_request_count
llm_request_count += 1

Expand All @@ -123,7 +141,7 @@ def on_request(request: httpx.Request) -> None:
agent = create_agent(
model,
tools=[role_tool["tool"], dept_tool["tool"]],
prompt=AGENT_PROMPT,
system_prompt=AGENT_PROMPT,
response_format=tool_output,
)

Expand Down
6 changes: 5 additions & 1 deletion libs/langchain_v1/tests/unit_tests/agents/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from pathlib import Path
from typing import TypeVar

from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
Expand All @@ -13,7 +14,10 @@ class BaseSchema(BaseModel):
)


def load_spec(spec_name: str, as_model: type[BaseModel]) -> list[BaseModel]:
_T = TypeVar("_T", bound=BaseModel)


def load_spec(spec_name: str, as_model: type[_T]) -> list[_T]:
with (Path(__file__).parent / "specifications" / f"{spec_name}.json").open(
"r", encoding="utf-8"
) as f:
Expand Down