1- import json
21import logging
32from contextlib import asynccontextmanager
43from typing import Any , cast
54from uuid import UUID
5+ from pydantic import ValidationError
66
77import anyio
88from anyio import CapacityLimiter , lowlevel
@@ -44,7 +44,7 @@ def __init__(
4444 self ._callbacks : dict [UUID , MessageCallback ] = {}
4545 # Ensures only one polling task runs at a time for message handling
4646 self ._limiter = CapacityLimiter (1 )
47- logger .debug (f"Initialized Redis message dispatch with URL : { redis_url } " )
47+ logger .debug (f"Redis message dispatch initialized : { redis_url } " )
4848
4949 def _session_channel (self , session_id : UUID ) -> str :
5050 """Get the Redis channel for a session."""
@@ -58,7 +58,7 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
5858 channel = self ._session_channel (session_id )
5959 await self ._pubsub .subscribe (channel ) # type: ignore
6060
61- logger .debug (f"Registered session { session_id } in Redis with callback " )
61+ logger .debug (f"Subscribing to Redis channel for session { session_id } " )
6262 async with anyio .create_task_group () as tg :
6363 tg .start_soon (self ._listen_for_messages )
6464 try :
@@ -68,7 +68,7 @@ async def subscribe(self, session_id: UUID, callback: MessageCallback):
6868 await self ._pubsub .unsubscribe (channel ) # type: ignore
6969 await self ._redis .srem (self ._active_sessions_key , session_id .hex )
7070 del self ._callbacks [session_id ]
71- logger .debug (f"Unregistered session { session_id } from Redis " )
71+ logger .debug (f"Unsubscribed from Redis channel for session { session_id } " )
7272
7373 async def _listen_for_messages (self ) -> None :
7474 """Background task that listens for messages on subscribed channels."""
@@ -85,61 +85,49 @@ async def _listen_for_messages(self) -> None:
8585 # Extract session ID from channel name
8686 channel : str = cast (str , message ["channel" ])
8787 if not channel .startswith (self ._prefix ):
88+ logger .debug (f"Ignoring message from non-MCP channel: { channel } " )
8889 continue
8990
9091 session_hex = channel .split (":" )[- 1 ]
9192 try :
9293 session_id = UUID (hex = session_hex )
9394 except ValueError :
94- logger .error (f"Invalid session channel: { channel } " )
95+ logger .error (f"Received message for invalid session channel: { channel } " )
9596 continue
9697
9798 data : str = cast (str , message ["data" ])
98- msg : None | types .JSONRPCMessage | Exception = None
9999 try :
100- json_data = json .loads (data )
101- if not isinstance (json_data , dict ):
102- logger .error (f"Received non-dict JSON data: { type (json_data )} " )
100+ if session_id not in self ._callbacks :
101+ logger .warning (f"Message dropped: no callback for session { session_id } " )
103102 continue
104103
105- json_dict : dict [str , Any ] = json_data
106- if json_dict .get ("_exception" , False ):
107- msg = Exception (
108- f"{ json_dict ['type' ]} : { json_dict ['message' ]} "
109- )
110- else :
104+ # Try to parse as valid message or recreate original ValidationError
105+ try :
111106 msg = types .JSONRPCMessage .model_validate_json (data )
112-
113- if msg :
114- if session_id in self ._callbacks :
115- await self ._callbacks [session_id ](msg )
116- else :
117- logger .warning (f"No callback registered for session { session_id } " )
107+ await self ._callbacks [session_id ](msg )
108+ except ValidationError as exc :
109+ # Pass the identical validation error that would have occurred originally
110+ await self ._callbacks [session_id ](exc )
118111 except Exception as e :
119- logger .error (f"Failed to process message: { e } " )
112+ logger .error (f"Error processing message for session { session_id } : { e } " )
120113
121114 async def publish_message (
122- self , session_id : UUID , message : types .JSONRPCMessage | Exception
115+ self , session_id : UUID , message : types .JSONRPCMessage | str
123116 ) -> bool :
124117 """Publish a message for the specified session."""
125118 if not await self .session_exists (session_id ):
126- logger .warning (f"Message received for unknown session { session_id } " )
119+ logger .warning (f"Message dropped: unknown session { session_id } " )
127120 return False
128121
129- if isinstance (message , Exception ):
130- data = json .dumps (
131- {
132- "_exception" : True ,
133- "type" : type (message ).__name__ ,
134- "message" : str (message ),
135- }
136- )
122+ # Pass raw JSON strings directly, preserving validation errors
123+ if isinstance (message , str ):
124+ data = message
137125 else :
138126 data = message .model_dump_json ()
139127
140128 channel = self ._session_channel (session_id )
141129 await self ._redis .publish (channel , data ) # type: ignore[attr-defined]
142- logger .debug (f"Published message to Redis channel for session { session_id } " )
130+ logger .debug (f"Message published to Redis channel for session { session_id } " )
143131 return True
144132
145133 async def session_exists (self , session_id : UUID ) -> bool :
0 commit comments