11from collections .abc import Awaitable , Callable
2+ from concurrent .futures import Future
23from dataclasses import dataclass , field
34from logging import getLogger
45from time import time
6+ from types import TracebackType
57from typing import Any
68from uuid import uuid4
79
8- from anyio import Lock , create_task_group , move_on_after
9- from anyio .abc import TaskGroup
10- from cachetools import TTLCache
10+ import anyio
11+ import anyio .to_thread
12+ from anyio . from_thread import BlockingPortal , BlockingPortalProvider
1113
1214from mcp import types
1315from mcp .server .auth .middleware .auth_context import auth_context_var as user_context
1416from mcp .server .auth .middleware .bearer_auth import AuthenticatedUser
15- from mcp .shared .context import BaseSession , RequestContext , SessionT
17+ from mcp .server .session import ServerSession
18+ from mcp .shared .context import RequestContext
1619
1720logger = getLogger (__name__ )
1821
2124class InProgress :
2225 token : str
2326 user : AuthenticatedUser | None = None
24- task_group : TaskGroup | None = None
25- sessions : list [BaseSession [Any , Any , Any , Any , Any ]] = field (
26- default_factory = lambda : []
27- )
27+ future : Future [types .CallToolResult ] | None = None
28+ sessions : dict [int , ServerSession ] = field (default_factory = lambda : {})
2829
2930
3031class ResultCache :
@@ -33,16 +34,11 @@ class ResultCache:
3334 Its purpose is to act as a central point for managing in progress
3435 async calls, allowing multiple clients to join and receive progress
3536 updates, get results and/or cancel in progress calls
36- TODO CRITICAL properly support join nothing actually happens at the moment
37- TODO CRITICAL intercept progress notifications from original session and
38- pass to joined sessions
39- TODO MAJOR handle session closure gracefully -
40- at the moment old connections will hang around and cause problems later
37+ TODO CRITICAL keep_alive logic is not correct as per spec - results currently
38+ only kept for as long as longest session reintroduce TTL cache
4139 TODO MAJOR needs a lot more testing around edge cases/failure scenarios
42- TODO MINOR keep_alive logic is not correct as per spec - results are
43- cached for too long, probably better than too short
44- TODO ENHANCEMENT might look into more fine grained locks, one global lock
45- is a bottleneck though this could be delegated to other cache impls if external
40+ TODO MAJOR decide if async.Locks are required for integrity of internal
41+ data structures
4642 TODO ENHANCEMENT externalise cachetools to allow for other implementations
4743 e.g. redis etal for production scenarios
4844 TODO ENHANCEMENT may need to add an authorisation layer to decide if
@@ -52,119 +48,188 @@ class ResultCache:
5248 """
5349
5450 _in_progress : dict [types .AsyncToken , InProgress ]
51+ _session_lookup : dict [int , types .AsyncToken ]
52+ _portal : BlockingPortal
5553
5654 def __init__ (self , max_size : int , max_keep_alive : int ):
5755 self ._max_size = max_size
5856 self ._max_keep_alive = max_keep_alive
59- self ._result_cache = TTLCache [types .AsyncToken , types .CallToolResult ](
60- self ._max_size , self ._max_keep_alive
61- )
6257 self ._in_progress = {}
63- self ._lock = Lock ()
58+ self ._session_lookup = {}
59+ self ._portal_provider = BlockingPortalProvider ()
60+
61+ async def __aenter__ (self ):
62+ def create_portal ():
63+ self ._portal = self ._portal_provider .__enter__ ()
64+
65+ await anyio .to_thread .run_sync (create_portal )
66+
67+ async def __aexit__ (
68+ self ,
69+ exc_type : type [BaseException ] | None ,
70+ exc_val : BaseException | None ,
71+ exc_tb : TracebackType | None ,
72+ ) -> bool | None :
73+ await anyio .to_thread .run_sync (lambda : self ._portal_provider .__exit__ )
6474
6575 async def add_call (
6676 self ,
6777 call : Callable [[types .CallToolRequest ], Awaitable [types .ServerResult ]],
6878 req : types .CallToolAsyncRequest ,
69- ctx : RequestContext [SessionT , Any , Any ],
79+ ctx : RequestContext [ServerSession , Any , Any ],
7080 ) -> types .CallToolAsyncResult :
7181 in_progress = await self ._new_in_progress ()
7282 timeout = min (
7383 req .params .keepAlive or self ._max_keep_alive , self ._max_keep_alive
7484 )
7585
7686 async def call_tool ():
77- with move_on_after ( timeout ) as scope :
78- result = await call (
79- types . CallToolRequest (
80- method = "tools/call" ,
81- params = types . CallToolRequestParams (
82- name = req . params . name , arguments = req .params .arguments
83- ) ,
84- )
87+ result = await call (
88+ types . CallToolRequest (
89+ method = "tools/call" ,
90+ params = types . CallToolRequestParams (
91+ name = req . params . name ,
92+ arguments = req .params .arguments ,
93+ _meta = req . params . meta ,
94+ ),
8595 )
86- if not scope .cancel_called :
87- async with self ._lock :
88- assert type (result .root ) is types .CallToolResult
89- self ._result_cache [in_progress .token ] = result .root
90-
91- async with create_task_group () as tg :
92- tg .start_soon (call_tool )
93- in_progress .task_group = tg
94- in_progress .user = user_context .get ()
95- in_progress .sessions .append (ctx .session )
96- result = types .CallToolAsyncResult (
97- token = in_progress .token ,
98- recieved = round (time ()),
99- keepAlive = timeout ,
100- accepted = True ,
10196 )
102- return result
97+ # async with self._lock:
98+ assert type (result .root ) is types .CallToolResult
99+ logger .debug (f"Got result { result } " )
100+ return result .root
101+
102+ in_progress .user = user_context .get ()
103+ in_progress .sessions [id (ctx .session )] = ctx .session
104+ self ._session_lookup [id (ctx .session )] = in_progress .token
105+ in_progress .future = self ._portal .start_task_soon (call_tool )
106+ result = types .CallToolAsyncResult (
107+ token = in_progress .token ,
108+ recieved = round (time ()),
109+ keepAlive = timeout ,
110+ accepted = True ,
111+ )
112+ return result
103113
104114 async def join_call (
105115 self ,
106116 req : types .JoinCallToolAsyncRequest ,
107- ctx : RequestContext [SessionT , Any , Any ],
117+ ctx : RequestContext [ServerSession , Any , Any ],
108118 ) -> types .CallToolAsyncResult :
109- async with self ._lock :
110- in_progress = self ._in_progress .get (req .params .token )
111- if in_progress is None :
112- # TODO consider creating new token to allow client
113- # to get message describing why it wasn't accepted
114- return types .CallToolAsyncResult (accepted = False )
119+ # async with self._lock:
120+ in_progress = self ._in_progress .get (req .params .token )
121+ if in_progress is None :
122+ # TODO consider creating new token to allow client
123+ # to get message describing why it wasn't accepted
124+ logger .warning ("Discarding join request for unknown async token" )
125+ return types .CallToolAsyncResult (accepted = False )
126+ else :
127+ # TODO consider adding authorisation layer to make this decision
128+ if in_progress .user == user_context .get ():
129+ logger .debug (f"Received join from { id (ctx .session )} " )
130+ self ._session_lookup [id (ctx .session )] = req .params .token
131+ in_progress .sessions [id (ctx .session )] = ctx .session
132+ return types .CallToolAsyncResult (token = req .params .token , accepted = True )
115133 else :
116- # TODO consider adding authorisation layer to make this decision
117- if in_progress .user == user_context .get ():
118- in_progress .sessions .append (ctx .session )
119- return types .CallToolAsyncResult (accepted = True )
120- else :
121- # TODO consider creating new token to allow client
122- # to get message describing why it wasn't accepted
123- return types .CallToolAsyncResult (accepted = False )
134+ # TODO consider sending error via get result
135+ return types .CallToolAsyncResult (accepted = False )
124136
125137 async def cancel (self , notification : types .CancelToolAsyncNotification ) -> None :
126- async with self ._lock :
127- in_progress = self ._in_progress .get (notification .params .token )
128- if in_progress is not None and in_progress . task_group is not None :
129- if in_progress .user == user_context .get ():
130- in_progress .task_group .cancel_scope .cancel ()
131- del self ._in_progress [notification .params .token ]
132- else :
133- logger .warning (
134- "Permission denied for cancel notification received"
135- f"from { user_context .get ()} "
136- )
138+ # async with self._lock:
139+ in_progress = self ._in_progress .get (notification .params .token )
140+ if in_progress is not None :
141+ if in_progress .user == user_context .get ():
142+ # in_progress.task_group.cancel_scope.cancel()
143+ del self ._in_progress [notification .params .token ]
144+ else :
145+ logger .warning (
146+ "Permission denied for cancel notification received"
147+ f"from { user_context .get ()} "
148+ )
137149
138150 async def get_result (self , req : types .GetToolAsyncResultRequest ):
139- async with self ._lock :
140- in_progress = self ._in_progress .get (req .params .token )
141- if in_progress is None :
142- return types .CallToolResult (
143- content = [
144- types .TextContent (type = "text" , text = "Unknown progress token" )
145- ],
146- isError = True ,
147- )
148- else :
149- if in_progress .user == user_context .get ():
150- result = self ._result_cache .get (in_progress .token )
151- if result is None :
152- return types .CallToolResult (content = [], isPending = True )
153- else :
154- return result
155- else :
151+ logger .debug ("Getting result" )
152+ in_progress = self ._in_progress .get (req .params .token )
153+ logger .debug (f"Found in progress { in_progress } " )
154+ if in_progress is None :
155+ return types .CallToolResult (
156+ content = [types .TextContent (type = "text" , text = "Unknown progress token" )],
157+ isError = True ,
158+ )
159+ else :
160+ if in_progress .user == user_context .get ():
161+ if in_progress .future is None :
156162 return types .CallToolResult (
157163 content = [
158164 types .TextContent (type = "text" , text = "Permission denied" )
159165 ],
160166 isError = True ,
161167 )
168+ else :
169+ # TODO add timeout to get async result
170+ # return isPending=True if timesout
171+ result = in_progress .future .result ()
172+ logger .debug (f"Found result { result } " )
173+ return result
174+ else :
175+ return types .CallToolResult (
176+ content = [types .TextContent (type = "text" , text = "Permission denied" )],
177+ isError = True ,
178+ )
179+
180+ async def notification_hook (
181+ self , session : ServerSession , notification : types .ServerNotification
182+ ):
183+ if type (notification .root ) is types .ProgressNotification :
184+ # async with self._lock:
185+ async_token = self ._session_lookup .get (id (session ))
186+ if async_token is None :
187+ # not all sessions are async so just debug
188+ logger .debug ("Discarding progress notification from unknown session" )
189+ else :
190+ in_progress = self ._in_progress .get (async_token )
191+ if in_progress is None :
192+ # this should not happen
193+ logger .error ("Discarding progress notification, not async" )
194+ else :
195+ for session_id , other_session in in_progress .sessions .items ():
196+ logger .debug (f"Checking { session_id } == { id (session )} " )
197+ if not session_id == id (session ):
198+ logger .debug (f"Sending progress to { id (other_session )} " )
199+ await other_session .send_progress_notification (
200+ progress_token = 1 ,
201+ progress = notification .root .params .progress ,
202+ total = notification .root .params .total ,
203+ message = notification .root .params .message ,
204+ resource_uri = notification .root .params .resourceUri ,
205+ )
206+
207+ async def session_close_hook (self , session : ServerSession ):
208+ logger .debug (f"Closing { id (session )} " )
209+ dropped = self ._session_lookup .pop (id (session ), None )
210+ if dropped is None :
211+ logger .warning (f"Discarding callback from unknown session { id (session )} " )
212+ else :
213+ in_progress = self ._in_progress .get (dropped )
214+ if in_progress is None :
215+ logger .warning ("In progress not found" )
216+ else :
217+ found = in_progress .sessions .pop (id (session ), None )
218+ if found is None :
219+ logger .warning ("No session found" )
220+ if len (in_progress .sessions ) == 0 :
221+ self ._in_progress .pop (dropped , None )
222+ logger .debug ("In progress found" )
223+ if in_progress .future is None :
224+ logger .warning ("In progress future is none" )
225+ else :
226+ logger .debug ("Cancelled in progress future" )
227+ in_progress .future .cancel ()
162228
163229 async def _new_in_progress (self ) -> InProgress :
164- async with self ._lock :
165- while True :
166- token = str (uuid4 ())
167- if token not in self ._in_progress :
168- new_in_progress = InProgress (token )
169- self ._in_progress [token ] = new_in_progress
170- return new_in_progress
230+ while True :
231+ token = str (uuid4 ())
232+ if token not in self ._in_progress :
233+ new_in_progress = InProgress (token )
234+ self ._in_progress [token ] = new_in_progress
235+ return new_in_progress
0 commit comments