Skip to content

Commit b2893e6

Browse files
committed
push message handling onto corresponding SSE session task group
1 parent 8fce8e6 commit b2893e6

File tree

1 file changed

+62
-37
lines changed

1 file changed

+62
-37
lines changed

src/mcp/server/message_queue/redis.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from uuid import UUID
55

66
import anyio
7-
from anyio import CapacityLimiter, lowlevel
7+
from anyio import CancelScope, CapacityLimiter, lowlevel
8+
from anyio.abc import TaskGroup
89
from pydantic import ValidationError
910

1011
import mcp.types as types
@@ -42,6 +43,7 @@ def __init__(
4243
self._prefix = prefix
4344
self._active_sessions_key = f"{prefix}active_sessions"
4445
self._callbacks: dict[UUID, MessageCallback] = {}
46+
self._task_groups: dict[UUID, TaskGroup] = {}
4547
# Ensures only one polling task runs at a time for message handling
4648
self._limiter = CapacityLimiter(1)
4749
logger.debug(f"Redis message dispatch initialized: {redis_url}")
@@ -60,6 +62,7 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6062

6163
logger.debug(f"Subscribing to Redis channel for session {session_id}")
6264
async with anyio.create_task_group() as tg:
65+
self._task_groups[session_id] = tg
6366
tg.start_soon(self._listen_for_messages)
6467
try:
6568
yield
@@ -68,53 +71,75 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6871
await self._pubsub.unsubscribe(channel) # type: ignore
6972
await self._redis.srem(self._active_sessions_key, session_id.hex)
7073
del self._callbacks[session_id]
74+
del self._task_groups[session_id]
7175
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
7276

77+
def _extract_session_id(self, channel: str) -> UUID | None:
78+
"""Extract and validate session ID from channel."""
79+
expected_prefix = f"{self._prefix}session:"
80+
if not channel.startswith(expected_prefix):
81+
return None
82+
83+
session_hex = channel[len(expected_prefix) :]
84+
try:
85+
session_id = UUID(hex=session_hex)
86+
if channel != self._session_channel(session_id):
87+
logger.error(f"Channel format mismatch: {channel}")
88+
return None
89+
return session_id
90+
except ValueError:
91+
logger.error(f"Invalid UUID in channel: {channel}")
92+
return None
93+
7394
async def _listen_for_messages(self) -> None:
7495
"""Background task that listens for messages on subscribed channels."""
7596
async with self._limiter:
7697
while True:
7798
await lowlevel.checkpoint()
78-
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
79-
ignore_subscribe_messages=True,
80-
timeout=None, # type: ignore
81-
)
82-
if message is None:
83-
continue
84-
85-
channel: str = cast(str, message["channel"])
86-
expected_prefix = f"{self._prefix}session:"
87-
88-
if not channel.startswith(expected_prefix):
89-
logger.debug(f"Ignoring message from non-MCP channel: {channel}")
90-
continue
91-
92-
session_hex = channel[len(expected_prefix) :]
93-
try:
94-
session_id = UUID(hex=session_hex)
95-
expected_channel = self._session_channel(session_id)
96-
if channel != expected_channel:
97-
logger.error(f"Channel format mismatch: {channel}")
99+
with CancelScope(shield=True):
100+
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
101+
ignore_subscribe_messages=True,
102+
timeout=0.1, # type: ignore
103+
)
104+
if message is None:
98105
continue
99-
except ValueError:
100-
logger.error(f"Invalid UUID in channel: {channel}")
101-
continue
102-
103-
data: str = cast(str, message["data"])
104-
try:
105-
if session_id not in self._callbacks:
106-
logger.warning(f"Message dropped: no callback for {session_id}")
106+
107+
channel: str = cast(str, message["channel"])
108+
session_id = self._extract_session_id(channel)
109+
if session_id is None:
110+
logger.debug(f"Ignoring message from non-MCP channel: {channel}")
107111
continue
108112

109-
# Try to parse as valid message or recreate original ValidationError
113+
data: str = cast(str, message["data"])
110114
try:
111-
msg = types.JSONRPCMessage.model_validate_json(data)
112-
await self._callbacks[session_id](msg)
113-
except ValidationError as exc:
114-
# Pass the identical validation error that would have occurred
115-
await self._callbacks[session_id](exc)
116-
except Exception as e:
117-
logger.error(f"Error processing message for {session_id}: {e}")
115+
if session_id in self._task_groups:
116+
self._task_groups[session_id].start_soon(
117+
self._handle_message, session_id, data
118+
)
119+
else:
120+
logger.warning(
121+
f"Message dropped: no task group for session: {session_id}"
122+
)
123+
except Exception as e:
124+
logger.error(f"Error processing message for {session_id}: {e}")
125+
126+
async def _handle_message(self, session_id: UUID, data: str) -> None:
127+
"""Process a message from Redis in the session's task group."""
128+
if session_id not in self._callbacks:
129+
logger.warning(f"Message dropped: callback removed for {session_id}")
130+
return
131+
132+
try:
133+
# Parse message or pass validation error to callback
134+
msg_or_error = None
135+
try:
136+
msg_or_error = types.JSONRPCMessage.model_validate_json(data)
137+
except ValidationError as exc:
138+
msg_or_error = exc
139+
140+
await self._callbacks[session_id](msg_or_error)
141+
except Exception as e:
142+
logger.error(f"Error in message handler for {session_id}: {e}")
118143

119144
async def publish_message(
120145
self, session_id: UUID, message: types.JSONRPCMessage | str

0 commit comments

Comments
 (0)