22import logging
33from uuid import UUID
44
5+ import anyio
6+ from anyio import CapacityLimiter
57import mcp .types as types
8+ from mcp .server .message_queue .base import MessageCallback
9+ from typing import Any , cast
10+ from anyio import from_thread
11+ from contextlib import asynccontextmanager
12+
613
714try :
8- import redis .asyncio as redis # type: ignore[import]
15+ import redis .asyncio as redis
916except ImportError :
1017 raise ImportError (
1118 "Redis support requires the 'redis' package. "
1623
1724
1825class RedisMessageQueue :
19- """Redis implementation of the MessageQueue interface.
26+ """Redis implementation of the MessageQueue interface using pubsub .
2027
21- This implementation uses Redis lists to store messages for each session.
22- Redis provides persistence and allows multiple servers to share the same queue .
28+ This implementation uses Redis pubsub for real-time message distribution across
29+ multiple servers handling the same sessions .
2330 """
2431
2532 def __init__ (
26- self , redis_url : str = "redis://localhost:6379/0" , prefix : str = "mcp:queue :"
33+ self , redis_url : str = "redis://localhost:6379/0" , prefix : str = "mcp:pubsub :"
2734 ) -> None :
2835 """Initialize Redis message queue.
2936
3037 Args:
3138 redis_url: Redis connection string
32- prefix: Key prefix for Redis keys to avoid collisions
39+ prefix: Key prefix for Redis channels to avoid collisions
3340 """
34- self ._redis = redis .Redis .from_url (redis_url , decode_responses = True ) # type: ignore[attr-defined]
41+ self ._redis = redis .from_url (redis_url , decode_responses = True ) # type: ignore
42+ self ._pubsub = self ._redis .pubsub (ignore_subscribe_messages = True ) # type: ignore
3543 self ._prefix = prefix
3644 self ._active_sessions_key = f"{ prefix } active_sessions"
45+ self ._callbacks : dict [UUID , MessageCallback ] = {}
46+ self ._limiter = CapacityLimiter (1 )
3747 logger .debug (f"Initialized Redis message queue with URL: { redis_url } " )
3848
39- def _session_queue_key (self , session_id : UUID ) -> str :
40- """Get the Redis key for a session's message queue ."""
49+ def _session_channel (self , session_id : UUID ) -> str :
50+ """Get the Redis channel for a session."""
4151 return f"{ self ._prefix } session:{ session_id .hex } "
4252
43- async def add_message (
53+ @asynccontextmanager
54+ async def active_for_request (self , session_id : UUID , callback : MessageCallback ):
55+ """Request-scoped context manager that ensures the listener task is running."""
56+
57+ await self ._redis .sadd (self ._active_sessions_key , session_id .hex )
58+ self ._callbacks [session_id ] = callback
59+ channel = self ._session_channel (session_id )
60+ await self ._pubsub .subscribe (channel ) # type: ignore
61+
62+ logger .debug (f"Registered session { session_id } in Redis with callback" )
63+ async with anyio .create_task_group () as tg :
64+ tg .start_soon (self ._listen_for_messages )
65+ try :
66+ yield
67+ finally :
68+ tg .cancel_scope .cancel ()
69+ await self ._pubsub .unsubscribe (channel ) # type: ignore
70+ await self ._redis .srem (self ._active_sessions_key , session_id .hex )
71+ del self ._callbacks [session_id ]
72+ logger .debug (f"Unregistered session { session_id } from Redis" )
73+
74+ async def _listen_for_messages (self ) -> None :
75+ """Background task that listens for messages on subscribed channels."""
76+ async with self ._limiter :
77+ while True :
78+ message : None | dict [str , Any ] = await self ._pubsub .get_message ( # type: ignore
79+ ignore_subscribe_messages = True
80+ )
81+ if message is not None :
82+ # Extract session ID from channel name
83+ channel : str = cast (str , message ["channel" ])
84+ if not channel .startswith (self ._prefix ):
85+ continue
86+
87+ session_hex = channel .split (":" )[- 1 ]
88+ try :
89+ session_id = UUID (hex = session_hex )
90+ except ValueError :
91+ logger .error (f"Invalid session channel: { channel } " )
92+ continue
93+
94+ # Deserialize the message
95+ data : str = cast (str , message ["data" ])
96+ msg : None | types .JSONRPCMessage | Exception = None
97+ try :
98+ json_data = json .loads (data )
99+ if isinstance (json_data , dict ):
100+ json_dict : dict [str , Any ] = json_data
101+ if json_dict .get ("_exception" , False ):
102+ msg = Exception (
103+ f"{ json_dict ['type' ]} : { json_dict ['message' ]} "
104+ )
105+ else :
106+ msg = types .JSONRPCMessage .model_validate_json (data )
107+
108+ if msg and session_id in self ._callbacks :
109+ from_thread .run (self ._callbacks [session_id ], msg )
110+ except Exception as e :
111+ logger .error (f"Failed to process message: { e } " )
112+
113+ async def publish_message (
44114 self , session_id : UUID , message : types .JSONRPCMessage | Exception
45115 ) -> bool :
46- """Add a message to the queue for the specified session."""
47- # Check if session exists
116+ """Publish a message for the specified session."""
48117 if not await self .session_exists (session_id ):
49118 logger .warning (f"Message received for unknown session { session_id } " )
50119 return False
51120
52- # Serialize the message
53121 if isinstance (message , Exception ):
54- # For exceptions, store them as special format
55122 data = json .dumps (
56123 {
57124 "_exception" : True ,
@@ -60,63 +127,13 @@ async def add_message(
60127 }
61128 )
62129 else :
63- data = message .model_dump_json (by_alias = True , exclude_none = True )
130+ data = message .model_dump_json ()
64131
65- # Push to the right side of the list (queue )
66- await self ._redis .rpush ( self . _session_queue_key ( session_id ) , data ) # type: ignore[attr-defined]
67- logger .debug (f"Added message to Redis queue for session { session_id } " )
132+ channel = self . _session_channel ( session_id )
133+ await self ._redis .publish ( channel , data ) # type: ignore[attr-defined]
134+ logger .debug (f"Published message to Redis channel for session { session_id } " )
68135 return True
69136
70- async def get_message (
71- self , session_id : UUID , timeout : float = 0.1
72- ) -> types .JSONRPCMessage | Exception | None :
73- """Get the next message for the specified session."""
74- # Check if session exists
75- if not await self .session_exists (session_id ):
76- return None
77-
78- # Pop from the left side of the list (queue)
79- # Use BLPOP with timeout to avoid busy waiting
80- result = await self ._redis .blpop ([self ._session_queue_key (session_id )], timeout ) # type: ignore[attr-defined]
81-
82- if not result :
83- return None
84-
85- # result is a tuple of (key, value)
86- _ , data = result # type: ignore[misc]
87-
88- # Deserialize the message
89- json_data = json .loads (data ) # type: ignore[arg-type]
90-
91- # Check if it's an exception
92- if isinstance (json_data , dict ):
93- exception_dict : dict [str , object ] = json_data
94- if exception_dict .get ("_exception" , False ):
95- return Exception (
96- f"{ exception_dict ['type' ]} : { exception_dict ['message' ]} "
97- )
98-
99- # Regular message
100- try :
101- return types .JSONRPCMessage .model_validate_json (data ) # type: ignore[arg-type]
102- except Exception as e :
103- logger .error (f"Failed to deserialize message: { e } " )
104- return None
105-
106- async def register_session (self , session_id : UUID ) -> None :
107- """Register a new session with the queue."""
108- # Add session ID to the set of active sessions
109- await self ._redis .sadd (self ._active_sessions_key , session_id .hex ) # type: ignore[attr-defined]
110- logger .debug (f"Registered session { session_id } in Redis" )
111-
112- async def unregister_session (self , session_id : UUID ) -> None :
113- """Unregister a session when it's closed."""
114- # Remove session ID from active sessions
115- await self ._redis .srem (self ._active_sessions_key , session_id .hex ) # type: ignore[attr-defined]
116- # Delete the session's message queue
117- await self ._redis .delete (self ._session_queue_key (session_id )) # type: ignore[attr-defined]
118- logger .debug (f"Unregistered session { session_id } from Redis" )
119-
120137 async def session_exists (self , session_id : UUID ) -> bool :
121138 """Check if a session exists."""
122139 # Explicitly annotate the result as bool to help the type checker
0 commit comments