diff --git a/pyproject.toml b/pyproject.toml index 35665fc5..c617a8f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ llama_index = [ ] openai = [ - "openai-agents>=0.2.6, <1", + "openai-agents>=0.3.0, <1", ] smolagents = [ diff --git a/src/any_agent/frameworks/openai.py b/src/any_agent/frameworks/openai.py index ef9e5d6a..ca0fa9e9 100644 --- a/src/any_agent/frameworks/openai.py +++ b/src/any_agent/frameworks/openai.py @@ -1,30 +1,381 @@ -import math -from typing import TYPE_CHECKING, Any +from __future__ import annotations -from pydantic import BaseModel +import math +import time +from copy import copy +from typing import TYPE_CHECKING, Any, Literal, cast, overload from any_agent.config import AgentConfig, AgentFramework from .any_agent import AnyAgent try: - from agents import ( - Agent, - Model, - ModelSettings, - Runner, + import any_llm + from agents import Agent, FunctionTool, Model, ModelSettings, Runner, Tool, Usage + from agents.exceptions import UserError + from agents.items import ModelResponse + from agents.models.chatcmpl_converter import Converter as BaseConverter + from agents.models.chatcmpl_stream_handler import ChatCmplStreamHandler + from agents.models.fake_id import FAKE_RESPONSES_ID + from agents.tracing import generation_span + from agents.util._json import _to_dump_compatible + from any_llm import AnyLLM + from openai import NOT_GIVEN, NotGiven, Omit + from openai.types.responses import Response + from openai.types.responses.response_usage import ( + InputTokensDetails, + OutputTokensDetails, ) - from agents.extensions.models.litellm_model import LitellmModel - - DEFAULT_MODEL_TYPE = LitellmModel + omit = Omit() agents_available = True except ImportError: agents_available = False - if TYPE_CHECKING: - from agents import Model + from collections.abc import AsyncIterator + + from agents.agent_output import AgentOutputSchemaBase + from agents.handoffs import Handoff + from agents.items import TResponseInputItem, TResponseStreamEvent + from agents.models.interface import ModelTracing + from agents.tracing.span_data import GenerationSpanData + from agents.tracing.spans import Span + from openai.types.chat import ChatCompletionToolParam + from pydantic import BaseModel + + +class Converter(BaseConverter): + """Same converter as agents.models.chatcmpl_converter.Converter but with strict mode enabled.""" + + @classmethod + def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam: + if isinstance(tool, FunctionTool): + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.params_json_schema, + "strict": tool.strict_json_schema, # adding missing field in BaseConverter + }, + } + + msg = ( + "Hosted tools are not supported with the ChatCompletions API." + f" Got tool type: {type(tool)}, tool: {tool}" + ) + raise UserError(msg) + + +class AnyllmModel(Model): + """Enables using any model via AnyLLM. + + AnyLLM allows you to access OpenAI, Anthropic, Gemini, Mistral, and many other models. + See supported providers/models here: https://mozilla-ai.github.io/any-llm/providers/ + """ + + def __init__( + self, + model: str, + base_url: str | None = None, + api_key: str | None = None, + ): + provider, model_id = AnyLLM.split_model_provider(model) + self.model = model + self.base_url = base_url + self.api_key = api_key + self.llm = AnyLLM.create(provider=provider, api_key=api_key, api_base=base_url) + self.model_id = model_id + + async def get_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], # noqa: A002 + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused + prompt: Any | None = None, + ) -> ModelResponse: + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() + | {"base_url": str(self.base_url or ""), "model_impl": "anyllm"}, + disabled=tracing.is_disabled(), + ) as span_generation: + response = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=False, + prompt=prompt, + ) + + assert isinstance(response.choices[0], any_llm.types.completion.Choice) + + usage = Usage() + if hasattr(response, "usage") and response.usage: + response_usage = response.usage + usage = Usage( + requests=1, + input_tokens=response_usage.prompt_tokens, + output_tokens=response_usage.completion_tokens, + total_tokens=response_usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=( + getattr(response_usage, "prompt_tokens_details", None) + and getattr( + response_usage.prompt_tokens_details, "cached_tokens", 0 + ) + ) + or 0 + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=( + getattr(response_usage, "completion_tokens_details", None) + and getattr( + response_usage.completion_tokens_details, + "reasoning_tokens", + 0, + ) + ) + or 0 + ), + ) + + if tracing.include_data(): + span_generation.span_data.output = [ + response.choices[0].message.model_dump() + ] + span_generation.span_data.usage = { + "input_tokens": usage.input_tokens, + "output_tokens": usage.output_tokens, + } + + items = Converter.message_to_output_items(response.choices[0].message) + + return ModelResponse( + output=items, + usage=usage, + response_id=None, + ) + + async def stream_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], # noqa: A002 + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + tracing: ModelTracing, + previous_response_id: str | None = None, # unused + conversation_id: str | None = None, # unused + prompt: Any | None = None, + ) -> AsyncIterator[TResponseStreamEvent]: + with generation_span( + model=str(self.model), + model_config=model_settings.to_json_dict() + | {"base_url": str(self.base_url or ""), "model_impl": "anyllm"}, + disabled=tracing.is_disabled(), + ) as span_generation: + response, stream = await self._fetch_response( + system_instructions, + input, + model_settings, + tools, + output_schema, + handoffs, + span_generation, + tracing, + stream=True, + prompt=prompt, + ) + + final_response: Response | None = None + async for chunk in ChatCmplStreamHandler.handle_stream(response, stream): # type: ignore[arg-type] + yield chunk + + if chunk.type == "response.completed": + final_response = chunk.response + + if tracing.include_data() and final_response: + span_generation.span_data.output = [final_response.model_dump()] + + if final_response and final_response.usage: + span_generation.span_data.usage = { + "input_tokens": final_response.usage.input_tokens, + "output_tokens": final_response.usage.output_tokens, + } + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[True], + prompt: Any | None = None, + ) -> tuple[ + Response, AsyncIterator[any_llm.types.completion.ChatCompletionChunk] + ]: ... + + @overload + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: Literal[False], + prompt: Any | None = None, + ) -> any_llm.types.completion.ChatCompletion: ... + + async def _fetch_response( + self, + system_instructions: str | None, + input: str | list[TResponseInputItem], # noqa: A002 + model_settings: ModelSettings, + tools: list[Tool], + output_schema: AgentOutputSchemaBase | None, + handoffs: list[Handoff], + span: Span[GenerationSpanData], + tracing: ModelTracing, + stream: bool = False, + prompt: Any | None = None, + ) -> ( + any_llm.types.completion.ChatCompletion + | tuple[Response, AsyncIterator[any_llm.types.completion.ChatCompletionChunk]] + ): + converted_messages = Converter.items_to_messages(input) + + if system_instructions: + converted_messages.insert( + 0, + { + "content": system_instructions, + "role": "system", + }, + ) + converted_messages = _to_dump_compatible(converted_messages) + + if tracing.include_data(): + span.span_data.input = converted_messages + + parallel_tool_calls = ( + True + if model_settings.parallel_tool_calls and tools and len(tools) > 0 + else False + if model_settings.parallel_tool_calls is False + else None + ) + tool_choice = Converter.convert_tool_choice(model_settings.tool_choice) + response_format = Converter.convert_response_format(output_schema) + + converted_tools = ( + [Converter.tool_to_openai(tool) for tool in tools] if tools else [] + ) + + for handoff in handoffs: + converted_tools.append(Converter.convert_handoff_tool(handoff)) + + converted_tools = _to_dump_compatible(converted_tools) + + reasoning_effort = ( + model_settings.reasoning.effort if model_settings.reasoning else None + ) + + stream_options = None + if stream and model_settings.include_usage is not None: + stream_options = {"include_usage": model_settings.include_usage} + + extra_kwargs = {} + if model_settings.extra_query: + extra_kwargs["extra_query"] = copy(model_settings.extra_query) + if model_settings.metadata: + extra_kwargs["metadata"] = copy(model_settings.metadata) + if model_settings.extra_body and isinstance(model_settings.extra_body, dict): + extra_kwargs.update(model_settings.extra_body) + + # Add kwargs from model_settings.extra_args, filtering out None values + if model_settings.extra_args: + extra_kwargs.update(model_settings.extra_args) + + ret = await self.llm.acompletion( + model=self.model_id, + messages=converted_messages, # type: ignore[arg-type] + tools=converted_tools, + temperature=model_settings.temperature, + top_p=model_settings.top_p, + frequency_penalty=model_settings.frequency_penalty, + presence_penalty=model_settings.presence_penalty, + max_tokens=model_settings.max_tokens, + tool_choice=self._remove_not_given(tool_choice), + response_format=self._remove_not_given(response_format), + parallel_tool_calls=parallel_tool_calls, + stream=stream, + stream_options=stream_options, + reasoning_effort=reasoning_effort, + top_logprobs=model_settings.top_logprobs, + **extra_kwargs, # type: ignore[arg-type] + ) + + if isinstance(ret, any_llm.types.completion.ChatCompletion): + return ret + + # If we reach here AND stream=False, something went wrong! + if not stream: + msg = ( + f"Expected any_llm.types.completion.ChatCompletion but got {type(ret)}" + ) + raise TypeError(msg) + + tool_choice_value = ( + cast("Literal['auto', 'required', 'none']", tool_choice) + if tool_choice not in (NOT_GIVEN, omit) + and not isinstance(tool_choice, (NotGiven, type(omit))) + else "auto" + ) + response = Response( + id=FAKE_RESPONSES_ID, + created_at=time.time(), + model=self.model, + object="response", + output=[], + tool_choice=tool_choice_value, + top_p=model_settings.top_p, + temperature=model_settings.temperature, + tools=[], + parallel_tool_calls=parallel_tool_calls or False, + reasoning=model_settings.reasoning, + ) + return response, ret + + def _remove_not_given(self, value: Any) -> Any: + if isinstance(value, (NotGiven, type(omit))): + return None + return value + + +DEFAULT_MODEL_TYPE = AnyllmModel class OpenAIAgent(AnyAgent): @@ -41,12 +392,16 @@ def framework(self) -> AgentFramework: def _get_model( self, agent_config: AgentConfig, - ) -> "Model": + ) -> Model: """Get the model configuration for an OpenAI agent.""" model_type = agent_config.model_type or DEFAULT_MODEL_TYPE + model_args = agent_config.model_args or {} + base_url = agent_config.api_base or cast( + "str | None", model_args.get("api_base") + ) return model_type( model=agent_config.model_id, - base_url=agent_config.api_base, + base_url=base_url, api_key=agent_config.api_key, ) diff --git a/src/any_agent/testing/helpers.py b/src/any_agent/testing/helpers.py index 3370623a..c1e0b816 100644 --- a/src/any_agent/testing/helpers.py +++ b/src/any_agent/testing/helpers.py @@ -16,7 +16,7 @@ AgentFramework.LANGCHAIN: "litellm.acompletion", AgentFramework.TINYAGENT: "any_agent.frameworks.tinyagent.acompletion", AgentFramework.AGNO: "any_agent.frameworks.agno.acompletion", - AgentFramework.OPENAI: "litellm.acompletion", + AgentFramework.OPENAI: "any_llm.AnyLLM.acompletion", AgentFramework.SMOLAGENTS: "any_llm.completion", AgentFramework.LLAMA_INDEX: "litellm.acompletion", } diff --git a/src/any_agent/tools/wrappers.py b/src/any_agent/tools/wrappers.py index 8067987d..96cc8610 100644 --- a/src/any_agent/tools/wrappers.py +++ b/src/any_agent/tools/wrappers.py @@ -43,7 +43,9 @@ def _wrap_tool_openai(tool: "Tool | AgentTool") -> "AgentTool": if isinstance(tool, AgentTool): # type: ignore[arg-type, misc] return tool # type: ignore[return-value] - return function_tool(tool) # type: ignore[arg-type] + # Enabling strict mode required else + # throws error "Only strict function tools can be auto-parsed" + return function_tool(tool, strict_mode=True) # type: ignore[arg-type] def _wrap_tool_langchain(tool: "Tool | LangchainTool") -> "LangchainTool": diff --git a/tests/conftest.py b/tests/conftest.py index 65c70758..4bc961c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import pytest from litellm.types.utils import ModelResponse -from openai.types.chat.chat_completion import ChatCompletion +from any_llm.types.completion import ChatCompletion from any_agent.config import AgentFramework from any_agent.logging import setup_logger @@ -169,21 +169,25 @@ def mock_any_llm_response() -> ChatCompletion: """Fixture to create a standard mock any-llm response""" return ChatCompletion.model_validate( { - "id": "44bb9c60ab374897825da5edfbd15126", + "id": "chatcmpl-BWnfbHWPsQp05roQ06LAD1mZ9tOjT", "choices": [ { "finish_reason": "stop", "index": 0, "message": { - "content": "Hello! 😊 How can I assist you today?", + "content": "The state capital of Pennsylvania is Harrisburg.", "role": "assistant", }, } ], - "created": 1754648476, + "created": 1747157127, "model": "mistral-small-latest", "object": "chat.completion", - "usage": {"completion_tokens": 13, "prompt_tokens": 5, "total_tokens": 18}, + "usage": { + "completion_tokens": 11, + "prompt_tokens": 138, + "total_tokens": 149, + }, } ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..9cf7161d --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,20 @@ +import os +from collections.abc import Generator + +import pytest + +from any_agent.config import AgentFramework + + +@pytest.fixture(autouse=True) +def mock_api_keys_for_unit_tests( + request: pytest.FixtureRequest, +) -> Generator[None, None, None]: + """Automatically provide dummy API keys for unit tests to avoid API key requirements.""" + if "agent_framework" in request.fixturenames: + agent_framework = request.getfixturevalue("agent_framework") + # Only set dummy API key if we're in a test that uses the agent_framework fixture + # and the framework is OPENAI (which uses any-llm with the AnyLLM.create class-based interface) + if agent_framework == AgentFramework.OPENAI: + os.environ["MISTRAL_API_KEY"] = "dummy-mistral-key-for-unit-tests" + yield # noqa: PT022 diff --git a/tests/unit/frameworks/test_any_agent.py b/tests/unit/frameworks/test_any_agent.py index 47015d5e..0b7f45c8 100644 --- a/tests/unit/frameworks/test_any_agent.py +++ b/tests/unit/frameworks/test_any_agent.py @@ -52,6 +52,7 @@ def test_create_any_with_invalid_string() -> None: def test_model_args( agent_framework: AgentFramework, mock_litellm_response: Any, + mock_any_llm_response: Any, ) -> None: if agent_framework == AgentFramework.LLAMA_INDEX: pytest.skip("LlamaIndex agent uses a litellm streaming syntax") @@ -59,7 +60,11 @@ def test_model_args( agent = create_agent_with_model_args(agent_framework) import_path = LLM_IMPORT_PATHS[agent_framework] - with patch(import_path, return_value=mock_litellm_response) as mock_llm: + mock_response = ( + mock_litellm_response if "litellm" in import_path else mock_any_llm_response + ) + + with patch(import_path, return_value=mock_response) as mock_llm: result = agent.run(TEST_QUERY) assert EXPECTED_OUTPUT == result.final_output assert mock_llm.call_args.kwargs["temperature"] == TEST_TEMPERATURE diff --git a/tests/unit/frameworks/test_openai.py b/tests/unit/frameworks/test_openai.py index 6d58085c..321dbe10 100644 --- a/tests/unit/frameworks/test_openai.py +++ b/tests/unit/frameworks/test_openai.py @@ -13,25 +13,25 @@ def test_load_openai_default() -> None: mock_agent = MagicMock() mock_function_tool = MagicMock() - mock_litellm_model = MagicMock() + mock_model = MagicMock() with ( patch("any_agent.frameworks.openai.Agent", mock_agent), patch("agents.function_tool", mock_function_tool), - patch("any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", mock_litellm_model), + patch("any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", mock_model), ): AnyAgent.create( AgentFramework.OPENAI, AgentConfig(model_id="mistral/mistral-small-latest") ) - mock_litellm_model.assert_called_once_with( + mock_model.assert_called_once_with( model="mistral/mistral-small-latest", base_url=None, api_key=None, ) mock_agent.assert_called_once_with( name="any_agent", - model=mock_litellm_model.return_value, + model=mock_model.return_value, instructions=None, tools=[], mcp_servers=[], @@ -40,12 +40,12 @@ def test_load_openai_default() -> None: def test_openai_with_api_base() -> None: mock_agent = MagicMock() - litllm_model_mock = MagicMock() + mock_model = MagicMock() with ( patch("any_agent.frameworks.openai.Agent", mock_agent), patch( "any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", - litllm_model_mock, + mock_model, ), ): AnyAgent.create( @@ -54,7 +54,7 @@ def test_openai_with_api_base() -> None: model_id="mistral/mistral-small-latest", model_args={}, api_base="FOO" ), ) - litllm_model_mock.assert_called_once_with( + mock_model.assert_called_once_with( model="mistral/mistral-small-latest", base_url="FOO", api_key=None, @@ -63,12 +63,12 @@ def test_openai_with_api_base() -> None: def test_openai_with_api_key() -> None: mock_agent = MagicMock() - litellm_model_mock = MagicMock() + mock_model = MagicMock() with ( patch("any_agent.frameworks.openai.Agent", mock_agent), patch( "any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", - litellm_model_mock, + mock_model, ), ): AnyAgent.create( @@ -77,7 +77,7 @@ def test_openai_with_api_key() -> None: model_id="mistral/mistral-small-latest", model_args={}, api_key="FOO" ), ) - litellm_model_mock.assert_called_once_with( + mock_model.assert_called_once_with( model="mistral/mistral-small-latest", base_url=None, api_key="FOO", @@ -87,13 +87,13 @@ def test_openai_with_api_key() -> None: def test_load_openai_with_mcp_server() -> None: mock_agent = MagicMock() mock_function_tool = MagicMock() - mock_litellm_model = MagicMock() + mock_model = MagicMock() mock_wrap_tools = MagicMock() with ( patch("any_agent.frameworks.openai.Agent", mock_agent), patch("agents.function_tool", mock_function_tool), - patch("any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", mock_litellm_model), + patch("any_agent.frameworks.openai.DEFAULT_MODEL_TYPE", mock_model), patch.object(AnyAgent, "_load_tools", mock_wrap_tools), ): @@ -124,7 +124,7 @@ async def side_effect(tools): # type: ignore[no-untyped-def] # No separate MCP servers with new architecture mock_agent.assert_called_once_with( name="any_agent", - model=mock_litellm_model.return_value, + model=mock_model.return_value, instructions=None, tools=[ mock_function_tool(search_web), diff --git a/tests/unit/tools/test_exception_run.py b/tests/unit/tools/test_exception_run.py index af7b029d..7b0409b4 100644 --- a/tests/unit/tools/test_exception_run.py +++ b/tests/unit/tools/test_exception_run.py @@ -10,6 +10,13 @@ ModelResponseStream, StreamingChoices, ) +from any_llm.types.completion import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageFunctionToolCall, + Choice, + Function as AnyllmFunction, +) from any_agent import ( AgentConfig, @@ -83,6 +90,7 @@ def search_web(query: str) -> str: If you need detailed comparisons or specific recommendations, I can help with that as well. Would you like me to do that? """ + # for frameworks using litellm, we need ModelResponse fake_give_up_response = ModelResponse( choices=[Choices(message=Message(content=give_up))] ) @@ -119,6 +127,47 @@ def search_web(query: str) -> str: choices=[Choices(message=Message(tool_calls=[tool_call]))] ) + # for frameworks using any_llm, we need ChatCompletion + fake_give_up_response_anyllm = ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(content=give_up, role="assistant"), + ) + ], + created=1747157127, + model="mistral-small-latest", + object="chat.completion", + ) + + fake_tool_fail_response_anyllm = ChatCompletion( + id="chatcmpl-test", + choices=[ + Choice( + finish_reason="tool_calls", + index=0, + message=ChatCompletionMessage( + role="assistant", + tool_calls=[ + ChatCompletionMessageFunctionToolCall( + id="call_12345xyz", + type="function", + function=AnyllmFunction( + name="search_web", + arguments='{"query": "which agent framework is the best"}', + ), + ) + ], + ), + ) + ], + created=1747157127, + model="mistral-small-latest", + object="chat.completion", + ) + fake_tool_fail_chunk = ModelResponseStream( choices=[StreamingChoices(delta=Delta(tool_calls=[tool_call]))] ) @@ -133,11 +182,14 @@ async def async_fake_give_up_chunk() -> AsyncIterator[ModelResponseStream]: [async_fake_tool_fail_chunk(), async_fake_give_up_chunk()] ) + import_path = LLM_IMPORT_PATHS[agent_framework] + uses_anyllm = "litellm" not in import_path + with ( - patch(LLM_IMPORT_PATHS[agent_framework]) as litellm_mock, + patch(import_path) as llm_mock, ): if agent_framework in (AgentFramework.LLAMA_INDEX): - litellm_mock.side_effect = streaming.next + llm_mock.side_effect = streaming.next elif agent_framework in (AgentFramework.SMOLAGENTS): # For smolagents, we need to handle the ReAct pattern properly # First call should be the tool failure, then final_answer calls @@ -147,15 +199,25 @@ def smolagents_mock_generator() -> Iterator[ModelResponse]: while True: yield fake_smolagents_final_answer_response - litellm_mock.side_effect = smolagents_mock_generator() + llm_mock.side_effect = smolagents_mock_generator() else: - # For other frameworks, just use the simple approach - def other_mock_generator() -> Iterator[ModelResponse]: - yield fake_tool_fail_response - while True: - yield fake_give_up_response + # For other frameworks, use the appropriate mock type (anyllm or litellm) + if uses_anyllm: + + def anyllm_mock_generator() -> Iterator[ChatCompletion]: + yield fake_tool_fail_response_anyllm + while True: + yield fake_give_up_response_anyllm + + llm_mock.side_effect = anyllm_mock_generator() + else: + + def other_mock_generator() -> Iterator[ModelResponse]: + yield fake_tool_fail_response + while True: + yield fake_give_up_response - litellm_mock.side_effect = other_mock_generator() + llm_mock.side_effect = other_mock_generator() agent_trace = agent.run( "Check in the web which agent framework is the best.", diff --git a/tests/unit/tools/test_unit_wrappers.py b/tests/unit/tools/test_unit_wrappers.py index 13eb085d..b551dbfd 100644 --- a/tests/unit/tools/test_unit_wrappers.py +++ b/tests/unit/tools/test_unit_wrappers.py @@ -58,13 +58,13 @@ def test_wrap_tool_openai() -> None: wrapper = MagicMock() with patch("agents.function_tool", wrapper): _wrap_tool_openai(foo) - wrapper.assert_called_with(foo) + wrapper.assert_called_with(foo, strict_mode=True) def test_wrap_tool_openai_already_wrapped() -> None: from agents import function_tool - wrapped = function_tool(foo) + wrapped = function_tool(foo, strict_mode=True) wrapper = MagicMock() with patch("agents.function_tool", wrapper): _wrap_tool_openai(wrapped)