Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: select errors that will be caught when raised within a tool #5488

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -548,18 +548,25 @@ async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
error_prefix_string = "Error: "
if not self._tools + self._handoff_tools:
return FunctionExecutionResult(
content=f"{error_prefix_string}No tools are available.", call_id=tool_call.id, is_error=True
)
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
return FunctionExecutionResult(
content=f"{error_prefix_string}The tool '{tool_call.name}' is not available.",
call_id=tool_call.id,
is_error=True,
)
try:
if not self._tools + self._handoff_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True)
except tool.return_error_types as e:
return FunctionExecutionResult(content=f"{error_prefix_string}{e}", call_id=tool_call.id, is_error=True)

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
Expand Down
11 changes: 10 additions & 1 deletion python/packages/autogen-core/src/autogen_core/tools/_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Dict, Generic, Mapping, Protocol, Type, TypedDict, TypeVar, cast, runtime_checkable
from typing import Any, Dict, Generic, Mapping, Protocol, Tuple, Type, TypedDict, TypeVar, cast, runtime_checkable

import jsonref
from pydantic import BaseModel
Expand Down Expand Up @@ -37,6 +37,9 @@ def description(self) -> str: ...
@property
def schema(self) -> ToolSchema: ...

@property
def return_error_types(self) -> Tuple[Type[Exception], ...] | Type[Exception]: ...

def args_type(self) -> Type[BaseModel]: ...

def return_type(self) -> Type[Any]: ...
Expand Down Expand Up @@ -66,12 +69,14 @@ def __init__(
return_type: Type[ReturnT],
name: str,
description: str,
return_error_types: Tuple[Type[Exception], ...] | Type[Exception] = Exception,
) -> None:
self._args_type = args_type
# Normalize Annotated to the base type.
self._return_type = normalize_annotated_type(return_type)
self._name = name
self._description = description
self._return_error_types = return_error_types

@property
def schema(self) -> ToolSchema:
Expand Down Expand Up @@ -103,6 +108,10 @@ def name(self) -> str:
def description(self) -> str:
return self._description

@property
def return_error_types(self) -> Tuple[Type[Exception], ...] | Type[Exception]:
return self._return_error_types

def args_type(self) -> Type[BaseModel]:
return self._args_type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import functools
import warnings
from textwrap import dedent
from typing import Any, Callable, Sequence
from typing import Any, Callable, Sequence, Tuple, Type

from pydantic import BaseModel
from typing_extensions import Self
Expand Down Expand Up @@ -47,6 +47,8 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]
it does and the context in which it should be called.
name (str, optional): An optional custom name for the tool. Defaults to
the function's original name if not provided.
return_error_types (Tuple[Type[Exception], ...] | Type[Exception]): Select the error types that are forwarded to the
agent. Error types not selected will raise an error. Defaults to forwarding all errors.

Example:

Expand Down Expand Up @@ -83,7 +85,12 @@ async def example():
component_config_schema = FunctionToolConfig

def __init__(
self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = []
self,
func: Callable[..., Any],
description: str,
name: str | None = None,
global_imports: Sequence[Import] = [],
return_error_types: Tuple[Type[Exception], ...] | Type[Exception] = Exception,
) -> None:
self._func = func
self._global_imports = global_imports
Expand All @@ -93,7 +100,7 @@ def __init__(
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters

super().__init__(args_model, return_type, func_name, description)
super().__init__(args_model, return_type, func_name, description, return_error_types)

async def run(self, args: BaseModel, cancellation_token: CancellationToken) -> Any:
if asyncio.iscoroutinefunction(self._func):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 +379,24 @@ def _thread_id(self) -> str:
raise ValueError("Thread not initialized")
return self._thread.id

async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
if not self._original_tools:
raise ValueError("No tools are available.")
return FunctionExecutionResult(content="No tools are available.", call_id=tool_call.id, is_error=True)
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
return tool.return_value_as_string(result)
return FunctionExecutionResult(
content=f"The tool '{tool_call.name}' is not available.", call_id=tool_call.id, is_error=True
)
try:
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False)
except tool.return_error_types as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True)

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response."""
Expand Down Expand Up @@ -460,15 +468,8 @@ async def on_messages_stream(
# Execute tool calls and get results
tool_outputs: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
try:
result = await self._execute_tool_call(tool_call, cancellation_token)
is_error = False
except Exception as e:
result = f"Error: {e}"
is_error = True
tool_outputs.append(
FunctionExecutionResult(content=result, call_id=tool_call.id, is_error=is_error)
)
tool_output = await self._execute_tool_call(tool_call, cancellation_token)
tool_outputs.append(tool_output)

# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
Expand Down