diff --git a/docs/agents.md b/docs/agents.md index 17589b3d..1c314739 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -130,3 +130,16 @@ robot_agent = pirate_agent.clone( instructions="Write like a robot", ) ``` + +## Forcing tool use + +Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are: + +1. `auto`, which allows the LLM to decide whether or not to use a tool. +2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool). +3. `none`, which requires the LLM to _not_ use a tool. +4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool. + +!!! note + + If requiring tool use, you should consider setting [`Agent.tool_use_behavior`] to stop the Agent from running when a tool output is produced. Otherwise, the Agent might run in an infinite loop, where the LLM produces a tool call , and the tool result is sent to the LLM, and this infinite loops because the LLM is always forced to use a tool. diff --git a/examples/agent_patterns/forcing_tool_use.py b/examples/agent_patterns/forcing_tool_use.py new file mode 100644 index 00000000..3f4e35ae --- /dev/null +++ b/examples/agent_patterns/forcing_tool_use.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import asyncio +from typing import Any, Literal + +from pydantic import BaseModel + +from agents import ( + Agent, + FunctionToolResult, + ModelSettings, + RunContextWrapper, + Runner, + ToolsToFinalOutputFunction, + ToolsToFinalOutputResult, + function_tool, +) + +""" +This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")` +to force the agent to use any tool. + +You can run it with 3 options: +1. `default`: The default behavior, which is to send the tool output to the LLM. In this case, + `tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would + call the tool, the tool would run and send the results to the LLM, and that would repeat + (because the model is forced to use a tool every time.) +2. `first_tool_result`: The first tool result is used as the final output. +3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool + results, and chooses to use the first tool result to generate the final output. + +Usage: +python examples/agent_patterns/forcing_tool_use.py -t default +python examples/agent_patterns/forcing_tool_use.py -t first_tool +python examples/agent_patterns/forcing_tool_use.py -t custom +""" + + +class Weather(BaseModel): + city: str + temperature_range: str + conditions: str + + +@function_tool +def get_weather(city: str) -> Weather: + print("[debug] get_weather called") + return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind") + + +async def custom_tool_use_behavior( + context: RunContextWrapper[Any], results: list[FunctionToolResult] +) -> ToolsToFinalOutputResult: + weather: Weather = results[0].output + return ToolsToFinalOutputResult( + is_final_output=True, final_output=f"{weather.city} is {weather.conditions}." + ) + + +async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"): + if tool_use_behavior == "default": + behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = ( + "run_llm_again" + ) + elif tool_use_behavior == "first_tool": + behavior = "stop_on_first_tool" + elif tool_use_behavior == "custom": + behavior = custom_tool_use_behavior + + agent = Agent( + name="Weather agent", + instructions="You are a helpful agent.", + tools=[get_weather], + tool_use_behavior=behavior, + model_settings=ModelSettings( + tool_choice="required" if tool_use_behavior != "default" else None + ), + ) + + result = await Runner.run(agent, input="What's the weather in Tokyo?") + print(result.final_output) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "-t", + "--tool-use-behavior", + type=str, + required=True, + choices=["default", "first_tool", "custom"], + help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. " + "first_tool_result will cause the first tool result to be used as the final output. " + "custom will use a custom tool use behavior function.", + ) + args = parser.parse_args() + asyncio.run(main(args.tool_use_behavior)) diff --git a/examples/basic/tools.py b/examples/basic/tools.py new file mode 100644 index 00000000..8936065a --- /dev/null +++ b/examples/basic/tools.py @@ -0,0 +1,34 @@ +import asyncio + +from pydantic import BaseModel + +from agents import Agent, Runner, function_tool + + +class Weather(BaseModel): + city: str + temperature_range: str + conditions: str + + +@function_tool +def get_weather(city: str) -> Weather: + print("[debug] get_weather called") + return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.") + + +agent = Agent( + name="Hello world", + instructions="You are a helpful agent.", + tools=[get_weather], +) + + +async def main(): + result = await Runner.run(agent, input="What's the weather in Tokyo?") + print(result.final_output) + # The weather in Tokyo is sunny. + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index 21a2f2a6..a7a1272c 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -5,7 +5,7 @@ from openai import AsyncOpenAI from . import _config -from .agent import Agent +from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult from .agent_output import AgentOutputSchema from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( @@ -57,6 +57,7 @@ ComputerTool, FileSearchTool, FunctionTool, + FunctionToolResult, Tool, WebSearchTool, default_tool_error_function, @@ -137,6 +138,8 @@ def enable_verbose_stdout_logging(): __all__ = [ "Agent", + "ToolsToFinalOutputFunction", + "ToolsToFinalOutputResult", "Runner", "Model", "ModelProvider", @@ -190,6 +193,7 @@ def enable_verbose_stdout_logging(): "AgentUpdatedStreamEvent", "StreamEvent", "FunctionTool", + "FunctionToolResult", "ComputerTool", "FileSearchTool", "Tool", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index c0c0ebd0..2849538d 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -1,8 +1,10 @@ from __future__ import annotations import asyncio +import inspect +from collections.abc import Awaitable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from openai.types.responses import ( ResponseComputerToolCall, @@ -25,7 +27,7 @@ from openai.types.responses.response_input_param import ComputerCallOutput from openai.types.responses.response_reasoning_item import ResponseReasoningItem -from .agent import Agent +from .agent import Agent, ToolsToFinalOutputResult from .agent_output import AgentOutputSchema from .computer import AsyncComputer, Computer from .exceptions import AgentsException, ModelBehaviorError, UserError @@ -48,7 +50,7 @@ from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool +from .tool import ComputerTool, FunctionTool, FunctionToolResult from .tracing import ( SpanError, Trace, @@ -70,6 +72,8 @@ class QueueCompleteSentinel: QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel() +_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None) + @dataclass class ToolRunHandoff: @@ -199,7 +203,7 @@ async def execute_tools_and_side_effects( config=run_config, ), ) - new_step_items.extend(function_results) + new_step_items.extend([result.run_item for result in function_results]) new_step_items.extend(computer_results) # Second, check if there are any handoffs @@ -216,6 +220,36 @@ async def execute_tools_and_side_effects( run_config=run_config, ) + # Third, we'll check if the tool use should result in a final output + check_tool_use = await cls._check_for_final_output_from_tools( + agent=agent, + tool_results=function_results, + context_wrapper=context_wrapper, + config=run_config, + ) + + if check_tool_use.is_final_output: + # If the output type is str, then let's just stringify it + if not agent.output_type or agent.output_type is str: + check_tool_use.final_output = str(check_tool_use.final_output) + + if check_tool_use.final_output is None: + logger.error( + "Model returned a final output of None. Not raising an error because we assume" + "you know what you're doing." + ) + + return await cls.execute_final_output( + agent=agent, + original_input=original_input, + new_response=new_response, + pre_step_items=pre_step_items, + new_step_items=new_step_items, + final_output=check_tool_use.final_output, + hooks=hooks, + context_wrapper=context_wrapper, + ) + # Now we can check if the model also produced a final output message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)] @@ -355,10 +389,10 @@ async def execute_function_tool_calls( hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], config: RunConfig, - ) -> list[RunItem]: + ) -> list[FunctionToolResult]: async def run_single_tool( func_tool: FunctionTool, tool_call: ResponseFunctionToolCall - ) -> str: + ) -> Any: with function_span(func_tool.name) as span_fn: if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments @@ -404,10 +438,14 @@ async def run_single_tool( results = await asyncio.gather(*tasks) return [ - ToolCallOutputItem( - output=str(result), - raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)), - agent=agent, + FunctionToolResult( + tool=tool_run.function_tool, + output=result, + run_item=ToolCallOutputItem( + output=result, + raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)), + agent=agent, + ), ) for tool_run, result in zip(tool_runs, results) ] @@ -646,6 +684,47 @@ def stream_step_result_to_queue( if event: queue.put_nowait(event) + @classmethod + async def _check_for_final_output_from_tools( + cls, + *, + agent: Agent[TContext], + tool_results: list[FunctionToolResult], + context_wrapper: RunContextWrapper[TContext], + config: RunConfig, + ) -> ToolsToFinalOutputResult: + """Returns (i, final_output).""" + if not tool_results: + return _NOT_FINAL_OUTPUT + + if agent.tool_use_behavior == "run_llm_again": + return _NOT_FINAL_OUTPUT + elif agent.tool_use_behavior == "stop_on_first_tool": + return ToolsToFinalOutputResult( + is_final_output=True, final_output=tool_results[0].output + ) + elif isinstance(agent.tool_use_behavior, dict): + names = agent.tool_use_behavior.get("stop_at_tool_names", []) + for tool_result in tool_results: + if tool_result.tool.name in names: + return ToolsToFinalOutputResult( + is_final_output=True, final_output=tool_result.output + ) + return ToolsToFinalOutputResult(is_final_output=False, final_output=None) + elif callable(agent.tool_use_behavior): + if inspect.iscoroutinefunction(agent.tool_use_behavior): + return await cast( + Awaitable[ToolsToFinalOutputResult], + agent.tool_use_behavior(context_wrapper, tool_results), + ) + else: + return cast( + ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results) + ) + + logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}") + raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}") + class TraceCtxManager: """Creates a trace only if there is no current trace, and manages the trace lifecycle.""" diff --git a/src/agents/agent.py b/src/agents/agent.py index 3c4588e6..2723e678 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -4,7 +4,9 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast + +from typing_extensions import TypeAlias, TypedDict from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff @@ -13,7 +15,7 @@ from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext -from .tool import Tool, function_tool +from .tool import FunctionToolResult, Tool, function_tool from .util import _transforms from .util._types import MaybeAwaitable @@ -22,6 +24,33 @@ from .result import RunResult +@dataclass +class ToolsToFinalOutputResult: + is_final_output: bool + """Whether this is the final output. If False, the LLM will run again and receive the tool call + output. + """ + + final_output: Any | None = None + """The final output. Can be None if `is_final_output` is False, otherwise must match the + `output_type` of the agent. + """ + + +ToolsToFinalOutputFunction: TypeAlias = Callable[ + [RunContextWrapper[TContext], list[FunctionToolResult]], + MaybeAwaitable[ToolsToFinalOutputResult], +] +"""A function that takes a run context and a list of tool results, and returns a +`ToolToFinalOutputResult`. +""" + + +class StopAtTools(TypedDict): + stop_at_tool_names: list[str] + """A list of tool names, any of which will stop the agent from running further.""" + + @dataclass class Agent(Generic[TContext]): """An agent is an AI model configured with instructions, tools, guardrails, handoffs and more. @@ -95,6 +124,25 @@ class Agent(Generic[TContext]): """A class that receives callbacks on various lifecycle events for this agent. """ + tool_use_behavior: ( + Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction + ) = "run_llm_again" + """This lets you configure how tool use is handled. + - "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results + and gets to respond. + - "stop_on_first_tool": The output of the first tool call is used as the final output. This + means that the LLM does not process the result of the tool call. + - A list of tool names: The agent will stop running if any of the tools in the list are called. + The final output will be the output of the first matching tool call. The LLM does not + process the result of the tool call. + - A function: If you pass a function, it will be called with the run context and the list of + tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool + calls result in a final output. + + NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search, + web search, etc are always processed by the LLM. + """ + def clone(self, **kwargs: Any) -> Agent[TContext]: """Make a copy of the agent, with the given arguments changed. For example, you could do: ``` diff --git a/src/agents/items.py b/src/agents/items.py index ffbeba02..c2af0dfc 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu raw_item: FunctionCallOutput | ComputerCallOutput """The raw item from the model.""" - output: str - """The output of the tool call.""" + output: Any + """The output of the tool call. This is whatever the tool call returned; the `raw_item` + contains a string representation of the output. + """ type: Literal["tool_call_output_item"] = "tool_call_output_item" diff --git a/src/agents/tool.py b/src/agents/tool.py index 3c309217..c1c16242 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -15,6 +15,7 @@ from .computer import AsyncComputer, Computer from .exceptions import ModelBehaviorError from .function_schema import DocstringStyle, function_schema +from .items import RunItem from .logger import logger from .run_context import RunContextWrapper from .tracing import SpanError @@ -29,6 +30,18 @@ ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]] +@dataclass +class FunctionToolResult: + tool: FunctionTool + """The tool that was run.""" + + output: Any + """The output of the tool.""" + + run_item: RunItem + """The run item that was produced as a result of the tool call.""" + + @dataclass class FunctionTool: """A tool that wraps a function. In most cases, you should use the `function_tool` helpers to @@ -44,15 +57,15 @@ class FunctionTool: params_json_schema: dict[str, Any] """The JSON schema for the tool's parameters.""" - on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]] + on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]] """A function that invokes the tool with the given context and parameters. The params passed are: 1. The tool run context. 2. The arguments from the LLM, as a JSON string. - You must return a string representation of the tool output. In case of errors, you can either - raise an Exception (which will cause the run to fail) or return a string error message (which - will be sent back to the LLM). + You must return a string representation of the tool output, or something we can call `str()` on. + In case of errors, you can either raise an Exception (which will cause the run to fail) or + return a string error message (which will be sent back to the LLM). """ strict_json_schema: bool = True @@ -207,7 +220,7 @@ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: strict_json_schema=strict_mode, ) - async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str: + async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any: try: json_data: dict[str, Any] = json.loads(input) if input else {} except Exception as e: @@ -254,9 +267,9 @@ async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str: else: logger.debug(f"Tool {schema.name} returned {result}") - return str(result) + return result - async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: + async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any: try: return await _on_invoke_tool_impl(ctx, input) except Exception as e: diff --git a/src/agents/tracing/span_data.py b/src/agents/tracing/span_data.py index 5e5d38cb..1a49d8e6 100644 --- a/src/agents/tracing/span_data.py +++ b/src/agents/tracing/span_data.py @@ -51,7 +51,7 @@ def export(self) -> dict[str, Any]: class FunctionSpanData(SpanData): __slots__ = ("name", "input", "output") - def __init__(self, name: str, input: str | None, output: str | None): + def __init__(self, name: str, input: str | None, output: Any | None): self.name = name self.input = input self.output = output @@ -65,7 +65,7 @@ def export(self) -> dict[str, Any]: "type": self.type, "name": self.name, "input": self.input, - "output": self.output, + "output": str(self.output) if self.output else None, } diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index c124915a..e8e060fd 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -21,6 +21,8 @@ UserError, handoff, ) +from agents.agent import ToolsToFinalOutputResult +from agents.tool import FunctionToolResult, function_tool from .fake_model import FakeModel from .test_responses import ( @@ -552,3 +554,83 @@ def guardrail_function( with pytest.raises(OutputGuardrailTripwireTriggered): await Runner.run(agent, input="user_message") + + +@function_tool +def test_tool_one(): + return Foo(bar="tool_one_result") + + +@function_tool +def test_tool_two(): + return "tool_two_result" + + +@pytest.mark.asyncio +async def test_tool_use_behavior_first_output(): + model = FakeModel() + agent = Agent( + name="test", + model=model, + tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two], + tool_use_behavior="stop_on_first_tool", + output_type=Foo, + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [ + get_text_message("a_message"), + get_function_tool_call("test_tool_one", None), + get_function_tool_call("test_tool_two", None), + ], + ] + ) + + result = await Runner.run(agent, input="user_message") + + assert result.final_output == Foo(bar="tool_one_result"), ( + "should have used the first tool result" + ) + + +def custom_tool_use_behavior( + context: RunContextWrapper[Any], results: list[FunctionToolResult] +) -> ToolsToFinalOutputResult: + if "test_tool_one" in [result.tool.name for result in results]: + return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output") + else: + return ToolsToFinalOutputResult(is_final_output=False, final_output=None) + + +@pytest.mark.asyncio +async def test_tool_use_behavior_custom_function(): + model = FakeModel() + agent = Agent( + name="test", + model=model, + tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two], + tool_use_behavior=custom_tool_use_behavior, + ) + + model.add_multiple_turn_outputs( + [ + # First turn: a message and tool call + [ + get_text_message("a_message"), + get_function_tool_call("test_tool_two", None), + ], + # Second turn: a message and tool call + [ + get_text_message("a_message"), + get_function_tool_call("test_tool_one", None), + get_function_tool_call("test_tool_two", None), + ], + ] + ) + + result = await Runner.run(agent, input="user_message") + + assert len(result.raw_responses) == 2, "should have two model responses" + assert result.final_output == "the_final_output", "should have used the custom function" diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 6a78309b..0a57aea8 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -49,10 +49,10 @@ async def test_simple_function(): assert tool.name == "simple_function" result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}') - assert result == "6" + assert result == 6 result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}') - assert result == "3" + assert result == 3 # Missing required argument should raise an error with pytest.raises(ModelBehaviorError): diff --git a/tests/test_tool_use_behavior.py b/tests/test_tool_use_behavior.py new file mode 100644 index 00000000..6a673b7a --- /dev/null +++ b/tests/test_tool_use_behavior.py @@ -0,0 +1,194 @@ +# Copyright + +from __future__ import annotations + +from typing import cast + +import pytest +from openai.types.responses.response_input_item_param import FunctionCallOutput + +from agents import ( + Agent, + FunctionToolResult, + RunConfig, + RunContextWrapper, + ToolCallOutputItem, + ToolsToFinalOutputResult, + UserError, +) +from agents._run_impl import RunImpl + +from .test_responses import get_function_tool + + +def _make_function_tool_result( + agent: Agent, output: str, tool_name: str | None = None +) -> FunctionToolResult: + # Construct a FunctionToolResult with the given output using a simple function tool. + tool = get_function_tool(tool_name or "dummy", return_value=output) + raw_item: FunctionCallOutput = cast( + FunctionCallOutput, + { + "call_id": "1", + "output": output, + "type": "function_call_output", + }, + ) + # For this test we don't care about the specific RunItem subclass, only the output field + run_item = ToolCallOutputItem(agent=agent, raw_item=raw_item, output=output) + return FunctionToolResult(tool=tool, output=output, run_item=run_item) + + +@pytest.mark.asyncio +async def test_no_tool_results_returns_not_final_output() -> None: + # If there are no tool results at all, tool_use_behavior should not produce a final output. + agent = Agent(name="test") + result = await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=[], + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + assert result.is_final_output is False + assert result.final_output is None + + +@pytest.mark.asyncio +async def test_run_llm_again_behavior() -> None: + # With the default run_llm_again behavior, even with tools we still expect to keep running. + agent = Agent(name="test", tool_use_behavior="run_llm_again") + tool_results = [_make_function_tool_result(agent, "ignored")] + result = await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=tool_results, + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + assert result.is_final_output is False + assert result.final_output is None + + +@pytest.mark.asyncio +async def test_stop_on_first_tool_behavior() -> None: + # When tool_use_behavior is stop_on_first_tool, we should surface first tool output as final. + agent = Agent(name="test", tool_use_behavior="stop_on_first_tool") + tool_results = [ + _make_function_tool_result(agent, "first_tool_output"), + _make_function_tool_result(agent, "ignored"), + ] + result = await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=tool_results, + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + assert result.is_final_output is True + assert result.final_output == "first_tool_output" + + +@pytest.mark.asyncio +async def test_custom_tool_use_behavior_sync() -> None: + """If tool_use_behavior is a sync function, we should call it and propagate its return.""" + + def behavior( + context: RunContextWrapper, results: list[FunctionToolResult] + ) -> ToolsToFinalOutputResult: + assert len(results) == 3 + return ToolsToFinalOutputResult(is_final_output=True, final_output="custom") + + agent = Agent(name="test", tool_use_behavior=behavior) + tool_results = [ + _make_function_tool_result(agent, "ignored1"), + _make_function_tool_result(agent, "ignored2"), + _make_function_tool_result(agent, "ignored3"), + ] + result = await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=tool_results, + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + assert result.is_final_output is True + assert result.final_output == "custom" + + +@pytest.mark.asyncio +async def test_custom_tool_use_behavior_async() -> None: + """If tool_use_behavior is an async function, we should await it and propagate its return.""" + + async def behavior( + context: RunContextWrapper, results: list[FunctionToolResult] + ) -> ToolsToFinalOutputResult: + assert len(results) == 3 + return ToolsToFinalOutputResult(is_final_output=True, final_output="async_custom") + + agent = Agent(name="test", tool_use_behavior=behavior) + tool_results = [ + _make_function_tool_result(agent, "ignored1"), + _make_function_tool_result(agent, "ignored2"), + _make_function_tool_result(agent, "ignored3"), + ] + result = await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=tool_results, + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + assert result.is_final_output is True + assert result.final_output == "async_custom" + + +@pytest.mark.asyncio +async def test_invalid_tool_use_behavior_raises() -> None: + """If tool_use_behavior is invalid, we should raise a UserError.""" + agent = Agent(name="test") + # Force an invalid value; mypy will complain, so ignore the type here. + agent.tool_use_behavior = "bad_value" # type: ignore[assignment] + tool_results = [_make_function_tool_result(agent, "ignored")] + with pytest.raises(UserError): + await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=tool_results, + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + + +@pytest.mark.asyncio +async def test_tool_names_to_stop_at_behavior() -> None: + agent = Agent( + name="test", + tools=[ + get_function_tool("tool1", return_value="tool1_output"), + get_function_tool("tool2", return_value="tool2_output"), + get_function_tool("tool3", return_value="tool3_output"), + ], + tool_use_behavior={"stop_at_tool_names": ["tool1"]}, + ) + + tool_results = [ + _make_function_tool_result(agent, "ignored1", "tool2"), + _make_function_tool_result(agent, "ignored3", "tool3"), + ] + result = await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=tool_results, + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + assert result.is_final_output is False, "We should not have stopped at tool1" + + # Now test with a tool that matches the list + tool_results = [ + _make_function_tool_result(agent, "output1", "tool1"), + _make_function_tool_result(agent, "ignored2", "tool2"), + _make_function_tool_result(agent, "ignored3", "tool3"), + ] + result = await RunImpl._check_for_final_output_from_tools( + agent=agent, + tool_results=tool_results, + context_wrapper=RunContextWrapper(context=None), + config=RunConfig(), + ) + assert result.is_final_output is True, "We should have stopped at tool1" + assert result.final_output == "output1"