44from uuid import UUID
55
66import anyio
7- from anyio import CapacityLimiter , lowlevel
7+ from anyio import CancelScope , CapacityLimiter , lowlevel
8+ from anyio .abc import TaskGroup
89from pydantic import ValidationError
910
1011import mcp .types as types
@@ -42,6 +43,7 @@ def __init__(
4243 self ._prefix = prefix
4344 self ._active_sessions_key = f"{ prefix } active_sessions"
4445 self ._callbacks : dict [UUID , MessageCallback ] = {}
46+ self ._task_groups : dict [UUID , TaskGroup ] = {}
4547 # Ensures only one polling task runs at a time for message handling
4648 self ._limiter = CapacityLimiter (1 )
4749 logger .debug (f"Redis message dispatch initialized: { redis_url } " )
@@ -60,6 +62,7 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6062
6163 logger .debug (f"Subscribing to Redis channel for session { session_id } " )
6264 async with anyio .create_task_group () as tg :
65+ self ._task_groups [session_id ] = tg
6366 tg .start_soon (self ._listen_for_messages )
6467 try :
6568 yield
@@ -68,53 +71,75 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6871 await self ._pubsub .unsubscribe (channel ) # type: ignore
6972 await self ._redis .srem (self ._active_sessions_key , session_id .hex )
7073 del self ._callbacks [session_id ]
74+ del self ._task_groups [session_id ]
7175 logger .debug (f"Unsubscribed from Redis channel: { session_id } " )
7276
77+ def _extract_session_id (self , channel : str ) -> UUID | None :
78+ """Extract and validate session ID from channel."""
79+ expected_prefix = f"{ self ._prefix } session:"
80+ if not channel .startswith (expected_prefix ):
81+ return None
82+
83+ session_hex = channel [len (expected_prefix ) :]
84+ try :
85+ session_id = UUID (hex = session_hex )
86+ if channel != self ._session_channel (session_id ):
87+ logger .error (f"Channel format mismatch: { channel } " )
88+ return None
89+ return session_id
90+ except ValueError :
91+ logger .error (f"Invalid UUID in channel: { channel } " )
92+ return None
93+
7394 async def _listen_for_messages (self ) -> None :
7495 """Background task that listens for messages on subscribed channels."""
7596 async with self ._limiter :
7697 while True :
7798 await lowlevel .checkpoint ()
78- message : None | dict [str , Any ] = await self ._pubsub .get_message ( # type: ignore
79- ignore_subscribe_messages = True ,
80- timeout = None , # type: ignore
81- )
82- if message is None :
83- continue
84-
85- channel : str = cast (str , message ["channel" ])
86- expected_prefix = f"{ self ._prefix } session:"
87-
88- if not channel .startswith (expected_prefix ):
89- logger .debug (f"Ignoring message from non-MCP channel: { channel } " )
90- continue
91-
92- session_hex = channel [len (expected_prefix ) :]
93- try :
94- session_id = UUID (hex = session_hex )
95- expected_channel = self ._session_channel (session_id )
96- if channel != expected_channel :
97- logger .error (f"Channel format mismatch: { channel } " )
99+ with CancelScope (shield = True ):
100+ message : None | dict [str , Any ] = await self ._pubsub .get_message ( # type: ignore
101+ ignore_subscribe_messages = True ,
102+ timeout = 0.1 , # type: ignore
103+ )
104+ if message is None :
98105 continue
99- except ValueError :
100- logger .error (f"Invalid UUID in channel: { channel } " )
101- continue
102-
103- data : str = cast (str , message ["data" ])
104- try :
105- if session_id not in self ._callbacks :
106- logger .warning (f"Message dropped: no callback for { session_id } " )
106+
107+ channel : str = cast (str , message ["channel" ])
108+ session_id = self ._extract_session_id (channel )
109+ if session_id is None :
110+ logger .debug (f"Ignoring message from non-MCP channel: { channel } " )
107111 continue
108112
109- # Try to parse as valid message or recreate original ValidationError
113+ data : str = cast ( str , message [ "data" ])
110114 try :
111- msg = types .JSONRPCMessage .model_validate_json (data )
112- await self ._callbacks [session_id ](msg )
113- except ValidationError as exc :
114- # Pass the identical validation error that would have occurred
115- await self ._callbacks [session_id ](exc )
116- except Exception as e :
117- logger .error (f"Error processing message for { session_id } : { e } " )
115+ if session_id in self ._task_groups :
116+ self ._task_groups [session_id ].start_soon (
117+ self ._handle_message , session_id , data
118+ )
119+ else :
120+ logger .warning (
121+ f"Message dropped: no task group for session: { session_id } "
122+ )
123+ except Exception as e :
124+ logger .error (f"Error processing message for { session_id } : { e } " )
125+
126+ async def _handle_message (self , session_id : UUID , data : str ) -> None :
127+ """Process a message from Redis in the session's task group."""
128+ if session_id not in self ._callbacks :
129+ logger .warning (f"Message dropped: callback removed for { session_id } " )
130+ return
131+
132+ try :
133+ # Parse message or pass validation error to callback
134+ msg_or_error = None
135+ try :
136+ msg_or_error = types .JSONRPCMessage .model_validate_json (data )
137+ except ValidationError as exc :
138+ msg_or_error = exc
139+
140+ await self ._callbacks [session_id ](msg_or_error )
141+ except Exception as e :
142+ logger .error (f"Error in message handler for { session_id } : { e } " )
118143
119144 async def publish_message (
120145 self , session_id : UUID , message : types .JSONRPCMessage | str
0 commit comments