1616import anyio
1717import httpx
1818from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
19- from httpx_sse import EventSource , aconnect_sse
19+ from httpx_sse import EventSource , ServerSentEvent , aconnect_sse
2020
2121from mcp .shared .message import ClientMessageMetadata , SessionMessage
2222from mcp .types import (
2626 JSONRPCNotification ,
2727 JSONRPCRequest ,
2828 JSONRPCResponse ,
29+ RequestId ,
2930)
3031
3132logger = logging .getLogger (__name__ )
3233
3334
34- MessageOrError = SessionMessage | Exception
35- StreamWriter = MemoryObjectSendStream [MessageOrError ]
35+ SessionMessageOrError = SessionMessage | Exception
36+ StreamWriter = MemoryObjectSendStream [SessionMessageOrError ]
3637StreamReader = MemoryObjectReceiveStream [SessionMessage ]
3738
3839
@@ -123,23 +124,21 @@ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
123124 and message .root .method == "notifications/initialized"
124125 )
125126
126- def _extract_session_id_from_response (
127+ def _maybe_extract_session_id_from_response (
127128 self ,
128129 response : httpx .Response ,
129- is_initialization : bool ,
130130 ) -> None :
131131 """Extract and store session ID from response headers."""
132- if is_initialization :
133- new_session_id = response .headers .get (MCP_SESSION_ID )
134- if new_session_id :
135- self .session_id = new_session_id
136- logger .info (f"Received session ID: { self .session_id } " )
132+ new_session_id = response .headers .get (MCP_SESSION_ID )
133+ if new_session_id :
134+ self .session_id = new_session_id
135+ logger .info (f"Received session ID: { self .session_id } " )
137136
138137 async def _handle_sse_event (
139138 self ,
140- sse : Any ,
139+ sse : ServerSentEvent ,
141140 read_stream_writer : StreamWriter ,
142- original_request_id : Any | None = None ,
141+ original_request_id : RequestId | None = None ,
143142 resumption_callback : Callable [[str ], Awaitable [None ]] | None = None ,
144143 ) -> bool :
145144 """Handle an SSE event, returning True if the response is complete."""
@@ -161,7 +160,8 @@ async def _handle_sse_event(
161160 if sse .id and resumption_callback :
162161 await resumption_callback (sse .id )
163162
164- # If this is a response or error, we're done
163+ # If this is a response or error return True indicating completion
164+ # Otherwise, return False to continue listening
165165 return isinstance (message .root , JSONRPCResponse | JSONRPCError )
166166
167167 except Exception as exc :
@@ -262,7 +262,8 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
262262 return
263263
264264 response .raise_for_status ()
265- self ._extract_session_id_from_response (response , is_initialization )
265+ if is_initialization :
266+ self ._maybe_extract_session_id_from_response (response )
266267
267268 content_type = response .headers .get (CONTENT_TYPE , "" ).lower ()
268269
@@ -324,7 +325,7 @@ async def _handle_unexpected_content_type(
324325 async def _send_session_terminated_error (
325326 self ,
326327 read_stream_writer : StreamWriter ,
327- request_id : Any ,
328+ request_id : RequestId ,
328329 ) -> None :
329330 """Send a session terminated error response."""
330331 jsonrpc_error = JSONRPCError (
0 commit comments