diff --git a/packages/toolbox-core/integration.cloudbuild.yaml b/packages/toolbox-core/integration.cloudbuild.yaml index 325abd6c..a30c66ab 100644 --- a/packages/toolbox-core/integration.cloudbuild.yaml +++ b/packages/toolbox-core/integration.cloudbuild.yaml @@ -47,5 +47,5 @@ options: logging: CLOUD_LOGGING_ONLY substitutions: _VERSION: '3.13' - _TOOLBOX_VERSION: '0.16.0' + _TOOLBOX_VERSION: '0.17.0' _TOOLBOX_MANIFEST_VERSION: '34' diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/base.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/base.py index c4263602..a8887d07 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/base.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/base.py @@ -58,6 +58,24 @@ def base_url(self) -> str: return self._mcp_base_url def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: + """ + Safely converts the raw tool dictionary from the server into a ToolSchema object, + robustly handling optional authentication metadata. + """ + param_auth = None + invoke_auth = [] + + if "_meta" in tool_data and isinstance(tool_data["_meta"], dict): + meta = tool_data["_meta"] + if "toolbox/authParam" in meta and isinstance( + meta["toolbox/authParam"], dict + ): + param_auth = meta["toolbox/authParam"] + if "toolbox/authInvoke" in meta and isinstance( + meta["toolbox/authInvoke"], list + ): + invoke_auth = meta["toolbox/authInvoke"] + parameters = [] input_schema = tool_data.get("inputSchema", {}) properties = input_schema.get("properties", {}) @@ -71,6 +89,10 @@ def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: ) else: additional_props = True + if param_auth and name in param_auth: + auth_sources = param_auth[name] + else: + auth_sources = None parameters.append( ParameterSchema( name=name, @@ -78,10 +100,15 @@ def _convert_tool_schema(self, tool_data: dict) -> ToolSchema: description=schema.get("description", ""), required=name in required, additionalProperties=additional_props, + authSources=auth_sources, ) ) - return ToolSchema(description=tool_data["description"], parameters=parameters) + return ToolSchema( + description=tool_data["description"], + parameters=parameters, + authRequired=invoke_auth, + ) async def _list_tools( self, diff --git a/packages/toolbox-core/tests/mcp_transport/test_base.py b/packages/toolbox-core/tests/mcp_transport/test_base.py index 33fccb4b..48f0abd4 100644 --- a/packages/toolbox-core/tests/mcp_transport/test_base.py +++ b/packages/toolbox-core/tests/mcp_transport/test_base.py @@ -14,14 +14,14 @@ import asyncio from typing import Any -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio from aiohttp import ClientSession from toolbox_core.mcp_transport.base import _McpHttpTransportBase -from toolbox_core.protocol import ManifestSchema +from toolbox_core.protocol import ManifestSchema, ToolSchema class ConcreteTransport(_McpHttpTransportBase): @@ -161,6 +161,19 @@ def test_convert_tool_schema(self, transport): assert location_param.required is True assert location_param.description == "The city." + def test_convert_tool_schema_with_auth(self, transport): + """Test schema conversion with authentication metadata.""" + tool_data = { + "name": "drive_tool", + "description": "A tool that requires auth.", + "inputSchema": {"type": "object", "properties": {}}, + "_meta": { + "toolbox/authInvoke": ["google"], + }, + } + tool_schema = transport._convert_tool_schema(tool_data) + assert tool_schema.authRequired == ["google"] + @pytest.mark.asyncio async def test_tools_list_success(self, transport): transport._server_version = "1.0.0" diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py index d4c64e21..9680aa87 100644 --- a/packages/toolbox-core/tests/test_e2e_mcp.py +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -133,6 +133,99 @@ async def test_bind_params_callable( assert "row4" not in response +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +class TestAuth: + async def test_run_tool_unauth_with_auth( + self, toolbox: ToolboxClient, auth_token2: str + ): + """Tests running a tool that doesn't require auth, with auth provided.""" + + with pytest.raises( + ValueError, + match=rf"Validation failed for tool 'get-row-by-id': unused auth tokens: my-test-auth", + ): + await toolbox.load_tool( + "get-row-by-id", + auth_token_getters={"my-test-auth": lambda: auth_token2}, + ) + + async def test_run_tool_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool requiring auth without providing auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool(id="2") + + async def test_run_tool_wrong_auth(self, toolbox: ToolboxClient, auth_token2: str): + """Tests running a tool with incorrect auth. The tool + requires a different authentication than the one provided.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token2}) + with pytest.raises( + Exception, + match="Unauthorized", + ): + await auth_tool(id="2") + + async def test_run_tool_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + auth_tool = tool.add_auth_token_getters({"my-test-auth": lambda: auth_token1}) + response = await auth_tool(id="2") + assert "row2" in response + + @pytest.mark.asyncio + async def test_run_tool_async_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with correct auth using an async token getter.""" + tool = await toolbox.load_tool("get-row-by-id-auth") + + async def get_token_asynchronously(): + return auth_token1 + + auth_tool = tool.add_auth_token_getters( + {"my-test-auth": get_token_asynchronously} + ) + response = await auth_tool(id="2") + assert "row2" in response + + async def test_run_tool_param_auth_no_auth(self, toolbox: ToolboxClient): + """Tests running a tool with a param requiring auth, without auth.""" + tool = await toolbox.load_tool("get-row-by-email-auth") + with pytest.raises( + PermissionError, + match="One or more of the following authn services are required to invoke this tool: my-test-auth", + ): + await tool() + + async def test_run_tool_param_auth(self, toolbox: ToolboxClient, auth_token1: str): + """Tests running a tool with a param requiring auth, with correct auth.""" + tool = await toolbox.load_tool( + "get-row-by-email-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + response = await tool() + assert "row4" in response + assert "row5" in response + assert "row6" in response + + async def test_run_tool_param_auth_no_field( + self, toolbox: ToolboxClient, auth_token1: str + ): + """Tests running a tool with a param requiring auth, with insufficient auth.""" + tool = await toolbox.load_tool( + "get-row-by-content-auth", + auth_token_getters={"my-test-auth": lambda: auth_token1}, + ) + with pytest.raises( + Exception, + match="no field named row_data in claims", + ): + await tool() + + @pytest.mark.asyncio @pytest.mark.usefixtures("toolbox_server") class TestOptionalParams: