Skip to content

Commit ea0048c

Browse files
committed
Updates to support join
1 parent 646dd63 commit ea0048c

File tree

5 files changed

+427
-111
lines changed

5 files changed

+427
-111
lines changed
Lines changed: 161 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
from collections.abc import Awaitable, Callable
2+
from concurrent.futures import Future
23
from dataclasses import dataclass, field
34
from logging import getLogger
45
from time import time
6+
from types import TracebackType
57
from typing import Any
68
from uuid import uuid4
79

8-
from anyio import Lock, create_task_group, move_on_after
9-
from anyio.abc import TaskGroup
10-
from cachetools import TTLCache
10+
import anyio
11+
import anyio.to_thread
12+
from anyio.from_thread import BlockingPortal, BlockingPortalProvider
1113

1214
from mcp import types
1315
from mcp.server.auth.middleware.auth_context import auth_context_var as user_context
1416
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
15-
from mcp.shared.context import BaseSession, RequestContext, SessionT
17+
from mcp.server.session import ServerSession
18+
from mcp.shared.context import RequestContext
1619

1720
logger = getLogger(__name__)
1821

@@ -21,10 +24,8 @@
2124
class InProgress:
2225
token: str
2326
user: AuthenticatedUser | None = None
24-
task_group: TaskGroup | None = None
25-
sessions: list[BaseSession[Any, Any, Any, Any, Any]] = field(
26-
default_factory=lambda: []
27-
)
27+
future: Future[types.CallToolResult] | None = None
28+
sessions: dict[int, ServerSession] = field(default_factory=lambda: {})
2829

2930

3031
class ResultCache:
@@ -33,16 +34,11 @@ class ResultCache:
3334
Its purpose is to act as a central point for managing in progress
3435
async calls, allowing multiple clients to join and receive progress
3536
updates, get results and/or cancel in progress calls
36-
TODO CRITICAL properly support join nothing actually happens at the moment
37-
TODO CRITICAL intercept progress notifications from original session and
38-
pass to joined sessions
39-
TODO MAJOR handle session closure gracefully -
40-
at the moment old connections will hang around and cause problems later
37+
TODO CRITICAL keep_alive logic is not correct as per spec - results currently
38+
only kept for as long as longest session reintroduce TTL cache
4139
TODO MAJOR needs a lot more testing around edge cases/failure scenarios
42-
TODO MINOR keep_alive logic is not correct as per spec - results are
43-
cached for too long, probably better than too short
44-
TODO ENHANCEMENT might look into more fine grained locks, one global lock
45-
is a bottleneck though this could be delegated to other cache impls if external
40+
TODO MAJOR decide if async.Locks are required for integrity of internal
41+
data structures
4642
TODO ENHANCEMENT externalise cachetools to allow for other implementations
4743
e.g. redis etal for production scenarios
4844
TODO ENHANCEMENT may need to add an authorisation layer to decide if
@@ -52,119 +48,188 @@ class ResultCache:
5248
"""
5349

5450
_in_progress: dict[types.AsyncToken, InProgress]
51+
_session_lookup: dict[int, types.AsyncToken]
52+
_portal: BlockingPortal
5553

5654
def __init__(self, max_size: int, max_keep_alive: int):
5755
self._max_size = max_size
5856
self._max_keep_alive = max_keep_alive
59-
self._result_cache = TTLCache[types.AsyncToken, types.CallToolResult](
60-
self._max_size, self._max_keep_alive
61-
)
6257
self._in_progress = {}
63-
self._lock = Lock()
58+
self._session_lookup = {}
59+
self._portal_provider = BlockingPortalProvider()
60+
61+
async def __aenter__(self):
62+
def create_portal():
63+
self._portal = self._portal_provider.__enter__()
64+
65+
await anyio.to_thread.run_sync(create_portal)
66+
67+
async def __aexit__(
68+
self,
69+
exc_type: type[BaseException] | None,
70+
exc_val: BaseException | None,
71+
exc_tb: TracebackType | None,
72+
) -> bool | None:
73+
await anyio.to_thread.run_sync(lambda: self._portal_provider.__exit__)
6474

6575
async def add_call(
6676
self,
6777
call: Callable[[types.CallToolRequest], Awaitable[types.ServerResult]],
6878
req: types.CallToolAsyncRequest,
69-
ctx: RequestContext[SessionT, Any, Any],
79+
ctx: RequestContext[ServerSession, Any, Any],
7080
) -> types.CallToolAsyncResult:
7181
in_progress = await self._new_in_progress()
7282
timeout = min(
7383
req.params.keepAlive or self._max_keep_alive, self._max_keep_alive
7484
)
7585

7686
async def call_tool():
77-
with move_on_after(timeout) as scope:
78-
result = await call(
79-
types.CallToolRequest(
80-
method="tools/call",
81-
params=types.CallToolRequestParams(
82-
name=req.params.name, arguments=req.params.arguments
83-
),
84-
)
87+
result = await call(
88+
types.CallToolRequest(
89+
method="tools/call",
90+
params=types.CallToolRequestParams(
91+
name=req.params.name,
92+
arguments=req.params.arguments,
93+
_meta=req.params.meta,
94+
),
8595
)
86-
if not scope.cancel_called:
87-
async with self._lock:
88-
assert type(result.root) is types.CallToolResult
89-
self._result_cache[in_progress.token] = result.root
90-
91-
async with create_task_group() as tg:
92-
tg.start_soon(call_tool)
93-
in_progress.task_group = tg
94-
in_progress.user = user_context.get()
95-
in_progress.sessions.append(ctx.session)
96-
result = types.CallToolAsyncResult(
97-
token=in_progress.token,
98-
recieved=round(time()),
99-
keepAlive=timeout,
100-
accepted=True,
10196
)
102-
return result
97+
# async with self._lock:
98+
assert type(result.root) is types.CallToolResult
99+
logger.debug(f"Got result {result}")
100+
return result.root
101+
102+
in_progress.user = user_context.get()
103+
in_progress.sessions[id(ctx.session)] = ctx.session
104+
self._session_lookup[id(ctx.session)] = in_progress.token
105+
in_progress.future = self._portal.start_task_soon(call_tool)
106+
result = types.CallToolAsyncResult(
107+
token=in_progress.token,
108+
recieved=round(time()),
109+
keepAlive=timeout,
110+
accepted=True,
111+
)
112+
return result
103113

104114
async def join_call(
105115
self,
106116
req: types.JoinCallToolAsyncRequest,
107-
ctx: RequestContext[SessionT, Any, Any],
117+
ctx: RequestContext[ServerSession, Any, Any],
108118
) -> types.CallToolAsyncResult:
109-
async with self._lock:
110-
in_progress = self._in_progress.get(req.params.token)
111-
if in_progress is None:
112-
# TODO consider creating new token to allow client
113-
# to get message describing why it wasn't accepted
114-
return types.CallToolAsyncResult(accepted=False)
119+
# async with self._lock:
120+
in_progress = self._in_progress.get(req.params.token)
121+
if in_progress is None:
122+
# TODO consider creating new token to allow client
123+
# to get message describing why it wasn't accepted
124+
logger.warning("Discarding join request for unknown async token")
125+
return types.CallToolAsyncResult(accepted=False)
126+
else:
127+
# TODO consider adding authorisation layer to make this decision
128+
if in_progress.user == user_context.get():
129+
logger.debug(f"Received join from {id(ctx.session)}")
130+
self._session_lookup[id(ctx.session)] = req.params.token
131+
in_progress.sessions[id(ctx.session)] = ctx.session
132+
return types.CallToolAsyncResult(token=req.params.token, accepted=True)
115133
else:
116-
# TODO consider adding authorisation layer to make this decision
117-
if in_progress.user == user_context.get():
118-
in_progress.sessions.append(ctx.session)
119-
return types.CallToolAsyncResult(accepted=True)
120-
else:
121-
# TODO consider creating new token to allow client
122-
# to get message describing why it wasn't accepted
123-
return types.CallToolAsyncResult(accepted=False)
134+
# TODO consider sending error via get result
135+
return types.CallToolAsyncResult(accepted=False)
124136

125137
async def cancel(self, notification: types.CancelToolAsyncNotification) -> None:
126-
async with self._lock:
127-
in_progress = self._in_progress.get(notification.params.token)
128-
if in_progress is not None and in_progress.task_group is not None:
129-
if in_progress.user == user_context.get():
130-
in_progress.task_group.cancel_scope.cancel()
131-
del self._in_progress[notification.params.token]
132-
else:
133-
logger.warning(
134-
"Permission denied for cancel notification received"
135-
f"from {user_context.get()}"
136-
)
138+
# async with self._lock:
139+
in_progress = self._in_progress.get(notification.params.token)
140+
if in_progress is not None:
141+
if in_progress.user == user_context.get():
142+
# in_progress.task_group.cancel_scope.cancel()
143+
del self._in_progress[notification.params.token]
144+
else:
145+
logger.warning(
146+
"Permission denied for cancel notification received"
147+
f"from {user_context.get()}"
148+
)
137149

138150
async def get_result(self, req: types.GetToolAsyncResultRequest):
139-
async with self._lock:
140-
in_progress = self._in_progress.get(req.params.token)
141-
if in_progress is None:
142-
return types.CallToolResult(
143-
content=[
144-
types.TextContent(type="text", text="Unknown progress token")
145-
],
146-
isError=True,
147-
)
148-
else:
149-
if in_progress.user == user_context.get():
150-
result = self._result_cache.get(in_progress.token)
151-
if result is None:
152-
return types.CallToolResult(content=[], isPending=True)
153-
else:
154-
return result
155-
else:
151+
logger.debug("Getting result")
152+
in_progress = self._in_progress.get(req.params.token)
153+
logger.debug(f"Found in progress {in_progress}")
154+
if in_progress is None:
155+
return types.CallToolResult(
156+
content=[types.TextContent(type="text", text="Unknown progress token")],
157+
isError=True,
158+
)
159+
else:
160+
if in_progress.user == user_context.get():
161+
if in_progress.future is None:
156162
return types.CallToolResult(
157163
content=[
158164
types.TextContent(type="text", text="Permission denied")
159165
],
160166
isError=True,
161167
)
168+
else:
169+
# TODO add timeout to get async result
170+
# return isPending=True if timesout
171+
result = in_progress.future.result()
172+
logger.debug(f"Found result {result}")
173+
return result
174+
else:
175+
return types.CallToolResult(
176+
content=[types.TextContent(type="text", text="Permission denied")],
177+
isError=True,
178+
)
179+
180+
async def notification_hook(
181+
self, session: ServerSession, notification: types.ServerNotification
182+
):
183+
if type(notification.root) is types.ProgressNotification:
184+
# async with self._lock:
185+
async_token = self._session_lookup.get(id(session))
186+
if async_token is None:
187+
# not all sessions are async so just debug
188+
logger.debug("Discarding progress notification from unknown session")
189+
else:
190+
in_progress = self._in_progress.get(async_token)
191+
if in_progress is None:
192+
# this should not happen
193+
logger.error("Discarding progress notification, not async")
194+
else:
195+
for session_id, other_session in in_progress.sessions.items():
196+
logger.debug(f"Checking {session_id} == {id(session)}")
197+
if not session_id == id(session):
198+
logger.debug(f"Sending progress to {id(other_session)}")
199+
await other_session.send_progress_notification(
200+
progress_token=1,
201+
progress=notification.root.params.progress,
202+
total=notification.root.params.total,
203+
message=notification.root.params.message,
204+
resource_uri=notification.root.params.resourceUri,
205+
)
206+
207+
async def session_close_hook(self, session: ServerSession):
208+
logger.debug(f"Closing {id(session)}")
209+
dropped = self._session_lookup.pop(id(session), None)
210+
if dropped is None:
211+
logger.warning(f"Discarding callback from unknown session {id(session)}")
212+
else:
213+
in_progress = self._in_progress.get(dropped)
214+
if in_progress is None:
215+
logger.warning("In progress not found")
216+
else:
217+
found = in_progress.sessions.pop(id(session), None)
218+
if found is None:
219+
logger.warning("No session found")
220+
if len(in_progress.sessions) == 0:
221+
self._in_progress.pop(dropped, None)
222+
logger.debug("In progress found")
223+
if in_progress.future is None:
224+
logger.warning("In progress future is none")
225+
else:
226+
logger.debug("Cancelled in progress future")
227+
in_progress.future.cancel()
162228

163229
async def _new_in_progress(self) -> InProgress:
164-
async with self._lock:
165-
while True:
166-
token = str(uuid4())
167-
if token not in self._in_progress:
168-
new_in_progress = InProgress(token)
169-
self._in_progress[token] = new_in_progress
170-
return new_in_progress
230+
while True:
231+
token = str(uuid4())
232+
if token not in self._in_progress:
233+
new_in_progress = InProgress(token)
234+
self._in_progress[token] = new_in_progress
235+
return new_in_progress

src/mcp/server/lowlevel/server.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,11 @@ async def run(
534534
write_stream,
535535
initialization_options,
536536
stateless=stateless,
537+
notification_hook=self.result_cache.notification_hook,
538+
session_close_hook=self.result_cache.session_close_hook,
537539
)
538540
)
541+
await stack.enter_async_context(self.result_cache)
539542

540543
async with anyio.create_task_group() as tg:
541544
async for message in session.incoming_messages:

src/mcp/server/session.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
3737
be instantiated directly by users of the MCP framework.
3838
"""
3939

40+
from collections.abc import Awaitable, Callable
4041
from enum import Enum
4142
from typing import Annotated, Any, TypeVar
4243

4344
import anyio
4445
import anyio.lowlevel
4546
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
4647
from pydantic.networks import AnyUrl, UrlConstraints
48+
from typing_extensions import Self
4749

4850
import mcp.types as types
4951
from mcp.server.models import InitializationOptions
@@ -88,9 +90,16 @@ def __init__(
8890
write_stream: MemoryObjectSendStream[SessionMessage],
8991
init_options: InitializationOptions,
9092
stateless: bool = False,
93+
notification_hook: Callable[[Self, types.ServerNotification], Awaitable[None]]
94+
| None = None,
95+
session_close_hook: Callable[[Self], Awaitable[None]] | None = None,
9196
) -> None:
9297
super().__init__(
93-
read_stream, write_stream, types.ClientRequest, types.ClientNotification
98+
read_stream,
99+
write_stream,
100+
types.ClientRequest,
101+
types.ClientNotification,
102+
notification_hook=notification_hook,
94103
)
95104
self._initialization_state = (
96105
InitializationState.Initialized
@@ -106,6 +115,12 @@ def __init__(
106115
lambda: self._incoming_message_stream_reader.aclose()
107116
)
108117

118+
async def call_session_close():
119+
if session_close_hook is not None:
120+
await session_close_hook(self)
121+
122+
self._exit_stack.push_async_callback(call_session_close)
123+
109124
@property
110125
def client_params(self) -> types.InitializeRequestParams | None:
111126
return self._client_params

0 commit comments

Comments
 (0)