From a85b991c7aa33e95ebe3e786b4f583c9f42abdca Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 11 Feb 2025 15:36:27 -0600 Subject: [PATCH 1/2] Quick and dirty stab at on_tool_result()/on_tool_request() callbacks --- chatlas/_chat.py | 111 ++++++++++++++++++++++++++++++++++++++++---- chatlas/_content.py | 20 ++++++-- chatlas/_google.py | 1 + 3 files changed, 120 insertions(+), 12 deletions(-) diff --git a/chatlas/_chat.py b/chatlas/_chat.py index a75f43d4..d9f97c33 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -16,6 +16,7 @@ Optional, Sequence, TypeVar, + cast, overload, ) @@ -40,7 +41,7 @@ from ._tools import Tool from ._turn import Turn, user_turn from ._typing_extensions import TypedDict -from ._utils import html_escape +from ._utils import html_escape, is_async_callable, wrap_async class AnyTypeDict(TypedDict, total=False): @@ -89,6 +90,14 @@ def __init__( self.provider = provider self._turns: list[Turn] = list(turns or []) self._tools: dict[str, Tool] = {} + self._on_tool_request: Optional[ + Callable[[ContentToolRequest], None] + | Callable[[ContentToolRequest], Awaitable[None]] + ] = None + self._on_tool_result: Optional[ + Callable[[ContentToolResult], None] + | Callable[[ContentToolResult], Awaitable[None]] + ] = None self._echo_options: EchoOptions = { "rich_markdown": {}, "rich_console": {}, @@ -906,6 +915,34 @@ def add(a: int, b: int) -> int: tool = Tool(func, model=model) self._tools[tool.name] = tool + def on_tool_request( + self, + func: Callable[[ContentToolRequest], None] + | Callable[[ContentToolRequest], Awaitable[None]], + ): + """ + Register a function to be called when a tool is requested. + + This function will be called with a single argument, a `ContentToolRequest` + object, which contains the tool name and the input parameters for the tool. + """ + self._on_tool_request = func + + def on_tool_result( + self, + func: Callable[[ContentToolResult], None] + | Callable[[ContentToolResult], Awaitable[None]], + ): + """ + Register a function to be called when a tool result is received. + + This function will be called with a single argument, a `ContentToolResult` + object, which contains the tool name and the output of the tool. + + TODO: explain how to check for errors in the tool result + """ + self._on_tool_result = func + def export( self, filename: str | Path, @@ -1205,12 +1242,30 @@ def _invoke_tools(self) -> Turn | None: if turn is None: return None + on_request = self._on_tool_request + if on_request is not None and is_async_callable(on_request): + raise ValueError( + "Cannot use async on_tool_request callback in a synchronous chat" + ) + + on_result = self._on_tool_result + if on_result is not None and is_async_callable(on_result): + raise ValueError( + "Cannot use async on_tool_result callback in a synchronous chat" + ) + + on_result = cast(Callable[[ContentToolResult], None], on_result) + results: list[ContentToolResult] = [] for x in turn.contents: if isinstance(x, ContentToolRequest): + if on_request is not None: + on_request(x) tool_def = self._tools.get(x.name, None) func = tool_def.func if tool_def is not None else None - results.append(self._invoke_tool(func, x.arguments, x.id)) + results.append( + self._invoke_tool(func, x.arguments, x.id, x.name, on_result) + ) if not results: return None @@ -1222,12 +1277,28 @@ async def _invoke_tools_async(self) -> Turn | None: if turn is None: return None + on_request = self._on_tool_request + if on_request is not None: + on_request = wrap_async(on_request) + + on_result = self._on_tool_result + if on_result is not None: + on_result = wrap_async(on_result) + + on_result = cast(Callable[[ContentToolResult], Awaitable[None]], on_result) + results: list[ContentToolResult] = [] for x in turn.contents: if isinstance(x, ContentToolRequest): + if on_request is not None: + await on_request(x) tool_def = self._tools.get(x.name, None) func = tool_def.func if tool_def is not None else None - results.append(await self._invoke_tool_async(func, x.arguments, x.id)) + results.append( + await self._invoke_tool_async( + func, x.arguments, x.id, x.name, on_result + ) + ) if not results: return None @@ -1239,40 +1310,62 @@ def _invoke_tool( func: Callable[..., Any] | None, arguments: object, id_: str, + name: str, + on_result: Optional[Callable[[ContentToolResult], None]] = None, ) -> ContentToolResult: if func is None: - return ContentToolResult(id_, None, "Unknown tool") + res = ContentToolResult(id_, name, None, "Unknown tool") + if on_result is not None: + on_result(res) + return res + res = None try: if isinstance(arguments, dict): result = func(**arguments) else: result = func(arguments) - return ContentToolResult(id_, result, None) + res = ContentToolResult(id_, name, result, None) except Exception as e: log_tool_error(func.__name__, str(arguments), e) - return ContentToolResult(id_, None, str(e)) + res = ContentToolResult(id_, name, None, str(e)) + + if on_result is not None: + on_result(res) + + return res @staticmethod async def _invoke_tool_async( func: Callable[..., Awaitable[Any]] | None, arguments: object, id_: str, + name: str, + on_result: Optional[Callable[[ContentToolResult], Awaitable[None]]] = None, ) -> ContentToolResult: if func is None: - return ContentToolResult(id_, None, "Unknown tool") + res = ContentToolResult(id_, name, None, "Unknown tool") + if on_result is not None: + await on_result(res) + return res + res = None try: if isinstance(arguments, dict): result = await func(**arguments) else: result = await func(arguments) - return ContentToolResult(id_, result, None) + res = ContentToolResult(id_, name, result, None) except Exception as e: log_tool_error(func.__name__, str(arguments), e) - return ContentToolResult(id_, None, str(e)) + res = ContentToolResult(id_, name, None, str(e)) + + if on_result is not None: + await on_result(res) + + return res def _markdown_display( self, echo: Literal["text", "all", "none"] diff --git a/chatlas/_content.py b/chatlas/_content.py index 76ea2bde..9515eb49 100644 --- a/chatlas/_content.py +++ b/chatlas/_content.py @@ -185,14 +185,26 @@ class ContentToolResult(Content): Parameters ---------- id - The unique identifier of the tool request. + The unique identifier for the tool result. + name + The name of the tool/function that was called. value The value returned by the tool/function. error An error message if the tool/function call failed. + + Note + ---- + If the tool/function call failed, the `value` field will be `None` and the + `error` field will contain the error message. + If the tool/function call succeeded, the `value` field will contain the + return value and the `error` field will be `None`. + To get the actual result sent to the model assistant, use the `get_final_value()` + method. """ id: str + name: str value: Any = None error: Optional[str] = None @@ -206,7 +218,7 @@ def _get_value_and_language(self) -> tuple[str, str]: return str(self.value), "" def __str__(self): - comment = f"# tool result ({self.id})" + comment = f"# tool ({self.name}) result ({self.id})" value, language = self._get_value_and_language() return f"""```{language}\n{comment}\n{value}\n```""" @@ -216,7 +228,9 @@ def _repr_markdown_(self): def __repr__(self, indent: int = 0): res = " " * indent - res += f"" diff --git a/chatlas/_google.py b/chatlas/_google.py index 313048f7..317a964b 100644 --- a/chatlas/_google.py +++ b/chatlas/_google.py @@ -460,6 +460,7 @@ def _as_turn( func = part.function_response contents.append( ContentToolResult( + func.name, func.name, value=func.response, ) From dd06cb4791d4638a340ec662b284bc647488b8f1 Mon Sep 17 00:00:00 2001 From: Carson Date: Tue, 11 Feb 2025 15:58:51 -0600 Subject: [PATCH 2/2] Fix link to hex logo --- docs/index.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/index.py b/docs/index.py index b745c041..6ee509f6 100644 --- a/docs/index.py +++ b/docs/index.py @@ -14,6 +14,10 @@ {readme_src} """ +# The root for the README is the home directory, but for the Quarto site, it is the docs directory +index_src = index_src.replace('src="docs/', 'src="') +index_src = index_src.replace("src='docs/", "src='") + index = docs_dir / "index.qmd" with open(index, "w") as f: