|
1 | 1 | import logging |
2 | 2 | from contextlib import asynccontextmanager |
3 | 3 | from typing import Any, cast |
4 | | -from uuid import UUID, uuid4 |
| 4 | +from uuid import UUID |
5 | 5 |
|
6 | 6 | import anyio |
7 | | -from anyio import CancelScope, CapacityLimiter, Event, lowlevel |
8 | | -from anyio.abc import TaskGroup |
| 7 | +from anyio import CapacityLimiter, lowlevel |
| 8 | +from pydantic import ValidationError |
9 | 9 |
|
10 | 10 | import mcp.types as types |
11 | | -from mcp.server.message_queue.base import MessageCallback, MessageWrapper |
| 11 | +from mcp.server.message_queue.base import MessageCallback |
12 | 12 |
|
13 | 13 | try: |
14 | 14 | import redis.asyncio as redis |
@@ -42,234 +42,98 @@ def __init__( |
42 | 42 | self._prefix = prefix |
43 | 43 | self._active_sessions_key = f"{prefix}active_sessions" |
44 | 44 | self._callbacks: dict[UUID, MessageCallback] = {} |
45 | | - self._handlers: dict[UUID, TaskGroup] = {} |
| 45 | + # Ensures only one polling task runs at a time for message handling |
46 | 46 | self._limiter = CapacityLimiter(1) |
47 | | - self._ack_events: dict[str, Event] = {} |
48 | | - |
49 | 47 | logger.debug(f"Redis message dispatch initialized: {redis_url}") |
50 | 48 |
|
51 | 49 | def _session_channel(self, session_id: UUID) -> str: |
52 | 50 | """Get the Redis channel for a session.""" |
53 | 51 | return f"{self._prefix}session:{session_id.hex}" |
54 | 52 |
|
55 | | - def _ack_channel(self, session_id: UUID) -> str: |
56 | | - """Get the acknowledgment channel for a session.""" |
57 | | - return f"{self._prefix}ack:{session_id.hex}" |
58 | | - |
59 | 53 | @asynccontextmanager |
60 | 54 | async def subscribe(self, session_id: UUID, callback: MessageCallback): |
61 | 55 | """Request-scoped context manager that subscribes to messages for a session.""" |
62 | 56 | await self._redis.sadd(self._active_sessions_key, session_id.hex) |
63 | 57 | self._callbacks[session_id] = callback |
| 58 | + channel = self._session_channel(session_id) |
| 59 | + await self._pubsub.subscribe(channel) # type: ignore |
64 | 60 |
|
65 | | - session_channel = self._session_channel(session_id) |
66 | | - ack_channel = self._ack_channel(session_id) |
67 | | - |
68 | | - await self._pubsub.subscribe(session_channel) # type: ignore |
69 | | - await self._pubsub.subscribe(ack_channel) # type: ignore |
70 | | - |
71 | | - logger.debug(f"Subscribing to Redis channels for session {session_id}") |
72 | | - |
73 | | - # Store the task group for the session |
| 61 | + logger.debug(f"Subscribing to Redis channel for session {session_id}") |
74 | 62 | async with anyio.create_task_group() as tg: |
75 | | - self._handlers[session_id] = tg |
76 | 63 | tg.start_soon(self._listen_for_messages) |
77 | 64 | try: |
78 | 65 | yield |
79 | 66 | finally: |
80 | 67 | tg.cancel_scope.cancel() |
81 | | - await self._pubsub.unsubscribe(session_channel) # type: ignore |
82 | | - await self._pubsub.unsubscribe(ack_channel) # type: ignore |
| 68 | + await self._pubsub.unsubscribe(channel) # type: ignore |
83 | 69 | await self._redis.srem(self._active_sessions_key, session_id.hex) |
84 | 70 | del self._callbacks[session_id] |
85 | | - logger.debug( |
86 | | - f"Unsubscribed from Redis channels for session {session_id}" |
87 | | - ) |
88 | | - del self._handlers[session_id] |
89 | | - |
90 | | - def _parse_ack_channel(self, channel: str) -> UUID | None: |
91 | | - """Parse and validate an acknowledgment channel, returning session_id.""" |
92 | | - ack_prefix = f"{self._prefix}ack:" |
93 | | - if not channel.startswith(ack_prefix): |
94 | | - return None |
95 | | - |
96 | | - # Extract exactly what should be a UUID hex after the prefix |
97 | | - session_hex = channel[len(ack_prefix):] |
98 | | - if len(session_hex) != 32: # Standard UUID hex length |
99 | | - logger.error(f"Invalid UUID length in ack channel: {channel}") |
100 | | - return None |
101 | | - |
102 | | - try: |
103 | | - session_id = UUID(hex=session_hex) |
104 | | - expected_channel = self._ack_channel(session_id) |
105 | | - if channel != expected_channel: |
106 | | - logger.error(f"Channel mismatch: got {channel}, expected {expected_channel}") |
107 | | - return None |
108 | | - return session_id |
109 | | - except ValueError: |
110 | | - logger.error(f"Invalid UUID hex in ack channel: {channel}") |
111 | | - return None |
112 | | - |
113 | | - def _parse_session_channel(self, channel: str) -> UUID | None: |
114 | | - """Parse and validate a session channel, returning session_id.""" |
115 | | - session_prefix = f"{self._prefix}session:" |
116 | | - if not channel.startswith(session_prefix): |
117 | | - return None |
118 | | - |
119 | | - # Extract exactly what should be a UUID hex after the prefix |
120 | | - session_hex = channel[len(session_prefix):] |
121 | | - if len(session_hex) != 32: # Standard UUID hex length |
122 | | - logger.error(f"Invalid UUID length in session channel: {channel}") |
123 | | - return None |
124 | | - |
125 | | - try: |
126 | | - session_id = UUID(hex=session_hex) |
127 | | - expected_channel = self._session_channel(session_id) |
128 | | - if channel != expected_channel: |
129 | | - logger.error(f"Channel mismatch: got {channel}, expected {expected_channel}") |
130 | | - return None |
131 | | - return session_id |
132 | | - except ValueError: |
133 | | - logger.error(f"Invalid UUID hex in session channel: {channel}") |
134 | | - return None |
| 71 | + logger.debug(f"Unsubscribed from Redis channel: {session_id}") |
135 | 72 |
|
136 | 73 | async def _listen_for_messages(self) -> None: |
137 | 74 | """Background task that listens for messages on subscribed channels.""" |
138 | 75 | async with self._limiter: |
139 | 76 | while True: |
140 | 77 | await lowlevel.checkpoint() |
141 | | - # Shield message retrieval from cancellation to ensure no messages are |
142 | | - # lost when a session disconnects during processing. |
143 | | - with CancelScope(shield=True): |
144 | | - redis_message: ( # type: ignore |
145 | | - None | dict[str, Any] |
146 | | - ) = await self._pubsub.get_message( # type: ignore |
147 | | - ignore_subscribe_messages=True, |
148 | | - timeout=0.1, # type: ignore |
149 | | - ) |
150 | | - if redis_message is None: |
151 | | - continue |
152 | | - |
153 | | - channel: str = cast(str, redis_message["channel"]) |
154 | | - data: str = cast(str, redis_message["data"]) |
155 | | - |
156 | | - # Determine which session this message is for |
157 | | - session_id = None |
158 | | - if channel.startswith(f"{self._prefix}ack:"): |
159 | | - session_id = self._parse_ack_channel(channel) |
160 | | - elif channel.startswith(f"{self._prefix}session:"): |
161 | | - session_id = self._parse_session_channel(channel) |
162 | | - |
163 | | - if session_id is None: |
164 | | - logger.debug(f"Ignoring message from channel: {channel}") |
| 78 | + message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore |
| 79 | + ignore_subscribe_messages=True, |
| 80 | + timeout=None, # type: ignore |
| 81 | + ) |
| 82 | + if message is None: |
| 83 | + continue |
| 84 | + |
| 85 | + channel: str = cast(str, message["channel"]) |
| 86 | + expected_prefix = f"{self._prefix}session:" |
| 87 | + |
| 88 | + if not channel.startswith(expected_prefix): |
| 89 | + logger.debug(f"Ignoring message from non-MCP channel: {channel}") |
| 90 | + continue |
| 91 | + |
| 92 | + session_hex = channel[len(expected_prefix) :] |
| 93 | + try: |
| 94 | + session_id = UUID(hex=session_hex) |
| 95 | + expected_channel = self._session_channel(session_id) |
| 96 | + if channel != expected_channel: |
| 97 | + logger.error(f"Channel format mismatch: {channel}") |
165 | 98 | continue |
166 | | - |
167 | | - if session_id not in self._handlers: |
168 | | - logger.warning(f"Dropping message for non-existent session: {session_id}") |
| 99 | + except ValueError: |
| 100 | + logger.error(f"Invalid UUID in channel: {channel}") |
| 101 | + continue |
| 102 | + |
| 103 | + data: str = cast(str, message["data"]) |
| 104 | + try: |
| 105 | + if session_id not in self._callbacks: |
| 106 | + logger.warning(f"Message dropped: no callback for {session_id}") |
169 | 107 | continue |
170 | | - |
171 | | - session_tg = self._handlers[session_id] |
172 | | - if channel.startswith(f"{self._prefix}ack:"): |
173 | | - session_tg.start_soon(self._handle_ack_message, channel, data) |
174 | | - else: |
175 | | - session_tg.start_soon(self._handle_session_message, channel, data) |
176 | | - |
177 | | - async def _handle_ack_message(self, channel: str, data: str) -> None: |
178 | | - """Handle acknowledgment messages received on ack channels.""" |
179 | | - session_id = self._parse_ack_channel(channel) |
180 | | - if session_id is None: |
181 | | - return |
182 | | - |
183 | | - # Extract message ID from data |
184 | | - message_id = data.strip() |
185 | | - if message_id in self._ack_events: |
186 | | - logger.debug(f"Received acknowledgment for message: {message_id}") |
187 | | - self._ack_events[message_id].set() |
188 | | - |
189 | | - async def _handle_session_message(self, channel: str, data: str) -> None: |
190 | | - """Handle regular messages received on session channels.""" |
191 | | - session_id = self._parse_session_channel(channel) |
192 | | - if session_id is None: |
193 | | - return |
194 | 108 |
|
195 | | - if session_id not in self._callbacks: |
196 | | - logger.warning(f"Message dropped: no callback for {session_id}") |
197 | | - return |
198 | | - |
199 | | - try: |
200 | | - wrapper = MessageWrapper.model_validate_json(data) |
201 | | - result = wrapper.get_json_rpc_message() |
202 | | - await self._callbacks[session_id](result) |
203 | | - await self._send_acknowledgment(session_id, wrapper.message_id) |
204 | | - |
205 | | - except Exception as e: |
206 | | - logger.error(f"Error processing message for {session_id}: {e}") |
207 | | - |
208 | | - async def _send_acknowledgment(self, session_id: UUID, message_id: str) -> None: |
209 | | - """Send an acknowledgment for a message that was successfully processed.""" |
210 | | - ack_channel = self._ack_channel(session_id) |
211 | | - await self._redis.publish(ack_channel, message_id) # type: ignore |
212 | | - logger.debug( |
213 | | - f"Sent acknowledgment for message {message_id} to session {session_id}" |
214 | | - ) |
| 109 | + # Try to parse as valid message or recreate original ValidationError |
| 110 | + try: |
| 111 | + msg = types.JSONRPCMessage.model_validate_json(data) |
| 112 | + await self._callbacks[session_id](msg) |
| 113 | + except ValidationError as exc: |
| 114 | + # Pass the identical validation error that would have occurred |
| 115 | + await self._callbacks[session_id](exc) |
| 116 | + except Exception as e: |
| 117 | + logger.error(f"Error processing message for {session_id}: {e}") |
215 | 118 |
|
216 | 119 | async def publish_message( |
217 | | - self, |
218 | | - session_id: UUID, |
219 | | - message: types.JSONRPCMessage | str, |
220 | | - message_id: str | None = None, |
221 | | - ) -> str | None: |
| 120 | + self, session_id: UUID, message: types.JSONRPCMessage | str |
| 121 | + ) -> bool: |
222 | 122 | """Publish a message for the specified session.""" |
223 | 123 | if not await self.session_exists(session_id): |
224 | 124 | logger.warning(f"Message dropped: unknown session {session_id}") |
225 | | - return None |
| 125 | + return False |
226 | 126 |
|
227 | 127 | # Pass raw JSON strings directly, preserving validation errors |
228 | | - message_id = message_id or str(uuid4()) |
229 | 128 | if isinstance(message, str): |
230 | | - wrapper = MessageWrapper(message_id=message_id, payload=message) |
| 129 | + data = message |
231 | 130 | else: |
232 | | - wrapper = MessageWrapper( |
233 | | - message_id=message_id, payload=message.model_dump_json() |
234 | | - ) |
| 131 | + data = message.model_dump_json() |
235 | 132 |
|
236 | 133 | channel = self._session_channel(session_id) |
237 | | - await self._redis.publish(channel, wrapper.model_dump_json()) # type: ignore |
238 | | - logger.debug( |
239 | | - f"Message {message_id} published to Redis channel for session {session_id}" |
240 | | - ) |
241 | | - return message_id |
242 | | - |
243 | | - async def publish_message_sync( |
244 | | - self, |
245 | | - session_id: UUID, |
246 | | - message: types.JSONRPCMessage | str, |
247 | | - timeout: float = 120.0, |
248 | | - ) -> bool: |
249 | | - """Publish a message and wait for acknowledgment of processing.""" |
250 | | - message_id = str(uuid4()) |
251 | | - ack_event = Event() |
252 | | - self._ack_events[message_id] = ack_event |
253 | | - |
254 | | - try: |
255 | | - published_id = await self.publish_message(session_id, message, message_id) |
256 | | - if published_id is None: |
257 | | - return False |
258 | | - |
259 | | - with anyio.fail_after(timeout): |
260 | | - await ack_event.wait() |
261 | | - logger.debug(f"Received acknowledgment for message {message_id}") |
262 | | - return True |
263 | | - |
264 | | - except TimeoutError: |
265 | | - logger.warning( |
266 | | - f"Timed out waiting for acknowledgment of message {message_id}" |
267 | | - ) |
268 | | - return False |
269 | | - |
270 | | - finally: |
271 | | - if message_id in self._ack_events: |
272 | | - del self._ack_events[message_id] |
| 134 | + await self._redis.publish(channel, data) # type: ignore[attr-defined] |
| 135 | + logger.debug(f"Message published to Redis channel for session {session_id}") |
| 136 | + return True |
273 | 137 |
|
274 | 138 | async def session_exists(self, session_id: UUID) -> bool: |
275 | 139 | """Check if a session exists.""" |
|
0 commit comments