Skip to content

Commit 8a31612

Browse files
cornellgitcopybara-github
authored andcommitted
feat(skill): Add BashTool
PiperOrigin-RevId: 875899505
1 parent ebbc114 commit 8a31612

File tree

2 files changed

+379
-0
lines changed

2 files changed

+379
-0
lines changed

src/google/adk/tools/bash_tool.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tool to execute bash commands."""
16+
17+
from __future__ import annotations
18+
19+
import dataclasses
20+
import pathlib
21+
import shlex
22+
import subprocess
23+
from typing import Any
24+
from typing import Optional
25+
26+
from google.genai import types
27+
28+
from .. import features
29+
from .base_tool import BaseTool
30+
from .tool_context import ToolContext
31+
32+
33+
@dataclasses.dataclass(frozen=True)
34+
class BashToolPolicy:
35+
"""Configuration for allowed bash commands based on prefix matching.
36+
37+
Set allowed_command_prefixes to ("*",) to allow all commands (default),
38+
or explicitly list allowed prefixes.
39+
"""
40+
41+
allowed_command_prefixes: tuple[str, ...] = ("*",)
42+
43+
44+
def _validate_command(command: str, policy: BashToolPolicy) -> Optional[str]:
45+
"""Validates a bash command against the permitted prefixes."""
46+
stripped = command.strip()
47+
if not stripped:
48+
return "Command is required."
49+
50+
if "*" in policy.allowed_command_prefixes:
51+
return None
52+
53+
for prefix in policy.allowed_command_prefixes:
54+
if stripped.startswith(prefix):
55+
return None
56+
57+
allowed = ", ".join(policy.allowed_command_prefixes)
58+
return f"Command blocked. Permitted prefixes are: {allowed}"
59+
60+
61+
@features.experimental(features.FeatureName.SKILL_TOOLSET)
62+
class ExecuteBashTool(BaseTool):
63+
"""Tool to execute a validated bash command within a workspace directory."""
64+
65+
def __init__(
66+
self,
67+
*,
68+
workspace: pathlib.Path | None = None,
69+
policy: Optional[BashToolPolicy] = None,
70+
):
71+
if workspace is None:
72+
workspace = pathlib.Path.cwd()
73+
policy = policy or BashToolPolicy()
74+
allowed_hint = (
75+
"any command"
76+
if "*" in policy.allowed_command_prefixes
77+
else (
78+
"commands matching prefixes:"
79+
f" {', '.join(policy.allowed_command_prefixes)}"
80+
)
81+
)
82+
super().__init__(
83+
name="execute_bash",
84+
description=(
85+
"Executes a bash command with the working directory set to the"
86+
f" workspace. Allowed: {allowed_hint}. All commands require user"
87+
" confirmation."
88+
),
89+
)
90+
self._workspace = workspace
91+
self._policy = policy
92+
93+
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
94+
return types.FunctionDeclaration(
95+
name=self.name,
96+
description=self.description,
97+
parameters_json_schema={
98+
"type": "object",
99+
"properties": {
100+
"command": {
101+
"type": "string",
102+
"description": "The bash command to execute.",
103+
},
104+
},
105+
"required": ["command"],
106+
},
107+
)
108+
109+
async def run_async(
110+
self, *, args: dict[str, Any], tool_context: ToolContext
111+
) -> Any:
112+
command = args.get("command")
113+
if not command:
114+
return {"error": "Command is required."}
115+
116+
# Static validation.
117+
error = _validate_command(command, self._policy)
118+
if error:
119+
return {"error": error}
120+
121+
# Always request user confirmation.
122+
if not tool_context.tool_confirmation:
123+
tool_context.request_confirmation(
124+
hint=f"Please approve or reject the bash command: {command}",
125+
)
126+
tool_context.actions.skip_summarization = True
127+
return {
128+
"error": (
129+
"This tool call requires confirmation, please approve or reject."
130+
)
131+
}
132+
elif not tool_context.tool_confirmation.confirmed:
133+
return {"error": "This tool call is rejected."}
134+
135+
try:
136+
result = subprocess.run(
137+
shlex.split(command),
138+
shell=False,
139+
cwd=str(self._workspace),
140+
capture_output=True,
141+
text=True,
142+
timeout=30,
143+
)
144+
return {
145+
"stdout": result.stdout,
146+
"stderr": result.stderr,
147+
"returncode": result.returncode,
148+
}
149+
except subprocess.TimeoutExpired:
150+
return {"error": "Command timed out after 30 seconds."}
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
from google.adk.tools import bash_tool
18+
from google.adk.tools import tool_context
19+
from google.adk.tools.tool_confirmation import ToolConfirmation
20+
import pytest
21+
22+
23+
@pytest.fixture
24+
def workspace(tmp_path):
25+
"""Creates a workspace mirroring the anthropics/skills PDF skill layout."""
26+
# Skill: pdf/
27+
skill_dir = tmp_path / "pdf"
28+
skill_dir.mkdir()
29+
(skill_dir / "SKILL.md").write_text(
30+
"---\nname: pdf\n"
31+
"description: Use this skill whenever the user wants to do"
32+
" anything with PDF files.\n"
33+
"---\n# PDF Processing Guide\n\n## Overview\n"
34+
"This guide covers PDF processing operations."
35+
)
36+
scripts = skill_dir / "scripts"
37+
scripts.mkdir()
38+
(scripts / "extract_form_structure.py").write_text(
39+
"import sys; print(f'extracting from {sys.argv[1]}')"
40+
)
41+
(scripts / "fill_pdf_form_with_annotations.py").write_text(
42+
"print('filling form')"
43+
)
44+
references = skill_dir / "references"
45+
references.mkdir()
46+
(references / "REFERENCE.md").write_text("# Reference\nDetailed docs.")
47+
# A loose file at workspace root (not inside a skill).
48+
(tmp_path / "sample.pdf").write_bytes(b"%PDF-1.4 fake")
49+
return tmp_path
50+
51+
52+
@pytest.fixture
53+
def tool_context_no_confirmation():
54+
"""ToolContext with no confirmation (initial call)."""
55+
ctx = mock.create_autospec(tool_context.ToolContext, instance=True)
56+
ctx.tool_confirmation = None
57+
ctx.actions = mock.MagicMock()
58+
return ctx
59+
60+
61+
@pytest.fixture
62+
def tool_context_confirmed():
63+
"""ToolContext with confirmation approved."""
64+
ctx = mock.create_autospec(tool_context.ToolContext, instance=True)
65+
confirmation = mock.create_autospec(ToolConfirmation, instance=True)
66+
confirmation.confirmed = True
67+
ctx.tool_confirmation = confirmation
68+
ctx.actions = mock.MagicMock()
69+
return ctx
70+
71+
72+
@pytest.fixture
73+
def tool_context_rejected():
74+
"""ToolContext with confirmation rejected."""
75+
ctx = mock.create_autospec(tool_context.ToolContext, instance=True)
76+
confirmation = mock.create_autospec(ToolConfirmation, instance=True)
77+
confirmation.confirmed = False
78+
ctx.tool_confirmation = confirmation
79+
ctx.actions = mock.MagicMock()
80+
return ctx
81+
82+
83+
# --- _validate_command tests ---
84+
85+
86+
class TestValidateCommand:
87+
88+
def test_empty_command(self):
89+
policy = bash_tool.BashToolPolicy()
90+
assert bash_tool._validate_command("", policy) is not None
91+
assert bash_tool._validate_command(" ", policy) is not None
92+
93+
def test_default_policy_allows_everything(self):
94+
policy = bash_tool.BashToolPolicy()
95+
assert bash_tool._validate_command("rm -rf /", policy) is None
96+
assert bash_tool._validate_command("cat /etc/passwd", policy) is None
97+
assert bash_tool._validate_command("sudo curl", policy) is None
98+
99+
def test_restricted_policy_allows_prefixes(self):
100+
policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls", "cat"))
101+
assert bash_tool._validate_command("ls -la", policy) is None
102+
assert bash_tool._validate_command("cat file.txt", policy) is None
103+
104+
def test_restricted_policy_blocks_others(self):
105+
policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls", "cat"))
106+
assert bash_tool._validate_command("rm -rf .", policy) is not None
107+
assert bash_tool._validate_command("tree", policy) is not None
108+
assert "Permitted prefixes are: ls, cat" in bash_tool._validate_command(
109+
"tree", policy
110+
)
111+
112+
113+
class TestExecuteBashTool:
114+
115+
@pytest.mark.asyncio
116+
async def test_requests_confirmation(
117+
self, workspace, tool_context_no_confirmation
118+
):
119+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
120+
result = await tool.run_async(
121+
args={"command": "ls"},
122+
tool_context=tool_context_no_confirmation,
123+
)
124+
assert "error" in result
125+
assert "requires confirmation" in result["error"]
126+
tool_context_no_confirmation.request_confirmation.assert_called_once()
127+
128+
@pytest.mark.asyncio
129+
async def test_rejected(self, workspace, tool_context_rejected):
130+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
131+
result = await tool.run_async(
132+
args={"command": "ls"}, tool_context=tool_context_rejected
133+
)
134+
assert result == {"error": "This tool call is rejected."}
135+
136+
@pytest.mark.asyncio
137+
async def test_executes_when_confirmed(
138+
self, workspace, tool_context_confirmed
139+
):
140+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
141+
result = await tool.run_async(
142+
args={"command": "ls"},
143+
tool_context=tool_context_confirmed,
144+
)
145+
assert result["returncode"] == 0
146+
assert "pdf" in result["stdout"]
147+
148+
@pytest.mark.asyncio
149+
async def test_cat_skill_md(self, workspace, tool_context_confirmed):
150+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
151+
result = await tool.run_async(
152+
args={"command": "cat pdf/SKILL.md"},
153+
tool_context=tool_context_confirmed,
154+
)
155+
assert "PDF Processing Guide" in result["stdout"]
156+
157+
@pytest.mark.asyncio
158+
async def test_python_script(self, workspace, tool_context_confirmed):
159+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
160+
result = await tool.run_async(
161+
args={
162+
"command": "python3 pdf/scripts/extract_form_structure.py test.pdf"
163+
},
164+
tool_context=tool_context_confirmed,
165+
)
166+
assert "extracting from test.pdf" in result["stdout"]
167+
assert result["returncode"] == 0
168+
169+
@pytest.mark.asyncio
170+
async def test_blocks_disallowed_by_policy(
171+
self, workspace, tool_context_no_confirmation
172+
):
173+
policy = bash_tool.BashToolPolicy(allowed_command_prefixes=("ls",))
174+
tool = bash_tool.ExecuteBashTool(workspace=workspace, policy=policy)
175+
result = await tool.run_async(
176+
args={"command": "rm -rf ."},
177+
tool_context=tool_context_no_confirmation,
178+
)
179+
assert "error" in result
180+
assert "Permitted prefixes are: ls" in result["error"]
181+
tool_context_no_confirmation.request_confirmation.assert_not_called()
182+
183+
@pytest.mark.asyncio
184+
async def test_captures_stderr(self, workspace, tool_context_confirmed):
185+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
186+
result = await tool.run_async(
187+
args={"command": "python3 -c 'import sys; sys.stderr.write(\"err\")'"},
188+
tool_context=tool_context_confirmed,
189+
)
190+
assert "err" in result["stderr"]
191+
192+
@pytest.mark.asyncio
193+
async def test_nonzero_returncode(self, workspace, tool_context_confirmed):
194+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
195+
result = await tool.run_async(
196+
args={"command": "python3 -c 'exit(42)'"},
197+
tool_context=tool_context_confirmed,
198+
)
199+
assert result["returncode"] == 42
200+
201+
@pytest.mark.asyncio
202+
async def test_timeout(self, workspace, tool_context_confirmed):
203+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
204+
with mock.patch(
205+
"google.adk.tools.bash_tool.subprocess.run",
206+
side_effect=__import__("subprocess").TimeoutExpired("cmd", 30),
207+
):
208+
result = await tool.run_async(
209+
args={"command": "python scripts/do_thing.py"},
210+
tool_context=tool_context_confirmed,
211+
)
212+
assert "error" in result
213+
assert "timed out" in result["error"].lower()
214+
215+
@pytest.mark.asyncio
216+
async def test_cwd_is_workspace(self, workspace, tool_context_confirmed):
217+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
218+
result = await tool.run_async(
219+
args={"command": "python3 -c 'import os; print(os.getcwd())'"},
220+
tool_context=tool_context_confirmed,
221+
)
222+
assert result["stdout"].strip() == str(workspace)
223+
224+
@pytest.mark.asyncio
225+
async def test_no_command(self, workspace, tool_context_confirmed):
226+
tool = bash_tool.ExecuteBashTool(workspace=workspace)
227+
result = await tool.run_async(args={}, tool_context=tool_context_confirmed)
228+
assert "error" in result
229+
assert "required" in result["error"].lower()

0 commit comments

Comments
 (0)