Skip to content

Commit 09e0cab

Browse files
committed
logging improvements
1 parent 5111c92 commit 09e0cab

File tree

4 files changed

+59
-59
lines changed

4 files changed

+59
-59
lines changed
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
"""
2-
Message Queue Module for MCP Server
2+
Message Dispatch Module for MCP Server
33
4-
This module implements queue interfaces for handling
4+
This module implements dispatch interfaces for handling
55
messages between clients and servers.
66
"""
77

8-
from mcp.server.message_queue.base import InMemoryMessageQueue, MessageQueue
8+
from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch
99

1010
# Try to import Redis implementation if available
1111
try:
12-
from mcp.server.message_queue.redis import RedisMessageQueue
12+
from mcp.server.message_queue.redis import RedisMessageDispatch
1313
except ImportError:
14-
RedisMessageQueue = None
14+
RedisMessageDispatch = None
1515

16-
__all__ = ["MessageQueue", "InMemoryMessageQueue", "RedisMessageQueue"]
16+
__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"]

src/mcp/server/message_queue/base.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import asynccontextmanager
44
from typing import Protocol, runtime_checkable
55
from uuid import UUID
6+
from pydantic import ValidationError
67

78
import mcp.types as types
89

@@ -20,13 +21,13 @@ class MessageDispatch(Protocol):
2021
"""
2122

2223
async def publish_message(
23-
self, session_id: UUID, message: types.JSONRPCMessage | Exception
24+
self, session_id: UUID, message: types.JSONRPCMessage | str
2425
) -> bool:
2526
"""Publish a message for the specified session.
2627
2728
Args:
2829
session_id: The UUID of the session this message is for
29-
message: The message to publish
30+
message: The message to publish (JSONRPCMessage or str for invalid JSON)
3031
3132
Returns:
3233
bool: True if message was published, False if session not found
@@ -67,31 +68,40 @@ def __init__(self) -> None:
6768
# We don't need a separate _active_sessions set since _callbacks already tracks this
6869

6970
async def publish_message(
70-
self, session_id: UUID, message: types.JSONRPCMessage | Exception
71+
self, session_id: UUID, message: types.JSONRPCMessage | str
7172
) -> bool:
7273
"""Publish a message for the specified session."""
7374
if session_id not in self._callbacks:
74-
logger.warning(f"Message received for unknown session {session_id}")
75+
logger.warning(f"Message dropped: unknown session {session_id}")
7576
return False
76-
77-
# Call the callback directly
78-
await self._callbacks[session_id](message)
79-
logger.debug(f"Called callback for session {session_id}")
80-
77+
78+
# For string messages, attempt parsing and recreate original ValidationError if invalid
79+
if isinstance(message, str):
80+
try:
81+
callback_argument = types.JSONRPCMessage.model_validate_json(message)
82+
except ValidationError as exc:
83+
callback_argument = exc
84+
else:
85+
callback_argument = message
86+
87+
# Call the callback with either valid message or recreated ValidationError
88+
await self._callbacks[session_id](callback_argument)
89+
90+
logger.debug(f"Message dispatched to session {session_id}")
8191
return True
8292

8393
@asynccontextmanager
8494
async def subscribe(self, session_id: UUID, callback: MessageCallback):
8595
"""Request-scoped context manager that subscribes to messages for a session."""
8696
self._callbacks[session_id] = callback
87-
logger.debug(f"Registered session {session_id} with callback")
97+
logger.debug(f"Subscribing to messages for session {session_id}")
8898

8999
try:
90100
yield
91101
finally:
92102
if session_id in self._callbacks:
93103
del self._callbacks[session_id]
94-
logger.debug(f"Unregistered session {session_id}")
104+
logger.debug(f"Unsubscribed from session {session_id}")
95105

96106
async def session_exists(self, session_id: UUID) -> bool:
97107
"""Check if a session exists."""

src/mcp/server/message_queue/redis.py

Lines changed: 21 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import json
21
import logging
32
from contextlib import asynccontextmanager
43
from typing import Any, cast
54
from uuid import UUID
5+
from pydantic import ValidationError
66

77
import anyio
88
from anyio import CapacityLimiter, lowlevel
@@ -44,7 +44,7 @@ def __init__(
4444
self._callbacks: dict[UUID, MessageCallback] = {}
4545
# Ensures only one polling task runs at a time for message handling
4646
self._limiter = CapacityLimiter(1)
47-
logger.debug(f"Initialized Redis message dispatch with URL: {redis_url}")
47+
logger.debug(f"Redis message dispatch initialized: {redis_url}")
4848

4949
def _session_channel(self, session_id: UUID) -> str:
5050
"""Get the Redis channel for a session."""
@@ -58,7 +58,7 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
5858
channel = self._session_channel(session_id)
5959
await self._pubsub.subscribe(channel) # type: ignore
6060

61-
logger.debug(f"Registered session {session_id} in Redis with callback")
61+
logger.debug(f"Subscribing to Redis channel for session {session_id}")
6262
async with anyio.create_task_group() as tg:
6363
tg.start_soon(self._listen_for_messages)
6464
try:
@@ -68,7 +68,7 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6868
await self._pubsub.unsubscribe(channel) # type: ignore
6969
await self._redis.srem(self._active_sessions_key, session_id.hex)
7070
del self._callbacks[session_id]
71-
logger.debug(f"Unregistered session {session_id} from Redis")
71+
logger.debug(f"Unsubscribed from Redis channel for session {session_id}")
7272

7373
async def _listen_for_messages(self) -> None:
7474
"""Background task that listens for messages on subscribed channels."""
@@ -85,61 +85,49 @@ async def _listen_for_messages(self) -> None:
8585
# Extract session ID from channel name
8686
channel: str = cast(str, message["channel"])
8787
if not channel.startswith(self._prefix):
88+
logger.debug(f"Ignoring message from non-MCP channel: {channel}")
8889
continue
8990

9091
session_hex = channel.split(":")[-1]
9192
try:
9293
session_id = UUID(hex=session_hex)
9394
except ValueError:
94-
logger.error(f"Invalid session channel: {channel}")
95+
logger.error(f"Received message for invalid session channel: {channel}")
9596
continue
9697

9798
data: str = cast(str, message["data"])
98-
msg: None | types.JSONRPCMessage | Exception = None
9999
try:
100-
json_data = json.loads(data)
101-
if not isinstance(json_data, dict):
102-
logger.error(f"Received non-dict JSON data: {type(json_data)}")
100+
if session_id not in self._callbacks:
101+
logger.warning(f"Message dropped: no callback for session {session_id}")
103102
continue
104103

105-
json_dict: dict[str, Any] = json_data
106-
if json_dict.get("_exception", False):
107-
msg = Exception(
108-
f"{json_dict['type']}: {json_dict['message']}"
109-
)
110-
else:
104+
# Try to parse as valid message or recreate original ValidationError
105+
try:
111106
msg = types.JSONRPCMessage.model_validate_json(data)
112-
113-
if msg:
114-
if session_id in self._callbacks:
115-
await self._callbacks[session_id](msg)
116-
else:
117-
logger.warning(f"No callback registered for session {session_id}")
107+
await self._callbacks[session_id](msg)
108+
except ValidationError as exc:
109+
# Pass the identical validation error that would have occurred originally
110+
await self._callbacks[session_id](exc)
118111
except Exception as e:
119-
logger.error(f"Failed to process message: {e}")
112+
logger.error(f"Error processing message for session {session_id}: {e}")
120113

121114
async def publish_message(
122-
self, session_id: UUID, message: types.JSONRPCMessage | Exception
115+
self, session_id: UUID, message: types.JSONRPCMessage | str
123116
) -> bool:
124117
"""Publish a message for the specified session."""
125118
if not await self.session_exists(session_id):
126-
logger.warning(f"Message received for unknown session {session_id}")
119+
logger.warning(f"Message dropped: unknown session {session_id}")
127120
return False
128121

129-
if isinstance(message, Exception):
130-
data = json.dumps(
131-
{
132-
"_exception": True,
133-
"type": type(message).__name__,
134-
"message": str(message),
135-
}
136-
)
122+
# Pass raw JSON strings directly, preserving validation errors
123+
if isinstance(message, str):
124+
data = message
137125
else:
138126
data = message.model_dump_json()
139127

140128
channel = self._session_channel(session_id)
141129
await self._redis.publish(channel, data) # type: ignore[attr-defined]
142-
logger.debug(f"Published message to Redis channel for session {session_id}")
130+
logger.debug(f"Message published to Redis channel for session {session_id}")
143131
return True
144132

145133
async def session_exists(self, session_id: UUID) -> bool:

src/mcp/server/sse.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ async def handle_sse(request):
4646
from starlette.types import Receive, Scope, Send
4747

4848
import mcp.types as types
49-
from mcp.server.message_queue import InMemoryMessageQueue, MessageQueue
49+
from mcp.server.message_queue import InMemoryMessageDispatch, MessageDispatch
5050

5151
logger = logging.getLogger(__name__)
5252

@@ -64,23 +64,23 @@ class SseServerTransport:
6464
"""
6565

6666
_endpoint: str
67-
_message_queue: MessageQueue
67+
_message_dispatch: MessageDispatch
6868

6969
def __init__(
70-
self, endpoint: str, message_queue: MessageQueue | None = None
70+
self, endpoint: str, message_dispatch: MessageDispatch | None = None
7171
) -> None:
7272
"""
7373
Creates a new SSE server transport, which will direct the client to POST
7474
messages to the relative or absolute URL given.
7575
7676
Args:
7777
endpoint: The endpoint URL for SSE connections
78-
message_queue: Optional message queue to use
78+
message_dispatch: Optional message dispatch to use
7979
"""
8080

8181
super().__init__()
8282
self._endpoint = endpoint
83-
self._message_dispatch = message_queue or InMemoryMessageDispatch()
83+
self._message_dispatch = message_dispatch or InMemoryMessageDispatch()
8484
logger.debug(f"SseServerTransport initialized with endpoint: {endpoint}")
8585

8686
@asynccontextmanager
@@ -137,7 +137,7 @@ async def sse_writer():
137137
logger.debug("Starting SSE response task")
138138
tg.start_soon(response, scope, receive, send)
139139

140-
async with self._message_queue.active_for_request(
140+
async with self._message_dispatch.subscribe(
141141
session_id, message_callback
142142
):
143143
logger.debug("Yielding read and write streams")
@@ -163,7 +163,7 @@ async def handle_post_message(
163163
response = Response("Invalid session ID", status_code=400)
164164
return await response(scope, receive, send)
165165

166-
if not await self._message_queue.session_exists(session_id):
166+
if not await self._message_dispatch.session_exists(session_id):
167167
logger.warning(f"Could not find session for ID: {session_id}")
168168
response = Response("Could not find session", status_code=404)
169169
return await response(scope, receive, send)
@@ -178,10 +178,12 @@ async def handle_post_message(
178178
logger.error(f"Failed to parse message: {err}")
179179
response = Response("Could not parse message", status_code=400)
180180
await response(scope, receive, send)
181-
await self._message_queue.publish_message(session_id, err)
181+
# Pass raw JSON string through dispatch; original ValidationError will be recreated when
182+
# the receiver tries to validate the same invalid JSON
183+
await self._message_dispatch.publish_message(session_id, body.decode())
182184
return
183185

184186
logger.debug(f"Publishing message for session {session_id}: {message}")
185187
response = Response("Accepted", status_code=202)
186188
await response(scope, receive, send)
187-
await self._message_queue.publish_message(session_id, message)
189+
await self._message_dispatch.publish_message(session_id, message)

0 commit comments

Comments
 (0)