Skip to content

Commit efe6da9

Browse files
committed
wip
1 parent fb44020 commit efe6da9

File tree

5 files changed

+423
-287
lines changed

5 files changed

+423
-287
lines changed

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,19 @@ dependencies = [
2525
"anyio>=4.5",
2626
"httpx>=0.27",
2727
"httpx-sse>=0.4",
28-
"pydantic>=2.7.2,<=2.10.1",
28+
"pydantic>=2.7.2,<3.0.0",
2929
"starlette>=0.27",
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
33+
"redis==5.2.1",
34+
"types-redis==4.6.0.20241004",
3335
]
3436

3537
[project.optional-dependencies]
3638
rich = ["rich>=13.9.4"]
3739
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
3840
ws = ["websockets>=15.0.1"]
39-
redis = ["redis>=4.5.0"]
4041

4142
[project.scripts]
4243
mcp = "mcp.cli:app [cli]"

src/mcp/server/message_queue/base.py

Lines changed: 43 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,46 @@
11
import logging
2-
from typing import Protocol, runtime_checkable
2+
from typing import Protocol, runtime_checkable, Callable, Awaitable
33
from uuid import UUID
4+
from contextlib import asynccontextmanager
45

56
import mcp.types as types
67

78
logger = logging.getLogger(__name__)
89

10+
MessageCallback = Callable[[types.JSONRPCMessage | Exception], Awaitable[None]]
11+
912

1013
@runtime_checkable
1114
class MessageQueue(Protocol):
12-
"""Abstract interface for an SSE message queue.
15+
"""Abstract interface for SSE messaging.
1316
14-
This interface allows messages to be queued and processed by any SSE server instance
15-
enabling multiple servers to handle requests for the same session.
17+
This interface allows messages to be published to sessions and callbacks to be
18+
registered for message handling, enabling multiple servers to handle requests.
1619
"""
1720

18-
async def add_message(
21+
async def publish_message(
1922
self, session_id: UUID, message: types.JSONRPCMessage | Exception
2023
) -> bool:
21-
"""Add a message to the queue for the specified session.
24+
"""Publish a message for the specified session.
2225
2326
Args:
2427
session_id: The UUID of the session this message is for
25-
message: The message to queue
26-
27-
Returns:
28-
bool: True if message was accepted, False if session not found
29-
"""
30-
...
31-
32-
async def get_message(
33-
self, session_id: UUID, timeout: float = 0.1
34-
) -> types.JSONRPCMessage | Exception | None:
35-
"""Get the next message for the specified session.
36-
37-
Args:
38-
session_id: The UUID of the session to get messages for
39-
timeout: Maximum time to wait for a message, in seconds
28+
message: The message to publish
4029
4130
Returns:
42-
The next message or None if no message is available
43-
"""
44-
...
45-
46-
async def register_session(self, session_id: UUID) -> None:
47-
"""Register a new session with the queue.
48-
49-
Args:
50-
session_id: The UUID of the new session to register
31+
bool: True if message was published, False if session not found
5132
"""
5233
...
5334

54-
async def unregister_session(self, session_id: UUID) -> None:
55-
"""Unregister a session when it's closed.
56-
35+
@asynccontextmanager
36+
async def active_for_request(self, session_id: UUID, callback: MessageCallback):
37+
"""Request-scoped context manager that ensures the listener is active.
38+
5739
Args:
58-
session_id: The UUID of the session to unregister
40+
session_id: The UUID of the session to activate
41+
callback: Async callback function to handle messages for this session
5942
"""
60-
...
43+
yield
6144

6245
async def session_exists(self, session_id: UUID) -> bool:
6346
"""Check if a session exists.
@@ -74,57 +57,44 @@ async def session_exists(self, session_id: UUID) -> bool:
7457
class InMemoryMessageQueue:
7558
"""Default in-memory implementation of the MessageQueue interface.
7659
77-
This implementation keeps messages in memory for
78-
each session until they're retrieved.
60+
This implementation immediately calls registered callbacks when messages are received.
7961
"""
8062

8163
def __init__(self) -> None:
82-
self._message_queues: dict[UUID, list[types.JSONRPCMessage | Exception]] = {}
64+
self._callbacks: dict[UUID, MessageCallback] = {}
8365
self._active_sessions: set[UUID] = set()
8466

85-
async def add_message(
67+
async def publish_message(
8668
self, session_id: UUID, message: types.JSONRPCMessage | Exception
8769
) -> bool:
88-
"""Add a message to the queue for the specified session."""
89-
if session_id not in self._active_sessions:
70+
"""Publish a message for the specified session."""
71+
if not await self.session_exists(session_id):
9072
logger.warning(f"Message received for unknown session {session_id}")
9173
return False
9274

93-
if session_id not in self._message_queues:
94-
self._message_queues[session_id] = []
95-
96-
self._message_queues[session_id].append(message)
97-
logger.debug(f"Added message to queue for session {session_id}")
75+
# Call the callback directly if registered
76+
if session_id in self._callbacks:
77+
await self._callbacks[session_id](message)
78+
logger.debug(f"Called callback for session {session_id}")
79+
else:
80+
logger.warning(f"No callback registered for session {session_id}")
81+
9882
return True
9983

100-
async def get_message(
101-
self, session_id: UUID, timeout: float = 0.1
102-
) -> types.JSONRPCMessage | Exception | None:
103-
"""Get the next message for the specified session."""
104-
if session_id not in self._active_sessions:
105-
return None
106-
107-
queue = self._message_queues.get(session_id, [])
108-
if not queue:
109-
return None
110-
111-
message = queue.pop(0)
112-
if not queue: # Clean up empty queue
113-
del self._message_queues[session_id]
114-
115-
return message
116-
117-
async def register_session(self, session_id: UUID) -> None:
118-
"""Register a new session with the queue."""
84+
@asynccontextmanager
85+
async def active_for_request(self, session_id: UUID, callback: MessageCallback):
86+
"""Request-scoped context manager that ensures the listener is active."""
11987
self._active_sessions.add(session_id)
120-
logger.debug(f"Registered session {session_id}")
121-
122-
async def unregister_session(self, session_id: UUID) -> None:
123-
"""Unregister a session when it's closed."""
124-
self._active_sessions.discard(session_id)
125-
if session_id in self._message_queues:
126-
del self._message_queues[session_id]
127-
logger.debug(f"Unregistered session {session_id}")
88+
self._callbacks[session_id] = callback
89+
logger.debug(f"Registered session {session_id} with callback")
90+
91+
try:
92+
yield
93+
finally:
94+
self._active_sessions.discard(session_id)
95+
if session_id in self._callbacks:
96+
del self._callbacks[session_id]
97+
logger.debug(f"Unregistered session {session_id}")
12898

12999
async def session_exists(self, session_id: UUID) -> bool:
130100
"""Check if a session exists."""

src/mcp/server/message_queue/redis.py

Lines changed: 85 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,17 @@
22
import logging
33
from uuid import UUID
44

5+
import anyio
6+
from anyio import CapacityLimiter
57
import mcp.types as types
8+
from mcp.server.message_queue.base import MessageCallback
9+
from typing import Any, cast
10+
from anyio import from_thread
11+
from contextlib import asynccontextmanager
12+
613

714
try:
8-
import redis.asyncio as redis # type: ignore[import]
15+
import redis.asyncio as redis
916
except ImportError:
1017
raise ImportError(
1118
"Redis support requires the 'redis' package. "
@@ -16,42 +23,102 @@
1623

1724

1825
class RedisMessageQueue:
19-
"""Redis implementation of the MessageQueue interface.
26+
"""Redis implementation of the MessageQueue interface using pubsub.
2027
21-
This implementation uses Redis lists to store messages for each session.
22-
Redis provides persistence and allows multiple servers to share the same queue.
28+
This implementation uses Redis pubsub for real-time message distribution across
29+
multiple servers handling the same sessions.
2330
"""
2431

2532
def __init__(
26-
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:queue:"
33+
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:pubsub:"
2734
) -> None:
2835
"""Initialize Redis message queue.
2936
3037
Args:
3138
redis_url: Redis connection string
32-
prefix: Key prefix for Redis keys to avoid collisions
39+
prefix: Key prefix for Redis channels to avoid collisions
3340
"""
34-
self._redis = redis.Redis.from_url(redis_url, decode_responses=True) # type: ignore[attr-defined]
41+
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
42+
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
3543
self._prefix = prefix
3644
self._active_sessions_key = f"{prefix}active_sessions"
45+
self._callbacks: dict[UUID, MessageCallback] = {}
46+
self._limiter = CapacityLimiter(1)
3747
logger.debug(f"Initialized Redis message queue with URL: {redis_url}")
3848

39-
def _session_queue_key(self, session_id: UUID) -> str:
40-
"""Get the Redis key for a session's message queue."""
49+
def _session_channel(self, session_id: UUID) -> str:
50+
"""Get the Redis channel for a session."""
4151
return f"{self._prefix}session:{session_id.hex}"
4252

43-
async def add_message(
53+
@asynccontextmanager
54+
async def active_for_request(self, session_id: UUID, callback: MessageCallback):
55+
"""Request-scoped context manager that ensures the listener task is running."""
56+
57+
await self._redis.sadd(self._active_sessions_key, session_id.hex)
58+
self._callbacks[session_id] = callback
59+
channel = self._session_channel(session_id)
60+
await self._pubsub.subscribe(channel) # type: ignore
61+
62+
logger.debug(f"Registered session {session_id} in Redis with callback")
63+
async with anyio.create_task_group() as tg:
64+
tg.start_soon(self._listen_for_messages)
65+
try:
66+
yield
67+
finally:
68+
tg.cancel_scope.cancel()
69+
await self._pubsub.unsubscribe(channel) # type: ignore
70+
await self._redis.srem(self._active_sessions_key, session_id.hex)
71+
del self._callbacks[session_id]
72+
logger.debug(f"Unregistered session {session_id} from Redis")
73+
74+
async def _listen_for_messages(self) -> None:
75+
"""Background task that listens for messages on subscribed channels."""
76+
async with self._limiter:
77+
while True:
78+
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
79+
ignore_subscribe_messages=True
80+
)
81+
if message is not None:
82+
# Extract session ID from channel name
83+
channel: str = cast(str, message["channel"])
84+
if not channel.startswith(self._prefix):
85+
continue
86+
87+
session_hex = channel.split(":")[-1]
88+
try:
89+
session_id = UUID(hex=session_hex)
90+
except ValueError:
91+
logger.error(f"Invalid session channel: {channel}")
92+
continue
93+
94+
# Deserialize the message
95+
data: str = cast(str, message["data"])
96+
msg: None | types.JSONRPCMessage | Exception = None
97+
try:
98+
json_data = json.loads(data)
99+
if isinstance(json_data, dict):
100+
json_dict: dict[str, Any] = json_data
101+
if json_dict.get("_exception", False):
102+
msg = Exception(
103+
f"{json_dict['type']}: {json_dict['message']}"
104+
)
105+
else:
106+
msg = types.JSONRPCMessage.model_validate_json(data)
107+
108+
if msg and session_id in self._callbacks:
109+
from_thread.run(self._callbacks[session_id], msg)
110+
except Exception as e:
111+
logger.error(f"Failed to process message: {e}")
112+
113+
async def publish_message(
44114
self, session_id: UUID, message: types.JSONRPCMessage | Exception
45115
) -> bool:
46-
"""Add a message to the queue for the specified session."""
47-
# Check if session exists
116+
"""Publish a message for the specified session."""
48117
if not await self.session_exists(session_id):
49118
logger.warning(f"Message received for unknown session {session_id}")
50119
return False
51120

52-
# Serialize the message
53121
if isinstance(message, Exception):
54-
# For exceptions, store them as special format
55122
data = json.dumps(
56123
{
57124
"_exception": True,
@@ -60,63 +127,13 @@ async def add_message(
60127
}
61128
)
62129
else:
63-
data = message.model_dump_json(by_alias=True, exclude_none=True)
130+
data = message.model_dump_json()
64131

65-
# Push to the right side of the list (queue)
66-
await self._redis.rpush(self._session_queue_key(session_id), data) # type: ignore[attr-defined]
67-
logger.debug(f"Added message to Redis queue for session {session_id}")
132+
channel = self._session_channel(session_id)
133+
await self._redis.publish(channel, data) # type: ignore[attr-defined]
134+
logger.debug(f"Published message to Redis channel for session {session_id}")
68135
return True
69136

70-
async def get_message(
71-
self, session_id: UUID, timeout: float = 0.1
72-
) -> types.JSONRPCMessage | Exception | None:
73-
"""Get the next message for the specified session."""
74-
# Check if session exists
75-
if not await self.session_exists(session_id):
76-
return None
77-
78-
# Pop from the left side of the list (queue)
79-
# Use BLPOP with timeout to avoid busy waiting
80-
result = await self._redis.blpop([self._session_queue_key(session_id)], timeout) # type: ignore[attr-defined]
81-
82-
if not result:
83-
return None
84-
85-
# result is a tuple of (key, value)
86-
_, data = result # type: ignore[misc]
87-
88-
# Deserialize the message
89-
json_data = json.loads(data) # type: ignore[arg-type]
90-
91-
# Check if it's an exception
92-
if isinstance(json_data, dict):
93-
exception_dict: dict[str, object] = json_data
94-
if exception_dict.get("_exception", False):
95-
return Exception(
96-
f"{exception_dict['type']}: {exception_dict['message']}"
97-
)
98-
99-
# Regular message
100-
try:
101-
return types.JSONRPCMessage.model_validate_json(data) # type: ignore[arg-type]
102-
except Exception as e:
103-
logger.error(f"Failed to deserialize message: {e}")
104-
return None
105-
106-
async def register_session(self, session_id: UUID) -> None:
107-
"""Register a new session with the queue."""
108-
# Add session ID to the set of active sessions
109-
await self._redis.sadd(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
110-
logger.debug(f"Registered session {session_id} in Redis")
111-
112-
async def unregister_session(self, session_id: UUID) -> None:
113-
"""Unregister a session when it's closed."""
114-
# Remove session ID from active sessions
115-
await self._redis.srem(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
116-
# Delete the session's message queue
117-
await self._redis.delete(self._session_queue_key(session_id)) # type: ignore[attr-defined]
118-
logger.debug(f"Unregistered session {session_id} from Redis")
119-
120137
async def session_exists(self, session_id: UUID) -> bool:
121138
"""Check if a session exists."""
122139
# Explicitly annotate the result as bool to help the type checker

0 commit comments

Comments
 (0)