@@ -42,6 +42,7 @@ def __init__(
4242 self ._prefix = prefix
4343 self ._active_sessions_key = f"{ prefix } active_sessions"
4444 self ._callbacks : dict [UUID , MessageCallback ] = {}
45+ self ._handlers : dict [UUID , TaskGroup ] = {}
4546 self ._limiter = CapacityLimiter (1 )
4647 self ._ack_events : dict [str , Event ] = {}
4748
@@ -69,24 +70,70 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6970
7071 logger .debug (f"Subscribing to Redis channels for session { session_id } " )
7172
72- # Two nested task groups ensure proper cleanup: the inner one cancels the
73- # listener, while the outer one allows any handlers to complete before exiting.
74- async with anyio .create_task_group () as tg_handler :
75- async with anyio .create_task_group () as tg :
76- tg .start_soon (self ._listen_for_messages , tg_handler )
77- try :
78- yield
79- finally :
80- tg .cancel_scope .cancel ()
81- await self ._pubsub .unsubscribe (session_channel ) # type: ignore
82- await self ._pubsub .unsubscribe (ack_channel ) # type: ignore
83- await self ._redis .srem (self ._active_sessions_key , session_id .hex )
84- del self ._callbacks [session_id ]
85- logger .debug (
86- f"Unsubscribed from Redis channels for session { session_id } "
87- )
73+ # Store the task group for the session
74+ async with anyio .create_task_group () as tg :
75+ self ._handlers [session_id ] = tg
76+ tg .start_soon (self ._listen_for_messages )
77+ try :
78+ yield
79+ finally :
80+ tg .cancel_scope .cancel ()
81+ await self ._pubsub .unsubscribe (session_channel ) # type: ignore
82+ await self ._pubsub .unsubscribe (ack_channel ) # type: ignore
83+ await self ._redis .srem (self ._active_sessions_key , session_id .hex )
84+ del self ._callbacks [session_id ]
85+ logger .debug (
86+ f"Unsubscribed from Redis channels for session { session_id } "
87+ )
88+ del self ._handlers [session_id ]
89+
90+ def _parse_ack_channel (self , channel : str ) -> UUID | None :
91+ """Parse and validate an acknowledgment channel, returning session_id."""
92+ ack_prefix = f"{ self ._prefix } ack:"
93+ if not channel .startswith (ack_prefix ):
94+ return None
95+
96+ # Extract exactly what should be a UUID hex after the prefix
97+ session_hex = channel [len (ack_prefix ):]
98+ if len (session_hex ) != 32 : # Standard UUID hex length
99+ logger .error (f"Invalid UUID length in ack channel: { channel } " )
100+ return None
101+
102+ try :
103+ session_id = UUID (hex = session_hex )
104+ expected_channel = self ._ack_channel (session_id )
105+ if channel != expected_channel :
106+ logger .error (f"Channel mismatch: got { channel } , expected { expected_channel } " )
107+ return None
108+ return session_id
109+ except ValueError :
110+ logger .error (f"Invalid UUID hex in ack channel: { channel } " )
111+ return None
112+
113+ def _parse_session_channel (self , channel : str ) -> UUID | None :
114+ """Parse and validate a session channel, returning session_id."""
115+ session_prefix = f"{ self ._prefix } session:"
116+ if not channel .startswith (session_prefix ):
117+ return None
118+
119+ # Extract exactly what should be a UUID hex after the prefix
120+ session_hex = channel [len (session_prefix ):]
121+ if len (session_hex ) != 32 : # Standard UUID hex length
122+ logger .error (f"Invalid UUID length in session channel: { channel } " )
123+ return None
124+
125+ try :
126+ session_id = UUID (hex = session_hex )
127+ expected_channel = self ._session_channel (session_id )
128+ if channel != expected_channel :
129+ logger .error (f"Channel mismatch: got { channel } , expected { expected_channel } " )
130+ return None
131+ return session_id
132+ except ValueError :
133+ logger .error (f"Invalid UUID hex in session channel: { channel } " )
134+ return None
88135
89- async def _listen_for_messages (self , tg_handler : TaskGroup ) -> None :
136+ async def _listen_for_messages (self ) -> None :
90137 """Background task that listens for messages on subscribed channels."""
91138 async with self ._limiter :
92139 while True :
@@ -106,41 +153,31 @@ async def _listen_for_messages(self, tg_handler: TaskGroup) -> None:
106153 channel : str = cast (str , redis_message ["channel" ])
107154 data : str = cast (str , redis_message ["data" ])
108155
109- # Handle acknowledgment messages
156+ # Determine which session this message is for
157+ session_id = None
110158 if channel .startswith (f"{ self ._prefix } ack:" ):
111- tg_handler .start_soon (self ._handle_ack_message , channel , data )
112- continue
113-
114- # Handle session messages
159+ session_id = self ._parse_ack_channel (channel )
115160 elif channel .startswith (f"{ self ._prefix } session:" ):
116- tg_handler .start_soon (
117- self ._handle_session_message , channel , data
118- )
161+ session_id = self ._parse_session_channel (channel )
162+
163+ if session_id is None :
164+ logger .debug (f"Ignoring message from channel: { channel } " )
119165 continue
120-
121- # Ignore other channels
166+
167+ if session_id not in self ._handlers :
168+ logger .warning (f"Dropping message for non-existent session: { session_id } " )
169+ continue
170+
171+ session_tg = self ._handlers [session_id ]
172+ if channel .startswith (f"{ self ._prefix } ack:" ):
173+ session_tg .start_soon (self ._handle_ack_message , channel , data )
122174 else :
123- logger .debug (
124- f"Ignoring message from non-MCP channel: { channel } "
125- )
175+ session_tg .start_soon (self ._handle_session_message , channel , data )
126176
127177 async def _handle_ack_message (self , channel : str , data : str ) -> None :
128178 """Handle acknowledgment messages received on ack channels."""
129- ack_prefix = f"{ self ._prefix } ack:"
130- if not channel .startswith (ack_prefix ):
131- return
132-
133- session_hex = channel [len (ack_prefix ) :]
134- try :
135- session_id = UUID (hex = session_hex )
136- expected_channel = self ._ack_channel (session_id )
137- if channel != expected_channel :
138- logger .error (
139- f"Channel mismatch: got { channel } , expected { expected_channel } "
140- )
141- return
142- except ValueError :
143- logger .error (f"Invalid UUID hex in ack channel: { channel } " )
179+ session_id = self ._parse_ack_channel (channel )
180+ if session_id is None :
144181 return
145182
146183 # Extract message ID from data
@@ -151,21 +188,8 @@ async def _handle_ack_message(self, channel: str, data: str) -> None:
151188
152189 async def _handle_session_message (self , channel : str , data : str ) -> None :
153190 """Handle regular messages received on session channels."""
154- session_prefix = f"{ self ._prefix } session:"
155- if not channel .startswith (session_prefix ):
156- return
157-
158- session_hex = channel [len (session_prefix ) :]
159- try :
160- session_id = UUID (hex = session_hex )
161- expected_channel = self ._session_channel (session_id )
162- if channel != expected_channel :
163- logger .error (
164- f"Channel mismatch: got { channel } , expected { expected_channel } "
165- )
166- return
167- except ValueError :
168- logger .error (f"Invalid UUID hex in session channel: { channel } " )
191+ session_id = self ._parse_session_channel (channel )
192+ if session_id is None :
169193 return
170194
171195 if session_id not in self ._callbacks :
0 commit comments