Skip to content

Commit 8698711

Browse files
committed
fix(tools): fail closed for high-risk tools without confirmation policy (#4625)
1 parent 3256a67 commit 8698711

File tree

5 files changed

+83
-0
lines changed

5 files changed

+83
-0
lines changed

src/google/adk/tools/function_tool.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(
4747
func: Callable[..., Any],
4848
*,
4949
require_confirmation: Union[bool, Callable[..., bool]] = False,
50+
is_high_risk: bool = False,
5051
):
5152
"""Initializes the FunctionTool. Extracts metadata from a callable object.
5253
@@ -56,6 +57,9 @@ def __init__(
5657
a callable that takes the function's arguments and returns a boolean. If
5758
the callable returns True, the tool will require confirmation from the
5859
user.
60+
is_high_risk: Whether the tool performs high-impact operations. High-risk
61+
tools fail closed unless an explicit confirmation policy resolves to
62+
`True`.
5963
"""
6064
name = ''
6165
doc = ''
@@ -82,6 +86,7 @@ def __init__(
8286
self.func = func
8387
self._ignore_params = ['tool_context', 'input_stream']
8488
self._require_confirmation = require_confirmation
89+
self._is_high_risk = is_high_risk
8590

8691
@override
8792
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
@@ -192,6 +197,15 @@ async def run_async(
192197
else:
193198
require_confirmation = bool(self._require_confirmation)
194199

200+
if self._is_high_risk and not require_confirmation:
201+
return {
202+
'error': (
203+
'This high-risk tool requires an explicit confirmation policy.'
204+
' Set require_confirmation=True or provide a callable policy'
205+
' that returns True.'
206+
)
207+
}
208+
195209
if require_confirmation:
196210
if not tool_context.tool_confirmation:
197211
args_to_show = args_to_call.copy()

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def __init__(
130130
auth_scheme: Optional[AuthScheme] = None,
131131
auth_credential: Optional[AuthCredential] = None,
132132
require_confirmation: Union[bool, Callable[..., bool]] = False,
133+
is_high_risk: bool = False,
133134
header_provider: Optional[
134135
Callable[[ReadonlyContext], Dict[str, str]]
135136
] = None,
@@ -151,6 +152,8 @@ def __init__(
151152
or a callable that takes the function's arguments and returns a
152153
boolean. If the callable returns True, the tool will require
153154
confirmation from the user.
155+
is_high_risk: Whether this tool is high-risk. High-risk tools fail
156+
closed unless an explicit confirmation policy resolves to `True`.
154157
header_provider: Optional function to provide dynamic headers.
155158
progress_callback: Optional callback to receive progress notifications
156159
from MCP server during long-running tool execution. Can be either:
@@ -178,6 +181,7 @@ def __init__(
178181
self._mcp_tool = mcp_tool
179182
self._mcp_session_manager = mcp_session_manager
180183
self._require_confirmation = require_confirmation
184+
self._is_high_risk = is_high_risk
181185
self._header_provider = header_provider
182186
self._progress_callback = progress_callback
183187

@@ -262,6 +266,15 @@ async def run_async(
262266
else:
263267
require_confirmation = bool(self._require_confirmation)
264268

269+
if self._is_high_risk and not require_confirmation:
270+
return {
271+
"error": (
272+
"This high-risk tool requires an explicit confirmation policy."
273+
" Set require_confirmation=True or provide a callable policy"
274+
" that returns True."
275+
)
276+
}
277+
265278
if require_confirmation:
266279
if not tool_context.tool_confirmation:
267280
args_to_show = args.copy()

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def __init__(
107107
auth_scheme: Optional[AuthScheme] = None,
108108
auth_credential: Optional[AuthCredential] = None,
109109
require_confirmation: Union[bool, Callable[..., bool]] = False,
110+
is_high_risk: bool = False,
110111
header_provider: Optional[
111112
Callable[[ReadonlyContext], Dict[str, str]]
112113
] = None,
@@ -136,6 +137,9 @@ def __init__(
136137
auth_credential: The auth credential of the tool for tool calling
137138
require_confirmation: Whether tools in this toolset require confirmation.
138139
Can be a single boolean or a callable to apply to all tools.
140+
is_high_risk: Whether tools from this toolset are high-risk. High-risk
141+
tools fail closed unless an explicit confirmation policy resolves to
142+
`True`.
139143
header_provider: A callable that takes a ReadonlyContext and returns a
140144
dictionary of headers to be used for the MCP session.
141145
progress_callback: Optional callback to receive progress notifications
@@ -170,6 +174,7 @@ def __init__(
170174
self._auth_scheme = auth_scheme
171175
self._auth_credential = auth_credential
172176
self._require_confirmation = require_confirmation
177+
self._is_high_risk = is_high_risk
173178
# Store auth config as instance variable so ADK can populate
174179
# exchanged_auth_credential in-place before calling get_tools()
175180
self._auth_config: Optional[AuthConfig] = (
@@ -316,6 +321,7 @@ async def get_tools(
316321
auth_scheme=self._auth_scheme,
317322
auth_credential=self._auth_credential,
318323
require_confirmation=self._require_confirmation,
324+
is_high_risk=self._is_high_risk,
319325
header_provider=self._header_provider,
320326
progress_callback=self._progress_callback
321327
if hasattr(self, "_progress_callback")

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,30 @@ async def test_run_async_require_confirmation_true_no_confirmation(self):
685685
}
686686
tool_context.request_confirmation.assert_called_once()
687687

688+
@pytest.mark.asyncio
689+
async def test_run_async_high_risk_without_confirmation_policy_fails_closed(
690+
self,
691+
):
692+
"""Test that high-risk MCP tools fail closed without explicit policy."""
693+
tool = MCPTool(
694+
mcp_tool=self.mock_mcp_tool,
695+
mcp_session_manager=self.mock_session_manager,
696+
is_high_risk=True,
697+
)
698+
tool_context = Mock(spec=ToolContext)
699+
tool_context.tool_confirmation = None
700+
args = {"param1": "test_value"}
701+
702+
result = await tool.run_async(args=args, tool_context=tool_context)
703+
704+
assert result == {
705+
"error": (
706+
"This high-risk tool requires an explicit confirmation policy. Set"
707+
" require_confirmation=True or provide a callable policy that"
708+
" returns True."
709+
)
710+
}
711+
688712
@pytest.mark.asyncio
689713
async def test_run_async_require_confirmation_true_rejected(self):
690714
"""Test require_confirmation=True with rejection in context."""

tests/unittests/tools/test_function_tool.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,32 @@ def sample_func(arg1: str):
417417
assert result == {"received_arg": "hello"}
418418

419419

420+
@pytest.mark.asyncio
421+
async def test_run_async_high_risk_without_confirmation_policy_fails_closed():
422+
"""Test that high-risk tools fail closed without explicit confirmation policy."""
423+
424+
def sample_func(arg1: str):
425+
return {"received_arg": arg1}
426+
427+
tool = FunctionTool(sample_func, is_high_risk=True)
428+
mock_invocation_context = MagicMock(spec=InvocationContext)
429+
mock_invocation_context.session = MagicMock(spec=Session)
430+
mock_invocation_context.session.state = MagicMock()
431+
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)
432+
433+
result = await tool.run_async(
434+
args={"arg1": "hello"},
435+
tool_context=tool_context_mock,
436+
)
437+
assert result == {
438+
"error": (
439+
"This high-risk tool requires an explicit confirmation policy. Set"
440+
" require_confirmation=True or provide a callable policy that returns"
441+
" True."
442+
)
443+
}
444+
445+
420446
@pytest.mark.asyncio
421447
async def test_run_async_parameter_filtering(mock_tool_context):
422448
"""Test that parameter filtering works correctly for functions with explicit parameters."""

0 commit comments

Comments
 (0)