Skip to content

Commit 215cc42

Browse files
committed
wip
1 parent 1e81f36 commit 215cc42

File tree

1 file changed

+84
-60
lines changed

1 file changed

+84
-60
lines changed

src/mcp/server/message_queue/redis.py

Lines changed: 84 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
self._prefix = prefix
4343
self._active_sessions_key = f"{prefix}active_sessions"
4444
self._callbacks: dict[UUID, MessageCallback] = {}
45+
self._handlers: dict[UUID, TaskGroup] = {}
4546
self._limiter = CapacityLimiter(1)
4647
self._ack_events: dict[str, Event] = {}
4748

@@ -69,24 +70,70 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6970

7071
logger.debug(f"Subscribing to Redis channels for session {session_id}")
7172

72-
# Two nested task groups ensure proper cleanup: the inner one cancels the
73-
# listener, while the outer one allows any handlers to complete before exiting.
74-
async with anyio.create_task_group() as tg_handler:
75-
async with anyio.create_task_group() as tg:
76-
tg.start_soon(self._listen_for_messages, tg_handler)
77-
try:
78-
yield
79-
finally:
80-
tg.cancel_scope.cancel()
81-
await self._pubsub.unsubscribe(session_channel) # type: ignore
82-
await self._pubsub.unsubscribe(ack_channel) # type: ignore
83-
await self._redis.srem(self._active_sessions_key, session_id.hex)
84-
del self._callbacks[session_id]
85-
logger.debug(
86-
f"Unsubscribed from Redis channels for session {session_id}"
87-
)
73+
# Store the task group for the session
74+
async with anyio.create_task_group() as tg:
75+
self._handlers[session_id] = tg
76+
tg.start_soon(self._listen_for_messages)
77+
try:
78+
yield
79+
finally:
80+
tg.cancel_scope.cancel()
81+
await self._pubsub.unsubscribe(session_channel) # type: ignore
82+
await self._pubsub.unsubscribe(ack_channel) # type: ignore
83+
await self._redis.srem(self._active_sessions_key, session_id.hex)
84+
del self._callbacks[session_id]
85+
logger.debug(
86+
f"Unsubscribed from Redis channels for session {session_id}"
87+
)
88+
del self._handlers[session_id]
89+
90+
def _parse_ack_channel(self, channel: str) -> UUID | None:
91+
"""Parse and validate an acknowledgment channel, returning session_id."""
92+
ack_prefix = f"{self._prefix}ack:"
93+
if not channel.startswith(ack_prefix):
94+
return None
95+
96+
# Extract exactly what should be a UUID hex after the prefix
97+
session_hex = channel[len(ack_prefix):]
98+
if len(session_hex) != 32: # Standard UUID hex length
99+
logger.error(f"Invalid UUID length in ack channel: {channel}")
100+
return None
101+
102+
try:
103+
session_id = UUID(hex=session_hex)
104+
expected_channel = self._ack_channel(session_id)
105+
if channel != expected_channel:
106+
logger.error(f"Channel mismatch: got {channel}, expected {expected_channel}")
107+
return None
108+
return session_id
109+
except ValueError:
110+
logger.error(f"Invalid UUID hex in ack channel: {channel}")
111+
return None
112+
113+
def _parse_session_channel(self, channel: str) -> UUID | None:
114+
"""Parse and validate a session channel, returning session_id."""
115+
session_prefix = f"{self._prefix}session:"
116+
if not channel.startswith(session_prefix):
117+
return None
118+
119+
# Extract exactly what should be a UUID hex after the prefix
120+
session_hex = channel[len(session_prefix):]
121+
if len(session_hex) != 32: # Standard UUID hex length
122+
logger.error(f"Invalid UUID length in session channel: {channel}")
123+
return None
124+
125+
try:
126+
session_id = UUID(hex=session_hex)
127+
expected_channel = self._session_channel(session_id)
128+
if channel != expected_channel:
129+
logger.error(f"Channel mismatch: got {channel}, expected {expected_channel}")
130+
return None
131+
return session_id
132+
except ValueError:
133+
logger.error(f"Invalid UUID hex in session channel: {channel}")
134+
return None
88135

89-
async def _listen_for_messages(self, tg_handler: TaskGroup) -> None:
136+
async def _listen_for_messages(self) -> None:
90137
"""Background task that listens for messages on subscribed channels."""
91138
async with self._limiter:
92139
while True:
@@ -106,41 +153,31 @@ async def _listen_for_messages(self, tg_handler: TaskGroup) -> None:
106153
channel: str = cast(str, redis_message["channel"])
107154
data: str = cast(str, redis_message["data"])
108155

109-
# Handle acknowledgment messages
156+
# Determine which session this message is for
157+
session_id = None
110158
if channel.startswith(f"{self._prefix}ack:"):
111-
tg_handler.start_soon(self._handle_ack_message, channel, data)
112-
continue
113-
114-
# Handle session messages
159+
session_id = self._parse_ack_channel(channel)
115160
elif channel.startswith(f"{self._prefix}session:"):
116-
tg_handler.start_soon(
117-
self._handle_session_message, channel, data
118-
)
161+
session_id = self._parse_session_channel(channel)
162+
163+
if session_id is None:
164+
logger.debug(f"Ignoring message from channel: {channel}")
119165
continue
120-
121-
# Ignore other channels
166+
167+
if session_id not in self._handlers:
168+
logger.warning(f"Dropping message for non-existent session: {session_id}")
169+
continue
170+
171+
session_tg = self._handlers[session_id]
172+
if channel.startswith(f"{self._prefix}ack:"):
173+
session_tg.start_soon(self._handle_ack_message, channel, data)
122174
else:
123-
logger.debug(
124-
f"Ignoring message from non-MCP channel: {channel}"
125-
)
175+
session_tg.start_soon(self._handle_session_message, channel, data)
126176

127177
async def _handle_ack_message(self, channel: str, data: str) -> None:
128178
"""Handle acknowledgment messages received on ack channels."""
129-
ack_prefix = f"{self._prefix}ack:"
130-
if not channel.startswith(ack_prefix):
131-
return
132-
133-
session_hex = channel[len(ack_prefix) :]
134-
try:
135-
session_id = UUID(hex=session_hex)
136-
expected_channel = self._ack_channel(session_id)
137-
if channel != expected_channel:
138-
logger.error(
139-
f"Channel mismatch: got {channel}, expected {expected_channel}"
140-
)
141-
return
142-
except ValueError:
143-
logger.error(f"Invalid UUID hex in ack channel: {channel}")
179+
session_id = self._parse_ack_channel(channel)
180+
if session_id is None:
144181
return
145182

146183
# Extract message ID from data
@@ -151,21 +188,8 @@ async def _handle_ack_message(self, channel: str, data: str) -> None:
151188

152189
async def _handle_session_message(self, channel: str, data: str) -> None:
153190
"""Handle regular messages received on session channels."""
154-
session_prefix = f"{self._prefix}session:"
155-
if not channel.startswith(session_prefix):
156-
return
157-
158-
session_hex = channel[len(session_prefix) :]
159-
try:
160-
session_id = UUID(hex=session_hex)
161-
expected_channel = self._session_channel(session_id)
162-
if channel != expected_channel:
163-
logger.error(
164-
f"Channel mismatch: got {channel}, expected {expected_channel}"
165-
)
166-
return
167-
except ValueError:
168-
logger.error(f"Invalid UUID hex in session channel: {channel}")
191+
session_id = self._parse_session_channel(channel)
192+
if session_id is None:
169193
return
170194

171195
if session_id not in self._callbacks:

0 commit comments

Comments
 (0)