Skip to content

Commit 52e066a

Browse files
committed
add progress notificaiton to the client
1 parent 36d89b1 commit 52e066a

File tree

4 files changed

+111
-44
lines changed

4 files changed

+111
-44
lines changed

src/mcp/client/session.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import mcp.types as types
99
from mcp.shared.context import RequestContext
1010
from mcp.shared.message import SessionMessage
11-
from mcp.shared.session import BaseSession, RequestResponder
11+
from mcp.shared.session import BaseSession, ProgressCallbackFnT, RequestResponder
1212
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1313

1414
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
@@ -270,18 +270,23 @@ async def call_tool(
270270
name: str,
271271
arguments: dict[str, Any] | None = None,
272272
read_timeout_seconds: timedelta | None = None,
273+
progress_callback: ProgressCallbackFnT | None = None,
273274
) -> types.CallToolResult:
274-
"""Send a tools/call request."""
275+
"""Send a tools/call request with optional progress callback support."""
275276

276277
return await self.send_request(
277278
types.ClientRequest(
278279
types.CallToolRequest(
279280
method="tools/call",
280-
params=types.CallToolRequestParams(name=name, arguments=arguments),
281+
params=types.CallToolRequestParams(
282+
name=name,
283+
arguments=arguments,
284+
),
281285
)
282286
),
283287
types.CallToolResult,
284288
request_read_timeout_seconds=read_timeout_seconds,
289+
progress_callback=progress_callback,
285290
)
286291

287292
async def list_prompts(self, cursor: str | None = None) -> types.ListPromptsResult:

src/mcp/server/fastmcp/server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,6 @@ async def report_progress(
963963
total: Optional total value e.g. 100
964964
message: Optional message e.g. Starting render...
965965
"""
966-
967966
progress_token = (
968967
self.request_context.meta.progressToken
969968
if self.request_context.meta

src/mcp/shared/session.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import AsyncExitStack
44
from datetime import timedelta
55
from types import TracebackType
6-
from typing import Any, Generic, TypeVar
6+
from typing import Any, Generic, Protocol, TypeVar
77

88
import anyio
99
import httpx
@@ -24,6 +24,7 @@
2424
JSONRPCNotification,
2525
JSONRPCRequest,
2626
JSONRPCResponse,
27+
ProgressNotification,
2728
RequestParams,
2829
ServerNotification,
2930
ServerRequest,
@@ -42,6 +43,22 @@
4243
RequestId = str | int
4344

4445

46+
class ProgressCallbackFnT(Protocol):
47+
"""Protocol for progress notification callbacks."""
48+
49+
def __call__(
50+
self, progress: float, total: float | None, message: str | None
51+
) -> None:
52+
"""Called when progress updates are received.
53+
54+
Args:
55+
progress: Current progress value
56+
total: Total progress value (if known), None if indeterminate
57+
message: Optional progress message
58+
"""
59+
...
60+
61+
4562
class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
4663
"""Handles responding to MCP requests and manages request lifecycle.
4764
@@ -169,6 +186,7 @@ class BaseSession(
169186
]
170187
_request_id: int
171188
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
189+
_progress_callbacks: dict[RequestId, ProgressCallbackFnT]
172190

173191
def __init__(
174192
self,
@@ -187,6 +205,7 @@ def __init__(
187205
self._receive_notification_type = receive_notification_type
188206
self._session_read_timeout_seconds = read_timeout_seconds
189207
self._in_flight = {}
208+
self._progress_callbacks = {}
190209
self._exit_stack = AsyncExitStack()
191210

192211
async def __aenter__(self) -> Self:
@@ -214,6 +233,7 @@ async def send_request(
214233
result_type: type[ReceiveResultT],
215234
request_read_timeout_seconds: timedelta | None = None,
216235
metadata: MessageMetadata = None,
236+
progress_callback: ProgressCallbackFnT | None = None,
217237
) -> ReceiveResultT:
218238
"""
219239
Sends a request and wait for a response. Raises an McpError if the
@@ -231,15 +251,25 @@ async def send_request(
231251
](1)
232252
self._response_streams[request_id] = response_stream
233253

254+
# Set up progress token if progress callback is provided
255+
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
256+
if progress_callback is not None:
257+
# Use request_id as progress token
258+
if "params" not in request_data:
259+
request_data["params"] = {}
260+
if "_meta" not in request_data["params"]:
261+
request_data["params"]["_meta"] = {}
262+
request_data["params"]["_meta"]["progressToken"] = request_id
263+
# Store the callback for this request
264+
self._progress_callbacks[request_id] = progress_callback
265+
234266
try:
235267
jsonrpc_request = JSONRPCRequest(
236268
jsonrpc="2.0",
237269
id=request_id,
238-
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
270+
**request_data,
239271
)
240272

241-
# TODO: Support progress callbacks
242-
243273
await self._write_stream.send(
244274
SessionMessage(
245275
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
@@ -275,6 +305,7 @@ async def send_request(
275305

276306
finally:
277307
self._response_streams.pop(request_id, None)
308+
self._progress_callbacks.pop(request_id, None)
278309
await response_stream.aclose()
279310
await response_stream_reader.aclose()
280311

@@ -333,7 +364,6 @@ async def _receive_loop(self) -> None:
333364
by_alias=True, mode="json", exclude_none=True
334365
)
335366
)
336-
337367
responder = RequestResponder(
338368
request_id=message.message.root.id,
339369
request_meta=validated_request.root.params.meta
@@ -362,6 +392,18 @@ async def _receive_loop(self) -> None:
362392
cancelled_id = notification.root.params.requestId
363393
if cancelled_id in self._in_flight:
364394
await self._in_flight[cancelled_id].cancel()
395+
# Handle progress notifications
396+
elif isinstance(notification.root, ProgressNotification):
397+
progress_token = notification.root.params.progressToken
398+
# If there is a progress callback for this token,
399+
# call it with the progress information
400+
if progress_token in self._progress_callbacks:
401+
callback = self._progress_callbacks[progress_token]
402+
callback(
403+
notification.root.params.progress,
404+
notification.root.params.total,
405+
notification.root.params.message,
406+
)
365407
else:
366408
await self._received_notification(notification)
367409
await self._handle_incoming(notification)

tests/server/fastmcp/test_integration.py

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,23 @@
1313
import anyio
1414
import pytest
1515
import uvicorn
16+
from pydantic import AnyUrl
1617

18+
import mcp.types as types
1719
from mcp.client.session import ClientSession
1820
from mcp.client.sse import sse_client
1921
from mcp.client.streamable_http import streamablehttp_client
2022
from mcp.server.fastmcp import FastMCP
21-
import mcp.types as types
22-
from mcp.types import InitializeResult, TextContent, TextResourceContents
23-
from pydantic import AnyUrl
23+
from mcp.shared.context import RequestContext
24+
from mcp.types import (
25+
CreateMessageRequestParams,
26+
CreateMessageResult,
27+
GetPromptResult,
28+
InitializeResult,
29+
ReadResourceResult,
30+
TextContent,
31+
TextResourceContents,
32+
)
2433

2534

2635
@pytest.fixture
@@ -625,15 +634,6 @@ async def test_fastmcp_all_features_sse(
625634
comprehensive_server: None, comprehensive_server_url: str
626635
) -> None:
627636
"""Test all MCP features work correctly with SSE transport."""
628-
from mcp.types import (
629-
GetPromptResult,
630-
ReadResourceResult,
631-
CreateMessageResult,
632-
CreateMessageRequestParams,
633-
SamplingMessage,
634-
TextContent,
635-
)
636-
from mcp.shared.context import RequestContext
637637

638638
# Create notification collector
639639
collector = NotificationCollector()
@@ -690,20 +690,38 @@ async def message_handler(message):
690690
assert tool_result.content[0].text == "Echo: hello"
691691

692692
# 2. Tool with context (logging and progress)
693-
# Test with progress token to capture progress notifications
693+
# Test progress callback functionality
694+
progress_updates = []
695+
696+
def progress_callback(
697+
progress: float, total: float | None, message: str | None
698+
) -> None:
699+
"""Collect progress updates for testing."""
700+
progress_updates.append((progress, total, message))
701+
print(f"Progress: {progress}/{total} - {message}")
702+
703+
params = {
704+
"message": "test",
705+
"steps": 3,
706+
}
694707
tool_result = await session.call_tool(
695708
"tool_with_context",
696-
{
697-
"message": "test",
698-
"steps": 3,
699-
"_meta": {"progressToken": "sse_test_token"},
700-
},
709+
params,
710+
progress_callback=progress_callback,
701711
)
702712
assert len(tool_result.content) == 1
703-
assert len(collector.progress_notifications) > 0
704713
assert isinstance(tool_result.content[0], TextContent)
705714
assert "Processed 'test' in 3 steps" in tool_result.content[0].text
706715

716+
# Verify progress callback was called
717+
assert len(progress_updates) == 3
718+
for i, (progress, total, message) in enumerate(progress_updates):
719+
expected_progress = (i + 1) / 3
720+
assert abs(progress - expected_progress) < 0.01
721+
assert total == 1.0
722+
assert message is not None
723+
assert f"step {i + 1} of 3" in message
724+
707725
# Verify we received log messages from the tool
708726
# Note: Progress notifications require special handling in the MCP client
709727
# that's not implemented by default, so we focus on testing logging
@@ -832,15 +850,6 @@ async def test_fastmcp_all_features_streamable_http(
832850
comprehensive_streamable_http_server: None, comprehensive_http_server_url: str
833851
) -> None:
834852
"""Test all MCP features work correctly with StreamableHTTP transport."""
835-
from mcp.types import (
836-
GetPromptResult,
837-
ReadResourceResult,
838-
CreateMessageResult,
839-
CreateMessageRequestParams,
840-
SamplingMessage,
841-
TextContent,
842-
)
843-
from mcp.shared.context import RequestContext
844853

845854
# Create notification collector
846855
collector = NotificationCollector()
@@ -902,21 +911,33 @@ async def message_handler(message):
902911
assert tool_result.content[0].text == "Echo: hello"
903912

904913
# 2. Tool with context (logging and progress)
905-
# Test with progress token to capture progress notifications
914+
# Test progress callback functionality over HTTP
915+
progress_updates_http = []
916+
917+
def progress_callback_http(
918+
progress: float, total: float | None, message: str | None
919+
) -> None:
920+
"""Collect progress updates for HTTP testing."""
921+
progress_updates_http.append((progress, total, message))
922+
print(f"HTTP Progress: {progress}/{total} - {message}")
923+
906924
await session.call_tool(
907925
"tool_with_context",
908926
{
909927
"message": "http_test",
910928
"steps": 2,
911929
},
930+
progress_callback=progress_callback_http,
912931
)
913932

914-
# Verify we received progress notifications
915-
assert len(collector.progress_notifications) > 0
916-
assert any(
917-
p.progressToken == "http_test_token"
918-
for p in collector.progress_notifications
919-
)
933+
# Verify progress callback was called over HTTP
934+
assert len(progress_updates_http) == 2
935+
for i, (progress, total, message) in enumerate(progress_updates_http):
936+
expected_progress = (i + 1) / 2
937+
assert abs(progress - expected_progress) < 0.01
938+
assert total == 1.0
939+
assert message is not None
940+
assert f"step {i + 1} of 2" in message
920941

921942
# 3. Test sampling tool
922943
sampling_result = await session.call_tool(

0 commit comments

Comments
 (0)