Skip to content

Commit adbc37f

Browse files
xuanyang15copybara-github
authored andcommitted
feat: Add progress_callback support to MCPTool and MCPToolset
Fixes: #3811 Co-authored-by: Xuan Yang <xygoogle@google.com> PiperOrigin-RevId: 866025995
1 parent 9b112e2 commit adbc37f

File tree

7 files changed

+695
-19
lines changed

7 files changed

+695
-19
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import agent
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Sample agent demonstrating MCP progress callback feature.
16+
17+
This sample shows how to use the progress_callback parameter in McpToolset
18+
to receive progress notifications from MCP servers during long-running tool
19+
executions.
20+
21+
There are two ways to use progress callbacks:
22+
23+
1. Simple callback (shared by all tools):
24+
Pass a ProgressFnT callback that receives (progress, total, message).
25+
26+
2. Factory function (per-tool callbacks with runtime context):
27+
Pass a ProgressCallbackFactory that takes (tool_name, callback_context, **kwargs)
28+
and returns a ProgressFnT or None. This allows different tools to have different
29+
progress handling logic, and the factory can access and modify session state
30+
via the CallbackContext. The **kwargs ensures forward compatibility for future
31+
parameters.
32+
33+
IMPORTANT: Progress callbacks only work when the MCP server actually sends
34+
progress notifications. Most simple MCP servers (like the filesystem server)
35+
do not send progress updates. This sample uses a mock server that demonstrates
36+
progress reporting.
37+
38+
Usage:
39+
adk run contributing/samples/mcp_progress_callback_agent
40+
41+
Then try:
42+
"Run the long running task with 5 steps"
43+
"Process these items: apple, banana, cherry"
44+
"""
45+
46+
import os
47+
import sys
48+
from typing import Any
49+
50+
from google.adk.agents.callback_context import CallbackContext
51+
from google.adk.agents.llm_agent import LlmAgent
52+
from google.adk.tools.mcp_tool import McpToolset
53+
from google.adk.tools.mcp_tool import StdioConnectionParams
54+
from mcp import StdioServerParameters
55+
from mcp.shared.session import ProgressFnT
56+
57+
_current_dir = os.path.dirname(os.path.abspath(__file__))
58+
_mock_server_path = os.path.join(_current_dir, "mock_progress_server.py")
59+
60+
61+
# Option 1: Simple shared callback
62+
async def simple_progress_callback(
63+
progress: float,
64+
total: float | None,
65+
message: str | None,
66+
) -> None:
67+
"""Handle progress notifications from MCP server.
68+
69+
This callback is shared by all tools in the toolset.
70+
"""
71+
if total is not None:
72+
percentage = (progress / total) * 100
73+
bar_length = 20
74+
filled = int(bar_length * progress / total)
75+
bar = "=" * filled + "-" * (bar_length - filled)
76+
print(f"[{bar}] {percentage:.0f}% ({progress}/{total}) {message or ''}")
77+
else:
78+
print(f"Progress: {progress} {f'- {message}' if message else ''}")
79+
80+
81+
# Option 2: Factory function for per-tool callbacks with runtime context
82+
def progress_callback_factory(
83+
tool_name: str,
84+
*,
85+
callback_context: CallbackContext | None = None,
86+
**kwargs: Any,
87+
) -> ProgressFnT | None:
88+
"""Create a progress callback for a specific tool.
89+
90+
This factory allows different tools to have different progress handling.
91+
It receives a CallbackContext for accessing and modifying runtime information
92+
like session state. The **kwargs parameter ensures forward compatibility.
93+
94+
Args:
95+
tool_name: The name of the MCP tool.
96+
callback_context: The callback context providing access to session,
97+
state, artifacts, and other runtime information. Allows modifying
98+
state via ctx.state['key'] = value. May be None if not available.
99+
**kwargs: Additional keyword arguments for future extensibility.
100+
101+
Returns:
102+
A progress callback function, or None if no callback is needed.
103+
"""
104+
# Example: Access session info from context (if available)
105+
session_id = "unknown"
106+
if callback_context and callback_context.session:
107+
session_id = callback_context.session.id
108+
109+
async def callback(
110+
progress: float,
111+
total: float | None,
112+
message: str | None,
113+
) -> None:
114+
# Include tool name and session info in the progress output
115+
prefix = f"[{tool_name}][session:{session_id}]"
116+
if total is not None:
117+
percentage = (progress / total) * 100
118+
bar_length = 20
119+
filled = int(bar_length * progress / total)
120+
bar = "=" * filled + "-" * (bar_length - filled)
121+
print(f"{prefix} [{bar}] {percentage:.0f}% {message or ''}")
122+
# Example: Store progress in state (callback_context allows modification)
123+
if callback_context:
124+
callback_context.state["last_progress"] = progress
125+
callback_context.state["last_total"] = total
126+
else:
127+
print(
128+
f"{prefix} Progress: {progress} {f'- {message}' if message else ''}"
129+
)
130+
131+
return callback
132+
133+
134+
root_agent = LlmAgent(
135+
model="gemini-2.5-flash",
136+
name="progress_demo_agent",
137+
instruction="""\
138+
You are a helpful assistant that can run long-running tasks.
139+
140+
Available tools:
141+
- long_running_task: Simulates a task with multiple steps. You can specify
142+
the number of steps and delay between them.
143+
- process_items: Processes a list of items one by one with progress updates.
144+
145+
When the user asks you to run a task, use these tools and the progress
146+
will be logged automatically.
147+
148+
Example requests:
149+
- "Run a long task with 5 steps"
150+
- "Process these items: apple, banana, cherry, date"
151+
""",
152+
tools=[
153+
McpToolset(
154+
connection_params=StdioConnectionParams(
155+
server_params=StdioServerParameters(
156+
command=sys.executable, # Use current Python interpreter
157+
args=[_mock_server_path],
158+
),
159+
timeout=60,
160+
),
161+
# Use factory function for per-tool callbacks (Option 2)
162+
# Or use simple_progress_callback for shared callback (Option 1)
163+
progress_callback=progress_callback_factory,
164+
)
165+
],
166+
)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Mock MCP server that sends progress notifications.
16+
17+
This server demonstrates how MCP servers can send progress updates
18+
during long-running tool execution.
19+
20+
Run this server directly:
21+
python mock_progress_server.py
22+
23+
Or use it with the sample agent:
24+
See agent_with_mock_server.py
25+
"""
26+
27+
import asyncio
28+
29+
from mcp.server import Server
30+
from mcp.server.stdio import stdio_server
31+
from mcp.types import TextContent
32+
from mcp.types import Tool
33+
34+
server = Server("mock-progress-server")
35+
36+
37+
@server.list_tools()
38+
async def list_tools() -> list[Tool]:
39+
"""List available tools."""
40+
return [
41+
Tool(
42+
name="long_running_task",
43+
description=(
44+
"A simulated long-running task that reports progress. "
45+
"Use this to test progress callback functionality."
46+
),
47+
inputSchema={
48+
"type": "object",
49+
"properties": {
50+
"steps": {
51+
"type": "integer",
52+
"description": "Number of steps to simulate (default: 5)",
53+
"default": 5,
54+
},
55+
"delay": {
56+
"type": "number",
57+
"description": (
58+
"Delay in seconds between steps (default: 0.5)"
59+
),
60+
"default": 0.5,
61+
},
62+
},
63+
},
64+
),
65+
Tool(
66+
name="process_items",
67+
description="Process a list of items with progress reporting.",
68+
inputSchema={
69+
"type": "object",
70+
"properties": {
71+
"items": {
72+
"type": "array",
73+
"items": {"type": "string"},
74+
"description": "List of items to process",
75+
},
76+
},
77+
"required": ["items"],
78+
},
79+
),
80+
]
81+
82+
83+
@server.call_tool()
84+
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
85+
"""Handle tool calls with progress reporting."""
86+
ctx = server.request_context
87+
88+
if name == "long_running_task":
89+
steps = arguments.get("steps", 5)
90+
delay = arguments.get("delay", 0.5)
91+
92+
# Get progress token from request metadata
93+
progress_token = None
94+
if ctx.meta and hasattr(ctx.meta, "progressToken"):
95+
progress_token = ctx.meta.progressToken
96+
97+
for i in range(steps):
98+
# Simulate work
99+
await asyncio.sleep(delay)
100+
101+
# Send progress notification if client supports it
102+
if progress_token is not None:
103+
await ctx.session.send_progress_notification(
104+
progress_token=progress_token,
105+
progress=i + 1,
106+
total=steps,
107+
message=f"Completed step {i + 1} of {steps}",
108+
)
109+
110+
return [
111+
TextContent(
112+
type="text",
113+
text=f"Successfully completed {steps} steps!",
114+
)
115+
]
116+
117+
elif name == "process_items":
118+
items = arguments.get("items", [])
119+
total = len(items)
120+
121+
progress_token = None
122+
if ctx.meta and hasattr(ctx.meta, "progressToken"):
123+
progress_token = ctx.meta.progressToken
124+
125+
results = []
126+
for i, item in enumerate(items):
127+
# Simulate processing
128+
await asyncio.sleep(0.3)
129+
results.append(f"Processed: {item}")
130+
131+
# Send progress
132+
if progress_token is not None:
133+
await ctx.session.send_progress_notification(
134+
progress_token=progress_token,
135+
progress=i + 1,
136+
total=total,
137+
message=f"Processing item: {item}",
138+
)
139+
140+
return [
141+
TextContent(
142+
type="text",
143+
text="\n".join(results),
144+
)
145+
]
146+
147+
return [TextContent(type="text", text=f"Unknown tool: {name}")]
148+
149+
150+
async def main():
151+
"""Run the MCP server."""
152+
async with stdio_server() as (read_stream, write_stream):
153+
await server.run(
154+
read_stream,
155+
write_stream,
156+
server.create_initialization_options(),
157+
)
158+
159+
160+
if __name__ == "__main__":
161+
asyncio.run(main())

0 commit comments

Comments
 (0)