11import json
22import logging
3+ from contextlib import asynccontextmanager
4+ from typing import Any , cast
35from uuid import UUID
46
57import anyio
6- from anyio import CapacityLimiter
8+ from anyio import CapacityLimiter , from_thread
79import mcp .types as types
810from mcp .server .message_queue .base import MessageCallback
9- from typing import Any , cast
10- from anyio import from_thread
11- from contextlib import asynccontextmanager
1211
1312
1413try :
@@ -38,8 +37,8 @@ def __init__(
3837 redis_url: Redis connection string
3938 prefix: Key prefix for Redis channels to avoid collisions
4039 """
41- self ._redis = redis .from_url (redis_url , decode_responses = True ) # type: ignore
42- self ._pubsub = self ._redis .pubsub (ignore_subscribe_messages = True ) # type: ignore
40+ self ._redis = redis .from_url (redis_url , decode_responses = True ) # type: ignore
41+ self ._pubsub = self ._redis .pubsub (ignore_subscribe_messages = True ) # type: ignore
4342 self ._prefix = prefix
4443 self ._active_sessions_key = f"{ prefix } active_sessions"
4544 self ._callbacks : dict [UUID , MessageCallback ] = {}
@@ -53,11 +52,10 @@ def _session_channel(self, session_id: UUID) -> str:
5352 @asynccontextmanager
5453 async def active_for_request (self , session_id : UUID , callback : MessageCallback ):
5554 """Request-scoped context manager that ensures the listener task is running."""
56-
5755 await self ._redis .sadd (self ._active_sessions_key , session_id .hex )
5856 self ._callbacks [session_id ] = callback
5957 channel = self ._session_channel (session_id )
60- await self ._pubsub .subscribe (channel ) # type: ignore
58+ await self ._pubsub .subscribe (channel ) # type: ignore
6159
6260 logger .debug (f"Registered session { session_id } in Redis with callback" )
6361 async with anyio .create_task_group () as tg :
@@ -66,7 +64,7 @@ async def active_for_request(self, session_id: UUID, callback: MessageCallback):
6664 yield
6765 finally :
6866 tg .cancel_scope .cancel ()
69- await self ._pubsub .unsubscribe (channel ) # type: ignore
67+ await self ._pubsub .unsubscribe (channel ) # type: ignore
7068 await self ._redis .srem (self ._active_sessions_key , session_id .hex )
7169 del self ._callbacks [session_id ]
7270 logger .debug (f"Unregistered session { session_id } from Redis" )
@@ -75,40 +73,41 @@ async def _listen_for_messages(self) -> None:
7573 """Background task that listens for messages on subscribed channels."""
7674 async with self ._limiter :
7775 while True :
78- message : None | dict [str , Any ] = await self ._pubsub .get_message ( # type: ignore
76+ message : None | dict [str , Any ] = await self ._pubsub .get_message ( # type: ignore
7977 ignore_subscribe_messages = True
8078 )
81- if message is not None :
82- # Extract session ID from channel name
83- channel : str = cast (str , message ["channel" ])
84- if not channel .startswith (self ._prefix ):
85- continue
86-
87- session_hex = channel .split (":" )[- 1 ]
88- try :
89- session_id = UUID (hex = session_hex )
90- except ValueError :
91- logger .error (f"Invalid session channel: { channel } " )
92- continue
93-
94- # Deserialize the message
95- data : str = cast (str , message ["data" ])
96- msg : None | types .JSONRPCMessage | Exception = None
97- try :
98- json_data = json .loads (data )
99- if isinstance (json_data , dict ):
100- json_dict : dict [str , Any ] = json_data
101- if json_dict .get ("_exception" , False ):
102- msg = Exception (
103- f"{ json_dict ['type' ]} : { json_dict ['message' ]} "
104- )
105- else :
106- msg = types .JSONRPCMessage .model_validate_json (data )
107-
108- if msg and session_id in self ._callbacks :
109- from_thread .run (self ._callbacks [session_id ], msg )
110- except Exception as e :
111- logger .error (f"Failed to process message: { e } " )
79+ if message is None :
80+ continue
81+
82+ # Extract session ID from channel name
83+ channel : str = cast (str , message ["channel" ])
84+ if not channel .startswith (self ._prefix ):
85+ continue
86+
87+ session_hex = channel .split (":" )[- 1 ]
88+ try :
89+ session_id = UUID (hex = session_hex )
90+ except ValueError :
91+ logger .error (f"Invalid session channel: { channel } " )
92+ continue
93+
94+ data : str = cast (str , message ["data" ])
95+ msg : None | types .JSONRPCMessage | Exception = None
96+ try :
97+ json_data = json .loads (data )
98+ if isinstance (json_data , dict ):
99+ json_dict : dict [str , Any ] = json_data
100+ if json_dict .get ("_exception" , False ):
101+ msg = Exception (
102+ f"{ json_dict ['type' ]} : { json_dict ['message' ]} "
103+ )
104+ else :
105+ msg = types .JSONRPCMessage .model_validate_json (data )
106+
107+ if msg and session_id in self ._callbacks :
108+ from_thread .run (self ._callbacks [session_id ], msg )
109+ except Exception as e :
110+ logger .error (f"Failed to process message: { e } " )
112111
113112 async def publish_message (
114113 self , session_id : UUID , message : types .JSONRPCMessage | Exception
@@ -136,8 +135,6 @@ async def publish_message(
136135
137136 async def session_exists (self , session_id : UUID ) -> bool :
138137 """Check if a session exists."""
139- # Explicitly annotate the result as bool to help the type checker
140- result = bool (
138+ return bool (
141139 await self ._redis .sismember (self ._active_sessions_key , session_id .hex ) # type: ignore[attr-defined]
142140 )
143- return result
0 commit comments