Skip to content

Commit bef1d72

Browse files
committed
Further updates, create lowlevel test for result caching and improve test logic for higher level session join test (currently skipped due to subtle bug in test)
1 parent ea0048c commit bef1d72

File tree

4 files changed

+212
-66
lines changed

4 files changed

+212
-66
lines changed

src/mcp/server/lowlevel/result_cache.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class InProgress:
2626
user: AuthenticatedUser | None = None
2727
future: Future[types.CallToolResult] | None = None
2828
sessions: dict[int, ServerSession] = field(default_factory=lambda: {})
29-
29+
session_progress: dict[int, types.ProgressToken | None] = field(default_factory=lambda: {})
3030

3131
class ResultCache:
3232
"""
@@ -72,7 +72,7 @@ async def __aexit__(
7272
) -> bool | None:
7373
await anyio.to_thread.run_sync(lambda: self._portal_provider.__exit__)
7474

75-
async def add_call(
75+
async def start_call(
7676
self,
7777
call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]],
7878
req: types.CallToolAsyncRequest,
@@ -101,6 +101,7 @@ async def call_tool():
101101

102102
in_progress.user = user_context.get()
103103
in_progress.sessions[id(ctx.session)] = ctx.session
104+
in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken
104105
self._session_lookup[id(ctx.session)] = in_progress.token
105106
in_progress.future = self._portal.start_task_soon(call_tool)
106107
result = types.CallToolAsyncResult(
@@ -129,6 +130,7 @@ async def join_call(
129130
logger.debug(f"Received join from {id(ctx.session)}")
130131
self._session_lookup[id(ctx.session)] = req.params.token
131132
in_progress.sessions[id(ctx.session)] = ctx.session
133+
in_progress.session_progress[id(ctx.session)] = None if req.params.meta is None else req.params.meta.progressToken
132134
return types.CallToolAsyncResult(token=req.params.token, accepted=True)
133135
else:
134136
# TODO consider sending error via get result
@@ -167,10 +169,15 @@ async def get_result(self, req: types.GetToolAsyncResultRequest):
167169
)
168170
else:
169171
# TODO add timeout to get async result
170-
# return isPending=True if timesout
171-
result = in_progress.future.result()
172-
logger.debug(f"Found result {result}")
173-
return result
172+
try:
173+
result = in_progress.future.result(1)
174+
logger.debug(f"Found result {result}")
175+
return result
176+
except TimeoutError:
177+
return types.CallToolResult(
178+
content=[],
179+
isPending=True,
180+
)
174181
else:
175182
return types.CallToolResult(
176183
content=[types.TextContent(type="text", text="Permission denied")],
@@ -180,6 +187,7 @@ async def get_result(self, req: types.GetToolAsyncResultRequest):
180187
async def notification_hook(
181188
self, session: ServerSession, notification: types.ServerNotification
182189
):
190+
logger.debug(f"received {notification} from {id(session)}")
183191
if type(notification.root) is types.ProgressNotification:
184192
# async with self._lock:
185193
async_token = self._session_lookup.get(id(session))
@@ -196,8 +204,12 @@ async def notification_hook(
196204
logger.debug(f"Checking {session_id} == {id(session)}")
197205
if not session_id == id(session):
198206
logger.debug(f"Sending progress to {id(other_session)}")
207+
progress_token = in_progress.session_progress.get(id(other_session))
208+
assert progress_token is not None
199209
await other_session.send_progress_notification(
200-
progress_token=1,
210+
# TODO this token is incorrect
211+
# it needs to be collected from original request
212+
progress_token=progress_token,
201213
progress=notification.root.params.progress,
202214
total=notification.root.params.total,
203215
message=notification.root.params.message,

src/mcp/server/lowlevel/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ async def handler(req: types.CallToolRequest):
432432

433433
async def async_call_handler(req: types.CallToolAsyncRequest):
434434
ctx = request_ctx.get()
435-
result = await self.result_cache.add_call(handler, req, ctx)
435+
result = await self.result_cache.start_call(handler, req, ctx)
436436
return types.ServerResult(result)
437437

438438
async def async_join_handler(req: types.JoinCallToolAsyncRequest):
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import pytest
2+
from mcp import types
3+
from mcp.server.lowlevel.result_cache import ResultCache
4+
from unittest.mock import AsyncMock, Mock, patch
5+
from contextlib import AsyncExitStack
6+
7+
@pytest.mark.anyio
8+
async def test_async_call():
9+
"""Tests basic async call"""
10+
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
11+
return types.ServerResult(types.CallToolResult(
12+
content=[types.TextContent(
13+
type="text",
14+
text="test"
15+
)]
16+
))
17+
async_call = types.CallToolAsyncRequest(
18+
method="tools/async/call",
19+
params=types.CallToolAsyncRequestParams(
20+
name="test"
21+
)
22+
)
23+
24+
mock_session = AsyncMock()
25+
mock_context = Mock()
26+
mock_context.session = mock_session
27+
result_cache = ResultCache(max_size=1, max_keep_alive=1)
28+
async with AsyncExitStack() as stack:
29+
await stack.enter_async_context(result_cache)
30+
async_call_ref = await result_cache.start_call(test_call, async_call, mock_context)
31+
assert async_call_ref.token is not None
32+
33+
result = await result_cache.get_result(types.GetToolAsyncResultRequest(
34+
method="tools/async/get",
35+
params=types.GetToolAsyncResultRequestParams(
36+
token = async_call_ref.token
37+
)
38+
))
39+
40+
assert not result.isError
41+
assert not result.isPending
42+
assert len(result.content) == 1
43+
assert type(result.content[0]) is types.TextContent
44+
assert result.content[0].text == "test"
45+
46+
@pytest.mark.anyio
47+
async def test_async_join_call_progress():
48+
"""Tests basic async call"""
49+
async def test_call(call: types.CallToolRequest) -> types.ServerResult:
50+
return types.ServerResult(types.CallToolResult(
51+
content=[types.TextContent(
52+
type="text",
53+
text="test"
54+
)]
55+
))
56+
async_call = types.CallToolAsyncRequest(
57+
method="tools/async/call",
58+
params=types.CallToolAsyncRequestParams(
59+
name="test"
60+
)
61+
)
62+
63+
mock_session_1 = AsyncMock()
64+
mock_context_1 = Mock()
65+
mock_context_1.session = mock_session_1
66+
67+
mock_session_2 = AsyncMock()
68+
mock_context_2 = Mock()
69+
70+
mock_context_2.session = mock_session_2
71+
mock_session_2.send_progress_notification.result = None
72+
73+
result_cache = ResultCache(max_size=1, max_keep_alive=1)
74+
async with AsyncExitStack() as stack:
75+
await stack.enter_async_context(result_cache)
76+
async_call_ref = await result_cache.start_call(test_call, async_call, mock_context_1)
77+
assert async_call_ref.token is not None
78+
79+
await result_cache.join_call(
80+
req=types.JoinCallToolAsyncRequest(
81+
method="tools/async/join",
82+
params=types.JoinCallToolRequestParams(
83+
token=async_call_ref.token,
84+
_meta = types.RequestParams.Meta(
85+
progressToken="test"
86+
)
87+
)
88+
),
89+
ctx=mock_context_2
90+
)
91+
assert async_call_ref.token is not None
92+
await result_cache.notification_hook(
93+
session=mock_session_1,
94+
notification=types.ServerNotification(types.ProgressNotification(
95+
method="notifications/progress",
96+
params=types.ProgressNotificationParams(
97+
progressToken="test",
98+
progress=1
99+
)
100+
)))
101+
102+
result = await result_cache.get_result(types.GetToolAsyncResultRequest(
103+
method="tools/async/get",
104+
params=types.GetToolAsyncResultRequestParams(
105+
token = async_call_ref.token
106+
)
107+
))
108+
109+
assert not result.isError
110+
assert not result.isPending
111+
assert len(result.content) == 1
112+
assert type(result.content[0]) is types.TextContent
113+
assert result.content[0].text == "test"
114+
mock_context_1.send_progress_notification.assert_not_called()
115+
mock_session_2.send_progress_notification.assert_called_with(
116+
progress_token="test",
117+
progress=1.0,
118+
total=None,
119+
message=None,
120+
resource_uri = None
121+
)

0 commit comments

Comments
 (0)