diff --git a/chatlas/_chat.py b/chatlas/_chat.py index 8286fa0f..c6ca7524 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -18,6 +18,7 @@ Optional, Sequence, TypeVar, + cast, overload, ) @@ -42,7 +43,7 @@ from ._tools import Tool from ._turn import Turn, user_turn from ._typing_extensions import TypedDict -from ._utils import html_escape, wrap_async +from ._utils import html_escape, is_async_callable, wrap_async class AnyTypeDict(TypedDict, total=False): @@ -91,6 +92,14 @@ def __init__( self.provider = provider self._turns: list[Turn] = list(turns or []) self._tools: dict[str, Tool] = {} + self._on_tool_request: Optional[ + Callable[[ContentToolRequest], None] + | Callable[[ContentToolRequest], Awaitable[None]] + ] = None + self._on_tool_result: Optional[ + Callable[[ContentToolResult], None] + | Callable[[ContentToolResult], Awaitable[None]] + ] = None self._echo_options: EchoOptions = { "rich_markdown": {}, "rich_console": {}, @@ -908,6 +917,34 @@ def add(a: int, b: int) -> int: tool = Tool(func, model=model) self._tools[tool.name] = tool + def on_tool_request( + self, + func: Callable[[ContentToolRequest], None] + | Callable[[ContentToolRequest], Awaitable[None]], + ): + """ + Register a function to be called when a tool is requested. + + This function will be called with a single argument, a `ContentToolRequest` + object, which contains the tool name and the input parameters for the tool. + """ + self._on_tool_request = func + + def on_tool_result( + self, + func: Callable[[ContentToolResult], None] + | Callable[[ContentToolResult], Awaitable[None]], + ): + """ + Register a function to be called when a tool result is received. + + This function will be called with a single argument, a `ContentToolResult` + object, which contains the tool name and the output of the tool. + + TODO: explain how to check for errors in the tool result + """ + self._on_tool_result = func + def export( self, filename: str | Path, @@ -1205,12 +1242,30 @@ def _invoke_tools(self) -> Turn | None: if turn is None: return None + on_request = self._on_tool_request + if on_request is not None and is_async_callable(on_request): + raise ValueError( + "Cannot use async on_tool_request callback in a synchronous chat" + ) + + on_result = self._on_tool_result + if on_result is not None and is_async_callable(on_result): + raise ValueError( + "Cannot use async on_tool_result callback in a synchronous chat" + ) + + on_result = cast(Callable[[ContentToolResult], None], on_result) + results: list[ContentToolResult] = [] for x in turn.contents: if isinstance(x, ContentToolRequest): + if on_request is not None: + on_request(x) tool_def = self._tools.get(x.name, None) func = tool_def.func if tool_def is not None else None - results.append(self._invoke_tool(func, x.arguments, x.id)) + results.append( + self._invoke_tool(func, x.arguments, x.id, x.name, on_result) + ) if not results: return None @@ -1222,9 +1277,21 @@ async def _invoke_tools_async(self) -> Turn | None: if turn is None: return None + on_request = self._on_tool_request + if on_request is not None: + on_request = wrap_async(on_request) + + on_result = self._on_tool_result + if on_result is not None: + on_result = wrap_async(on_result) + + on_result = cast(Callable[[ContentToolResult], Awaitable[None]], on_result) + results: list[ContentToolResult] = [] for x in turn.contents: if isinstance(x, ContentToolRequest): + if on_request is not None: + await on_request(x) tool_def = self._tools.get(x.name, None) func = None if tool_def: @@ -1232,7 +1299,11 @@ async def _invoke_tools_async(self) -> Turn | None: func = tool_def.func else: func = wrap_async(tool_def.func) - results.append(await self._invoke_tool_async(func, x.arguments, x.id)) + results.append( + await self._invoke_tool_async( + func, x.arguments, x.id, x.name, on_result + ) + ) if not results: return None @@ -1244,12 +1315,18 @@ def _invoke_tool( func: Callable[..., Any] | None, arguments: object, id_: str, + name: str, + on_result: Optional[Callable[[ContentToolResult], None]] = None, ) -> ContentToolResult: if func is None: - return ContentToolResult(id_, value=None, error="Unknown tool") + res = ContentToolResult(id_, value=None, error="Unknown tool", name=name) + if on_result is not None: + on_result(res) + return res name = func.__name__ + res = None try: if isinstance(arguments, dict): result = func(**arguments) @@ -1259,19 +1336,30 @@ def _invoke_tool( return ContentToolResult(id_, value=result, error=None, name=name) except Exception as e: log_tool_error(name, str(arguments), e) - return ContentToolResult(id_, value=None, error=str(e), name=name) + res = ContentToolResult(id_, value=None, error=str(e), name=name) + + if on_result is not None: + on_result(res) + + return res @staticmethod async def _invoke_tool_async( func: Callable[..., Awaitable[Any]] | None, arguments: object, id_: str, + name: str, + on_result: Optional[Callable[[ContentToolResult], Awaitable[None]]] = None, ) -> ContentToolResult: if func is None: - return ContentToolResult(id_, value=None, error="Unknown tool") + res = ContentToolResult(id_, value=None, error="Unknown tool", name=name) + if on_result is not None: + await on_result(res) + return res name = func.__name__ + res = None try: if isinstance(arguments, dict): result = await func(**arguments) @@ -1281,7 +1369,11 @@ async def _invoke_tool_async( return ContentToolResult(id_, value=result, error=None, name=name) except Exception as e: log_tool_error(func.__name__, str(arguments), e) - return ContentToolResult(id_, value=None, error=str(e), name=name) + res = ContentToolResult(id_, value=None, error=str(e), name=name) + if on_result is not None: + await on_result(res) + + return res def _markdown_display( self, echo: Literal["text", "all", "none"] diff --git a/chatlas/_content.py b/chatlas/_content.py index 56e313c2..dce62dfd 100644 --- a/chatlas/_content.py +++ b/chatlas/_content.py @@ -185,16 +185,28 @@ class ContentToolResult(Content): Parameters ---------- id - The unique identifier of the tool request. + The unique identifier for the tool result. + name + The name of the tool/function that was called. value The value returned by the tool/function. name The name of the tool/function that was called. error An error message if the tool/function call failed. + + Note + ---- + If the tool/function call failed, the `value` field will be `None` and the + `error` field will contain the error message. + If the tool/function call succeeded, the `value` field will contain the + return value and the `error` field will be `None`. + To get the actual result sent to the model assistant, use the `get_final_value()` + method. """ id: str + name: str value: Any = None name: Optional[str] = None error: Optional[str] = None @@ -209,7 +221,7 @@ def _get_value_and_language(self) -> tuple[str, str]: return str(self.value), "" def __str__(self): - comment = f"# tool result ({self.id})" + comment = f"# tool ({self.name}) result ({self.id})" value, language = self._get_value_and_language() return f"""```{language}\n{comment}\n{value}\n```""" @@ -219,7 +231,9 @@ def _repr_markdown_(self): def __repr__(self, indent: int = 0): res = " " * indent - res += f""