Skip to content

Commit d625782

Browse files
committed
fixes
1 parent efe6da9 commit d625782

File tree

5 files changed

+61
-204
lines changed

5 files changed

+61
-204
lines changed

README.md

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -397,22 +397,23 @@ For more information on mounting applications in Starlette, see the [Starlette d
397397
By default, the SSE server uses an in-memory message queue for incoming POST messages. For production deployments or distributed scenarios, you can use Redis:
398398

399399
```python
400+
# Using the built-in Redis message queue
400401
from mcp.server.fastmcp import FastMCP
402+
from mcp.server.message_queue import RedisMessageQueue
401403

402-
mcp = FastMCP(
403-
"My App",
404-
settings={
405-
"message_queue": "redis",
406-
"redis_url": "redis://localhost:6379/0",
407-
"redis_prefix": "mcp:queue:",
408-
},
404+
# Create a Redis message queue
405+
redis_queue = RedisMessageQueue(
406+
redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:"
409407
)
408+
409+
# Pass the message queue instance to the server
410+
mcp = FastMCP("My App", message_queue=redis_queue)
410411
```
411412

412413
To use Redis, add the Redis dependency:
413414

414415
```bash
415-
uv add "mcp[redis]"
416+
uv add redis types-redis
416417
```
417418

418419
## Examples

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ dependencies = [
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
33-
"redis==5.2.1",
34-
"types-redis==4.6.0.20241004",
3533
]
3634

3735
[project.optional-dependencies]

src/mcp/server/fastmcp/server.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from mcp.types import Resource as MCPResource
5050
from mcp.types import ResourceTemplate as MCPResourceTemplate
5151
from mcp.types import Tool as MCPTool
52+
from mcp.server.message_queue import MessageQueue
5253

5354
logger = get_logger(__name__)
5455

@@ -77,9 +78,7 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
7778
message_path: str = "/messages/"
7879

7980
# SSE message queue settings
80-
message_queue: Literal["memory", "redis"] = "memory"
81-
redis_url: str = "redis://localhost:6379/0"
82-
redis_prefix: str = "mcp:queue:"
81+
message_queue: MessageQueue | None = Field(None, description="Custom message queue instance")
8382

8483
# resource settings
8584
warn_on_duplicate_resources: bool = True
@@ -484,25 +483,14 @@ async def run_sse_async(self) -> None:
484483

485484
def sse_app(self) -> Starlette:
486485
"""Return an instance of the SSE server app."""
487-
message_queue = None
488-
if self.settings.message_queue == "redis":
489-
try:
490-
from mcp.server.message_queue import RedisMessageQueue
491-
492-
message_queue = RedisMessageQueue(
493-
redis_url=self.settings.redis_url, prefix=self.settings.redis_prefix
494-
)
495-
logger.info(f"Using Redis message queue at {self.settings.redis_url}")
496-
except ImportError:
497-
logger.error(
498-
"Redis message queue requested but 'redis' package not installed. "
499-
)
500-
raise
501-
else:
486+
# Use a custom provided message queue if available
487+
message_queue = self.settings.message_queue
488+
489+
# If no message queue is provided, create an in-memory queue as default
490+
if message_queue is None:
502491
from mcp.server.message_queue import InMemoryMessageQueue
503-
504492
message_queue = InMemoryMessageQueue()
505-
logger.info("Using in-memory message queue")
493+
logger.info("Using default in-memory message queue")
506494

507495
sse = SseServerTransport(
508496
self.settings.message_path, message_queue=message_queue

src/mcp/server/message_queue/redis.py

Lines changed: 41 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import json
22
import logging
3+
from contextlib import asynccontextmanager
4+
from typing import Any, cast
35
from uuid import UUID
46

57
import anyio
6-
from anyio import CapacityLimiter
8+
from anyio import CapacityLimiter, from_thread
79
import mcp.types as types
810
from mcp.server.message_queue.base import MessageCallback
9-
from typing import Any, cast
10-
from anyio import from_thread
11-
from contextlib import asynccontextmanager
1211

1312

1413
try:
@@ -38,8 +37,8 @@ def __init__(
3837
redis_url: Redis connection string
3938
prefix: Key prefix for Redis channels to avoid collisions
4039
"""
41-
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
42-
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
40+
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
41+
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
4342
self._prefix = prefix
4443
self._active_sessions_key = f"{prefix}active_sessions"
4544
self._callbacks: dict[UUID, MessageCallback] = {}
@@ -53,11 +52,10 @@ def _session_channel(self, session_id: UUID) -> str:
5352
@asynccontextmanager
5453
async def active_for_request(self, session_id: UUID, callback: MessageCallback):
5554
"""Request-scoped context manager that ensures the listener task is running."""
56-
5755
await self._redis.sadd(self._active_sessions_key, session_id.hex)
5856
self._callbacks[session_id] = callback
5957
channel = self._session_channel(session_id)
60-
await self._pubsub.subscribe(channel) # type: ignore
58+
await self._pubsub.subscribe(channel) # type: ignore
6159

6260
logger.debug(f"Registered session {session_id} in Redis with callback")
6361
async with anyio.create_task_group() as tg:
@@ -66,7 +64,7 @@ async def active_for_request(self, session_id: UUID, callback: MessageCallback):
6664
yield
6765
finally:
6866
tg.cancel_scope.cancel()
69-
await self._pubsub.unsubscribe(channel) # type: ignore
67+
await self._pubsub.unsubscribe(channel) # type: ignore
7068
await self._redis.srem(self._active_sessions_key, session_id.hex)
7169
del self._callbacks[session_id]
7270
logger.debug(f"Unregistered session {session_id} from Redis")
@@ -75,40 +73,41 @@ async def _listen_for_messages(self) -> None:
7573
"""Background task that listens for messages on subscribed channels."""
7674
async with self._limiter:
7775
while True:
78-
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
76+
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
7977
ignore_subscribe_messages=True
8078
)
81-
if message is not None:
82-
# Extract session ID from channel name
83-
channel: str = cast(str, message["channel"])
84-
if not channel.startswith(self._prefix):
85-
continue
86-
87-
session_hex = channel.split(":")[-1]
88-
try:
89-
session_id = UUID(hex=session_hex)
90-
except ValueError:
91-
logger.error(f"Invalid session channel: {channel}")
92-
continue
93-
94-
# Deserialize the message
95-
data: str = cast(str, message["data"])
96-
msg: None | types.JSONRPCMessage | Exception = None
97-
try:
98-
json_data = json.loads(data)
99-
if isinstance(json_data, dict):
100-
json_dict: dict[str, Any] = json_data
101-
if json_dict.get("_exception", False):
102-
msg = Exception(
103-
f"{json_dict['type']}: {json_dict['message']}"
104-
)
105-
else:
106-
msg = types.JSONRPCMessage.model_validate_json(data)
107-
108-
if msg and session_id in self._callbacks:
109-
from_thread.run(self._callbacks[session_id], msg)
110-
except Exception as e:
111-
logger.error(f"Failed to process message: {e}")
79+
if message is None:
80+
continue
81+
82+
# Extract session ID from channel name
83+
channel: str = cast(str, message["channel"])
84+
if not channel.startswith(self._prefix):
85+
continue
86+
87+
session_hex = channel.split(":")[-1]
88+
try:
89+
session_id = UUID(hex=session_hex)
90+
except ValueError:
91+
logger.error(f"Invalid session channel: {channel}")
92+
continue
93+
94+
data: str = cast(str, message["data"])
95+
msg: None | types.JSONRPCMessage | Exception = None
96+
try:
97+
json_data = json.loads(data)
98+
if isinstance(json_data, dict):
99+
json_dict: dict[str, Any] = json_data
100+
if json_dict.get("_exception", False):
101+
msg = Exception(
102+
f"{json_dict['type']}: {json_dict['message']}"
103+
)
104+
else:
105+
msg = types.JSONRPCMessage.model_validate_json(data)
106+
107+
if msg and session_id in self._callbacks:
108+
from_thread.run(self._callbacks[session_id], msg)
109+
except Exception as e:
110+
logger.error(f"Failed to process message: {e}")
112111

113112
async def publish_message(
114113
self, session_id: UUID, message: types.JSONRPCMessage | Exception
@@ -136,8 +135,6 @@ async def publish_message(
136135

137136
async def session_exists(self, session_id: UUID) -> bool:
138137
"""Check if a session exists."""
139-
# Explicitly annotate the result as bool to help the type checker
140-
result = bool(
138+
return bool(
141139
await self._redis.sismember(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
142140
)
143-
return result

0 commit comments

Comments
 (0)