Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions libs/langchain_v1/tests/unit_tests/agents/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from dataclasses import asdict, is_dataclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
)

from langchain_core.callbacks import CallbackManagerForLLMRun
Expand All @@ -19,13 +17,12 @@
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from typing_extensions import override

StructuredResponseT = TypeVar("StructuredResponseT")


class FakeToolCallingModel(BaseChatModel, Generic[StructuredResponseT]):
tool_calls: list[list[ToolCall]] | list[list[dict]] | None = None
structured_response: StructuredResponseT | None = None
class FakeToolCallingModel(BaseChatModel):
tool_calls: list[list[ToolCall]] | list[list[dict[str, Any]]] | None = None
structured_response: Any | None = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"

Expand All @@ -52,7 +49,9 @@ def _generate(
if is_native and not tool_calls:
if isinstance(self.structured_response, BaseModel):
content_obj = self.structured_response.model_dump()
elif is_dataclass(self.structured_response):
elif is_dataclass(self.structured_response) and not isinstance(
self.structured_response, type
):
content_obj = asdict(self.structured_response)
elif isinstance(self.structured_response, dict):
content_obj = self.structured_response
Expand All @@ -71,11 +70,14 @@ def _generate(
def _llm_type(self) -> str:
return "fake-tool-call-model"

@override
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
tools: Sequence[dict[str, Any] | type | Callable[..., Any] | BaseTool],
*,
tool_choice: str | None = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
) -> Runnable[LanguageModelInput, AIMessage]:
if len(tools) == 0:
msg = "Must provide at least one tool"
raise ValueError(msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@
# return "The weather is sunny and 75°F."

# expected_structured_response = WeatherResponse(temperature=75)
# model = FakeToolCallingModel[WeatherResponse](
# model = FakeToolCallingModel(
# tool_calls=tool_calls, structured_response=expected_structured_response
# )
# agent = create_agent(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def test_union_of_types(self) -> None:
],
]

model = FakeToolCallingModel[WeatherBaseModel | LocationResponse](tool_calls=tool_calls)
model = FakeToolCallingModel(tool_calls=tool_calls)

agent = create_agent(
model,
Expand Down Expand Up @@ -655,7 +655,7 @@ def test_pydantic_model(self) -> None:
[{"args": {}, "id": "1", "name": "get_weather"}],
]

model = FakeToolCallingModel[WeatherBaseModel](
model = FakeToolCallingModel(
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
)

Expand All @@ -678,7 +678,7 @@ def test_validation_error_with_invalid_response(self) -> None:
]

# But we're using WeatherBaseModel which has different field requirements
model = FakeToolCallingModel[dict](
model = FakeToolCallingModel(
tool_calls=tool_calls,
structured_response={"invalid": "data"}, # Wrong structure
)
Expand All @@ -699,7 +699,7 @@ def test_dataclass(self) -> None:
[{"args": {}, "id": "1", "name": "get_weather"}],
]

model = FakeToolCallingModel[WeatherDataclass](
model = FakeToolCallingModel(
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DATACLASS
)

Expand All @@ -719,7 +719,7 @@ def test_typed_dict(self) -> None:
[{"args": {}, "id": "1", "name": "get_weather"}],
]

model = FakeToolCallingModel[WeatherTypedDict](
model = FakeToolCallingModel(
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
)

Expand All @@ -737,7 +737,7 @@ def test_json_schema(self) -> None:
[{"args": {}, "id": "1", "name": "get_weather"}],
]

model = FakeToolCallingModel[dict](
model = FakeToolCallingModel(
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_DICT
)

Expand Down Expand Up @@ -858,7 +858,7 @@ def test_union_of_types() -> None:
],
]

model = FakeToolCallingModel[WeatherBaseModel | LocationResponse](
model = FakeToolCallingModel(
tool_calls=tool_calls, structured_response=EXPECTED_WEATHER_PYDANTIC
)

Expand Down