Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 12 additions & 33 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from aiohttp import ClientSession
from deprecated import deprecated

from .protocol import ManifestSchema, ToolSchema
from .itransport import ITransport
from .protocol import ToolSchema
from .tool import ToolboxTool
from .toolbox_transport import ToolboxTransport
from .utils import identify_auth_requirements, resolve_value


Expand All @@ -33,9 +35,7 @@ class ToolboxClient:
is not provided.
"""

__base_url: str
__session: ClientSession
__manage_session: bool
__transport: ITransport

def __init__(
self,
Expand All @@ -56,15 +56,13 @@ def __init__(
should typically be managed externally.
client_headers: Headers to include in each request sent through this client.
"""
self.__base_url = url

# If no aiohttp.ClientSession is provided, make our own
self.__manage_session = False
manage_session = False
if session is None:
self.__manage_session = True
manage_session = True
session = ClientSession()
self.__session = session

self.__transport = ToolboxTransport(url, session, manage_session)
self.__client_headers = client_headers if client_headers is not None else {}

def __parse_tool(
Expand Down Expand Up @@ -103,8 +101,7 @@ def __parse_tool(
)

tool = ToolboxTool(
session=self.__session,
base_url=self.__base_url,
transport=self.__transport,
name=name,
description=schema.description,
# create a read-only values to prevent mutation
Expand Down Expand Up @@ -149,8 +146,7 @@ async def close(self):
If the session was provided externally during initialization, the caller
is responsible for its lifecycle.
"""
if self.__manage_session and not self.__session.closed:
await self.__session.close()
await self.__transport.close()

async def load_tool(
self,
Expand Down Expand Up @@ -191,16 +187,7 @@ async def load_tool(
for name, val in self.__client_headers.items()
}

# request the definition of the tool from the server
url = f"{self.__base_url}/api/tool/{name}"
async with self.__session.get(url, headers=resolved_headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
manifest: ManifestSchema = ManifestSchema(**json)
manifest = await self.__transport.tool_get(name, resolved_headers)

# parse the provided definition to a tool
if name not in manifest.tools:
Expand Down Expand Up @@ -274,16 +261,8 @@ async def load_toolset(
header_name: await resolve_value(original_headers[header_name])
for header_name in original_headers
}
# Request the definition of the toolset from the server
url = f"{self.__base_url}/api/toolset/{name or ''}"
async with self.__session.get(url, headers=resolved_headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
manifest: ManifestSchema = ManifestSchema(**json)

manifest = await self.__transport.tools_list(name, resolved_headers)

tools: list[ToolboxTool] = []
overall_used_auth_keys: set[str] = set()
Expand Down
58 changes: 58 additions & 0 deletions packages/toolbox-core/src/toolbox_core/itransport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Mapping, Optional

from .protocol import ManifestSchema


class ITransport(ABC):
"""Defines the contract for a 'smart' transport that handles both
protocol formatting and network communication.
"""

@property
@abstractmethod
def base_url(self) -> str:
"""The base URL for the transport."""
pass

@abstractmethod
async def tool_get(
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
) -> ManifestSchema:
"""Gets a single tool from the server."""
pass

@abstractmethod
async def tools_list(
self,
toolset_name: Optional[str] = None,
headers: Optional[Mapping[str, str]] = None,
) -> ManifestSchema:
"""Lists available tools from the server."""
pass

@abstractmethod
async def tool_invoke(
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
) -> dict:
"""Invokes a specific tool on the server."""
pass

@abstractmethod
async def close(self):
"""Closes any underlying connections."""
pass
40 changes: 14 additions & 26 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
from warnings import warn

from aiohttp import ClientSession

from .itransport import ITransport
from .protocol import ParameterSchema
from .utils import (
create_func_docstring,
Expand All @@ -46,8 +45,7 @@ class ToolboxTool:

def __init__(
self,
session: ClientSession,
base_url: str,
transport: ITransport,
name: str,
description: str,
params: Sequence[ParameterSchema],
Expand All @@ -68,8 +66,7 @@ def __init__(
Toolbox server.

Args:
session: The `aiohttp.ClientSession` used for making API requests.
base_url: The base URL of the Toolbox server API.
transport: The transport used for making API requests.
name: The name of the remote tool.
description: The description of the remote tool.
params: The args of the tool.
Expand All @@ -84,9 +81,7 @@ def __init__(
client_headers: Client specific headers bound to the tool.
"""
# used to invoke the toolbox API
self.__session: ClientSession = session
self.__base_url: str = base_url
self.__url = f"{base_url}/api/tool/{name}/invoke"
self.__transport = transport
self.__description = description
self.__params = params
self.__pydantic_model = params_to_pydantic_model(name, self.__params)
Expand Down Expand Up @@ -137,9 +132,9 @@ def __init__(
# these over HTTP exposes the data to interception and unauthorized
# access. Always use HTTPS to ensure secure communication and protect
# user privacy.
if (
if self.__transport.base_url.startswith("http://") and (
required_authn_params or required_authz_tokens or client_headers
) and not self.__url.startswith("https://"):
):
warn(
"Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication."
)
Expand Down Expand Up @@ -184,8 +179,7 @@ def _client_headers(

def __copy(
self,
session: Optional[ClientSession] = None,
base_url: Optional[str] = None,
transport: Optional[ITransport] = None,
name: Optional[str] = None,
description: Optional[str] = None,
params: Optional[Sequence[ParameterSchema]] = None,
Expand All @@ -205,8 +199,7 @@ def __copy(
Creates a copy of the ToolboxTool, overriding specific fields.

Args:
session: The `aiohttp.ClientSession` used for making API requests.
base_url: The base URL of the Toolbox server API.
transport: The transport used for making API requests.
name: The name of the remote tool.
description: The description of the remote tool.
params: The args of the tool.
Expand All @@ -222,8 +215,7 @@ def __copy(
"""
check = lambda val, default: val if val is not None else default
return ToolboxTool(
session=check(session, self.__session),
base_url=check(base_url, self.__base_url),
transport=check(transport, self.__transport),
name=check(name, self.__name__),
description=check(description, self.__description),
params=check(params, self.__params),
Expand Down Expand Up @@ -304,15 +296,11 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
token_getter
)

async with self.__session.post(
self.__url,
json=payload,
headers=headers,
) as resp:
body = await resp.json()
if not resp.ok:
err = body.get("error", f"unexpected status from server: {resp.status}")
raise Exception(err)
body = await self.__transport.tool_invoke(
self.__name__,
payload,
headers,
)
return body.get("result", body)

def add_auth_token_getters(
Expand Down
81 changes: 81 additions & 0 deletions packages/toolbox-core/src/toolbox_core/toolbox_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Mapping, Optional

from aiohttp import ClientSession

from .itransport import ITransport
from .protocol import ManifestSchema


class ToolboxTransport(ITransport):
"""Transport for the native Toolbox protocol."""

def __init__(self, base_url: str, session: ClientSession, manage_session: bool):
self.__base_url = base_url
self.__session = session
self.__manage_session = manage_session

@property
def base_url(self) -> str:
"""The base URL for the transport."""
return self.__base_url

async def tool_get(
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
) -> ManifestSchema:
url = f"{self.__base_url}/api/tool/{tool_name}"
async with self.__session.get(url, headers=headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
return ManifestSchema(**json)

async def tools_list(
self,
toolset_name: Optional[str] = None,
headers: Optional[Mapping[str, str]] = None,
) -> ManifestSchema:
url = f"{self.__base_url}/api/toolset/{toolset_name or ''}"
async with self.__session.get(url, headers=headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
return ManifestSchema(**json)

async def tool_invoke(
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
) -> dict:
url = f"{self.__base_url}/api/tool/{tool_name}/invoke"
async with self.__session.post(
url,
json=arguments,
headers=headers,
) as resp:
body = await resp.json()
if not resp.ok:
err = body.get("error", f"unexpected status from server: {resp.status}")
raise Exception(err)
return body

async def close(self):
if self.__manage_session and not self.__session.closed:
await self.__session.close()
Loading