Skip to content

Commit 8fce8e6

Browse files
committed
back to b484284
1 parent 215cc42 commit 8fce8e6

File tree

3 files changed

+59
-253
lines changed

3 files changed

+59
-253
lines changed

src/mcp/server/message_queue/base.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Protocol, runtime_checkable
55
from uuid import UUID
66

7-
from pydantic import BaseModel, ValidationError
7+
from pydantic import ValidationError
88

99
import mcp.types as types
1010

@@ -13,18 +13,6 @@
1313
MessageCallback = Callable[[types.JSONRPCMessage | Exception], Awaitable[None]]
1414

1515

16-
class MessageWrapper(BaseModel):
17-
message_id: str
18-
payload: str
19-
20-
def get_json_rpc_message(self) -> types.JSONRPCMessage | ValidationError:
21-
"""Parse the payload into a JSONRPCMessage or return ValidationError."""
22-
try:
23-
return types.JSONRPCMessage.model_validate_json(self.payload)
24-
except ValidationError as exc:
25-
return exc
26-
27-
2816
@runtime_checkable
2917
class MessageDispatch(Protocol):
3018
"""Abstract interface for SSE message dispatching.
@@ -47,28 +35,6 @@ async def publish_message(
4735
"""
4836
...
4937

50-
async def publish_message_sync(
51-
self,
52-
session_id: UUID,
53-
message: types.JSONRPCMessage | str,
54-
timeout: float = 120.0,
55-
) -> bool:
56-
"""Publish a message for the specified session and wait for confirmation.
57-
58-
This method blocks until the message has been fully consumed by the subscriber,
59-
or until the timeout is reached.
60-
61-
Args:
62-
session_id: The UUID of the session this message is for
63-
message: The message to publish (JSONRPCMessage or str for invalid JSON)
64-
timeout: Maximum time to wait for consumption in seconds
65-
66-
Returns:
67-
bool: True if message was published and consumed, False otherwise
68-
"""
69-
# Default implementation falls back to standard publish
70-
return await self.publish_message(session_id, message)
71-
7238
@asynccontextmanager
7339
async def subscribe(self, session_id: UUID, callback: MessageCallback):
7440
"""Request-scoped context manager that subscribes to messages for a session.
@@ -125,21 +91,6 @@ async def publish_message(
12591
logger.debug(f"Message dispatched to session {session_id}")
12692
return True
12793

128-
async def publish_message_sync(
129-
self,
130-
session_id: UUID,
131-
message: types.JSONRPCMessage | str,
132-
timeout: float = 30.0,
133-
) -> bool:
134-
"""Publish a message for the specified session and wait for consumption.
135-
136-
For InMemoryMessageDispatch, this is the same as publish_message since
137-
the callback is executed synchronously.
138-
"""
139-
# For in-memory dispatch, the message is processed immediately
140-
# so we can just call the regular publish method
141-
return await self.publish_message(session_id, message)
142-
14394
@asynccontextmanager
14495
async def subscribe(self, session_id: UUID, callback: MessageCallback):
14596
"""Request-scoped context manager that subscribes to messages for a session."""

src/mcp/server/message_queue/redis.py

Lines changed: 55 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import logging
22
from contextlib import asynccontextmanager
33
from typing import Any, cast
4-
from uuid import UUID, uuid4
4+
from uuid import UUID
55

66
import anyio
7-
from anyio import CancelScope, CapacityLimiter, Event, lowlevel
8-
from anyio.abc import TaskGroup
7+
from anyio import CapacityLimiter, lowlevel
8+
from pydantic import ValidationError
99

1010
import mcp.types as types
11-
from mcp.server.message_queue.base import MessageCallback, MessageWrapper
11+
from mcp.server.message_queue.base import MessageCallback
1212

1313
try:
1414
import redis.asyncio as redis
@@ -42,234 +42,98 @@ def __init__(
4242
self._prefix = prefix
4343
self._active_sessions_key = f"{prefix}active_sessions"
4444
self._callbacks: dict[UUID, MessageCallback] = {}
45-
self._handlers: dict[UUID, TaskGroup] = {}
45+
# Ensures only one polling task runs at a time for message handling
4646
self._limiter = CapacityLimiter(1)
47-
self._ack_events: dict[str, Event] = {}
48-
4947
logger.debug(f"Redis message dispatch initialized: {redis_url}")
5048

5149
def _session_channel(self, session_id: UUID) -> str:
5250
"""Get the Redis channel for a session."""
5351
return f"{self._prefix}session:{session_id.hex}"
5452

55-
def _ack_channel(self, session_id: UUID) -> str:
56-
"""Get the acknowledgment channel for a session."""
57-
return f"{self._prefix}ack:{session_id.hex}"
58-
5953
@asynccontextmanager
6054
async def subscribe(self, session_id: UUID, callback: MessageCallback):
6155
"""Request-scoped context manager that subscribes to messages for a session."""
6256
await self._redis.sadd(self._active_sessions_key, session_id.hex)
6357
self._callbacks[session_id] = callback
58+
channel = self._session_channel(session_id)
59+
await self._pubsub.subscribe(channel) # type: ignore
6460

65-
session_channel = self._session_channel(session_id)
66-
ack_channel = self._ack_channel(session_id)
67-
68-
await self._pubsub.subscribe(session_channel) # type: ignore
69-
await self._pubsub.subscribe(ack_channel) # type: ignore
70-
71-
logger.debug(f"Subscribing to Redis channels for session {session_id}")
72-
73-
# Store the task group for the session
61+
logger.debug(f"Subscribing to Redis channel for session {session_id}")
7462
async with anyio.create_task_group() as tg:
75-
self._handlers[session_id] = tg
7663
tg.start_soon(self._listen_for_messages)
7764
try:
7865
yield
7966
finally:
8067
tg.cancel_scope.cancel()
81-
await self._pubsub.unsubscribe(session_channel) # type: ignore
82-
await self._pubsub.unsubscribe(ack_channel) # type: ignore
68+
await self._pubsub.unsubscribe(channel) # type: ignore
8369
await self._redis.srem(self._active_sessions_key, session_id.hex)
8470
del self._callbacks[session_id]
85-
logger.debug(
86-
f"Unsubscribed from Redis channels for session {session_id}"
87-
)
88-
del self._handlers[session_id]
89-
90-
def _parse_ack_channel(self, channel: str) -> UUID | None:
91-
"""Parse and validate an acknowledgment channel, returning session_id."""
92-
ack_prefix = f"{self._prefix}ack:"
93-
if not channel.startswith(ack_prefix):
94-
return None
95-
96-
# Extract exactly what should be a UUID hex after the prefix
97-
session_hex = channel[len(ack_prefix):]
98-
if len(session_hex) != 32: # Standard UUID hex length
99-
logger.error(f"Invalid UUID length in ack channel: {channel}")
100-
return None
101-
102-
try:
103-
session_id = UUID(hex=session_hex)
104-
expected_channel = self._ack_channel(session_id)
105-
if channel != expected_channel:
106-
logger.error(f"Channel mismatch: got {channel}, expected {expected_channel}")
107-
return None
108-
return session_id
109-
except ValueError:
110-
logger.error(f"Invalid UUID hex in ack channel: {channel}")
111-
return None
112-
113-
def _parse_session_channel(self, channel: str) -> UUID | None:
114-
"""Parse and validate a session channel, returning session_id."""
115-
session_prefix = f"{self._prefix}session:"
116-
if not channel.startswith(session_prefix):
117-
return None
118-
119-
# Extract exactly what should be a UUID hex after the prefix
120-
session_hex = channel[len(session_prefix):]
121-
if len(session_hex) != 32: # Standard UUID hex length
122-
logger.error(f"Invalid UUID length in session channel: {channel}")
123-
return None
124-
125-
try:
126-
session_id = UUID(hex=session_hex)
127-
expected_channel = self._session_channel(session_id)
128-
if channel != expected_channel:
129-
logger.error(f"Channel mismatch: got {channel}, expected {expected_channel}")
130-
return None
131-
return session_id
132-
except ValueError:
133-
logger.error(f"Invalid UUID hex in session channel: {channel}")
134-
return None
71+
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
13572

13673
async def _listen_for_messages(self) -> None:
13774
"""Background task that listens for messages on subscribed channels."""
13875
async with self._limiter:
13976
while True:
14077
await lowlevel.checkpoint()
141-
# Shield message retrieval from cancellation to ensure no messages are
142-
# lost when a session disconnects during processing.
143-
with CancelScope(shield=True):
144-
redis_message: ( # type: ignore
145-
None | dict[str, Any]
146-
) = await self._pubsub.get_message( # type: ignore
147-
ignore_subscribe_messages=True,
148-
timeout=0.1, # type: ignore
149-
)
150-
if redis_message is None:
151-
continue
152-
153-
channel: str = cast(str, redis_message["channel"])
154-
data: str = cast(str, redis_message["data"])
155-
156-
# Determine which session this message is for
157-
session_id = None
158-
if channel.startswith(f"{self._prefix}ack:"):
159-
session_id = self._parse_ack_channel(channel)
160-
elif channel.startswith(f"{self._prefix}session:"):
161-
session_id = self._parse_session_channel(channel)
162-
163-
if session_id is None:
164-
logger.debug(f"Ignoring message from channel: {channel}")
78+
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
79+
ignore_subscribe_messages=True,
80+
timeout=None, # type: ignore
81+
)
82+
if message is None:
83+
continue
84+
85+
channel: str = cast(str, message["channel"])
86+
expected_prefix = f"{self._prefix}session:"
87+
88+
if not channel.startswith(expected_prefix):
89+
logger.debug(f"Ignoring message from non-MCP channel: {channel}")
90+
continue
91+
92+
session_hex = channel[len(expected_prefix) :]
93+
try:
94+
session_id = UUID(hex=session_hex)
95+
expected_channel = self._session_channel(session_id)
96+
if channel != expected_channel:
97+
logger.error(f"Channel format mismatch: {channel}")
16598
continue
166-
167-
if session_id not in self._handlers:
168-
logger.warning(f"Dropping message for non-existent session: {session_id}")
99+
except ValueError:
100+
logger.error(f"Invalid UUID in channel: {channel}")
101+
continue
102+
103+
data: str = cast(str, message["data"])
104+
try:
105+
if session_id not in self._callbacks:
106+
logger.warning(f"Message dropped: no callback for {session_id}")
169107
continue
170-
171-
session_tg = self._handlers[session_id]
172-
if channel.startswith(f"{self._prefix}ack:"):
173-
session_tg.start_soon(self._handle_ack_message, channel, data)
174-
else:
175-
session_tg.start_soon(self._handle_session_message, channel, data)
176-
177-
async def _handle_ack_message(self, channel: str, data: str) -> None:
178-
"""Handle acknowledgment messages received on ack channels."""
179-
session_id = self._parse_ack_channel(channel)
180-
if session_id is None:
181-
return
182-
183-
# Extract message ID from data
184-
message_id = data.strip()
185-
if message_id in self._ack_events:
186-
logger.debug(f"Received acknowledgment for message: {message_id}")
187-
self._ack_events[message_id].set()
188-
189-
async def _handle_session_message(self, channel: str, data: str) -> None:
190-
"""Handle regular messages received on session channels."""
191-
session_id = self._parse_session_channel(channel)
192-
if session_id is None:
193-
return
194108

195-
if session_id not in self._callbacks:
196-
logger.warning(f"Message dropped: no callback for {session_id}")
197-
return
198-
199-
try:
200-
wrapper = MessageWrapper.model_validate_json(data)
201-
result = wrapper.get_json_rpc_message()
202-
await self._callbacks[session_id](result)
203-
await self._send_acknowledgment(session_id, wrapper.message_id)
204-
205-
except Exception as e:
206-
logger.error(f"Error processing message for {session_id}: {e}")
207-
208-
async def _send_acknowledgment(self, session_id: UUID, message_id: str) -> None:
209-
"""Send an acknowledgment for a message that was successfully processed."""
210-
ack_channel = self._ack_channel(session_id)
211-
await self._redis.publish(ack_channel, message_id) # type: ignore
212-
logger.debug(
213-
f"Sent acknowledgment for message {message_id} to session {session_id}"
214-
)
109+
# Try to parse as valid message or recreate original ValidationError
110+
try:
111+
msg = types.JSONRPCMessage.model_validate_json(data)
112+
await self._callbacks[session_id](msg)
113+
except ValidationError as exc:
114+
# Pass the identical validation error that would have occurred
115+
await self._callbacks[session_id](exc)
116+
except Exception as e:
117+
logger.error(f"Error processing message for {session_id}: {e}")
215118

216119
async def publish_message(
217-
self,
218-
session_id: UUID,
219-
message: types.JSONRPCMessage | str,
220-
message_id: str | None = None,
221-
) -> str | None:
120+
self, session_id: UUID, message: types.JSONRPCMessage | str
121+
) -> bool:
222122
"""Publish a message for the specified session."""
223123
if not await self.session_exists(session_id):
224124
logger.warning(f"Message dropped: unknown session {session_id}")
225-
return None
125+
return False
226126

227127
# Pass raw JSON strings directly, preserving validation errors
228-
message_id = message_id or str(uuid4())
229128
if isinstance(message, str):
230-
wrapper = MessageWrapper(message_id=message_id, payload=message)
129+
data = message
231130
else:
232-
wrapper = MessageWrapper(
233-
message_id=message_id, payload=message.model_dump_json()
234-
)
131+
data = message.model_dump_json()
235132

236133
channel = self._session_channel(session_id)
237-
await self._redis.publish(channel, wrapper.model_dump_json()) # type: ignore
238-
logger.debug(
239-
f"Message {message_id} published to Redis channel for session {session_id}"
240-
)
241-
return message_id
242-
243-
async def publish_message_sync(
244-
self,
245-
session_id: UUID,
246-
message: types.JSONRPCMessage | str,
247-
timeout: float = 120.0,
248-
) -> bool:
249-
"""Publish a message and wait for acknowledgment of processing."""
250-
message_id = str(uuid4())
251-
ack_event = Event()
252-
self._ack_events[message_id] = ack_event
253-
254-
try:
255-
published_id = await self.publish_message(session_id, message, message_id)
256-
if published_id is None:
257-
return False
258-
259-
with anyio.fail_after(timeout):
260-
await ack_event.wait()
261-
logger.debug(f"Received acknowledgment for message {message_id}")
262-
return True
263-
264-
except TimeoutError:
265-
logger.warning(
266-
f"Timed out waiting for acknowledgment of message {message_id}"
267-
)
268-
return False
269-
270-
finally:
271-
if message_id in self._ack_events:
272-
del self._ack_events[message_id]
134+
await self._redis.publish(channel, data) # type: ignore[attr-defined]
135+
logger.debug(f"Message published to Redis channel for session {session_id}")
136+
return True
273137

274138
async def session_exists(self, session_id: UUID) -> bool:
275139
"""Check if a session exists."""

0 commit comments

Comments
 (0)