77import httpx
88from anyio .abc import TaskStatus
99from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
10+ from exceptiongroup import BaseExceptionGroup , catch
1011from httpx_sse import aconnect_sse
1112
1213import mcp .types as types
@@ -19,6 +20,12 @@ def remove_request_params(url: str) -> str:
1920 return urljoin (url , urlparse (url ).path )
2021
2122
23+ def handle_exception (exc : BaseExceptionGroup [Exception ]) -> str :
24+ """Handle ExceptionGroup and Exceptions for Client transport for SSE"""
25+ messages = "; " .join (str (e ) for e in exc .exceptions )
26+ raise Exception (f"TaskGroup failed with: { messages } " ) from None
27+
28+
2229@asynccontextmanager
2330async def sse_client (
2431 url : str ,
@@ -41,114 +48,120 @@ async def sse_client(
4148 read_stream_writer , read_stream = anyio .create_memory_object_stream (0 )
4249 write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
4350
44- async with anyio .create_task_group () as tg :
45- try :
46- logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
47- async with httpx .AsyncClient (headers = headers ) as client :
48- async with aconnect_sse (
49- client ,
50- "GET" ,
51- url ,
52- timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
53- ) as event_source :
54- event_source .response .raise_for_status ()
55- logger .debug ("SSE connection established" )
56-
57- async def sse_reader (
58- task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
59- ):
60- try :
61- async for sse in event_source .aiter_sse ():
62- logger .debug (f"Received SSE event: { sse .event } " )
63- match sse .event :
64- case "endpoint" :
65- endpoint_url = urljoin (url , sse .data )
66- logger .info (
67- f"Received endpoint URL: { endpoint_url } "
68- )
69-
70- url_parsed = urlparse (url )
71- endpoint_parsed = urlparse (endpoint_url )
72- if (
73- url_parsed .netloc != endpoint_parsed .netloc
74- or url_parsed .scheme
75- != endpoint_parsed .scheme
76- ):
77- error_msg = (
78- "Endpoint origin does not match "
79- f"connection origin: { endpoint_url } "
51+ with catch ({Exception : handle_exception }):
52+ async with anyio .create_task_group () as tg :
53+ try :
54+ logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
55+ async with httpx .AsyncClient (headers = headers ) as client :
56+ async with aconnect_sse (
57+ client ,
58+ "GET" ,
59+ url ,
60+ timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
61+ ) as event_source :
62+ event_source .response .raise_for_status ()
63+ logger .debug ("SSE connection established" )
64+
65+ async def sse_reader (
66+ task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
67+ ):
68+ try :
69+ async for sse in event_source .aiter_sse ():
70+ logger .debug (f"Received SSE event: { sse .event } " )
71+ match sse .event :
72+ case "endpoint" :
73+ endpoint_url = urljoin (url , sse .data )
74+ logger .info (
75+ f"Received endpoint URL: { endpoint_url } "
8076 )
81- logger .error (error_msg )
82- raise ValueError (error_msg )
8377
84- task_status .started (endpoint_url )
85-
86- case "message" :
87- try :
88- message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
89- sse .data
78+ url_parsed = urlparse (url )
79+ endpoint_parsed = urlparse (endpoint_url )
80+ if (
81+ url_parsed .netloc
82+ != endpoint_parsed .netloc
83+ or url_parsed .scheme
84+ != endpoint_parsed .scheme
85+ ):
86+ error_msg = (
87+ "Endpoint origin does not match "
88+ f"connection origin: { endpoint_url } "
89+ )
90+ logger .error (error_msg )
91+ raise ValueError (error_msg )
92+
93+ task_status .started (endpoint_url )
94+
95+ case "message" :
96+ try :
97+ message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
98+ sse .data
99+ )
100+ logger .debug (
101+ "Received server message: "
102+ f"{ message } "
103+ )
104+ except Exception as exc :
105+ logger .error (
106+ "Error parsing server message: "
107+ f"{ exc } "
108+ )
109+ await read_stream_writer .send (exc )
110+ continue
111+
112+ session_message = SessionMessage (
113+ message = message
90114 )
91- logger . debug (
92- f"Received server message: { message } "
115+ await read_stream_writer . send (
116+ session_message
93117 )
94- except Exception as exc :
95- logger .error (
96- f"Error parsing server message : { exc } "
118+ case _ :
119+ logger .warning (
120+ f"Unknown SSE event : { sse . event } "
97121 )
98- await read_stream_writer .send (exc )
99- continue
100-
101- session_message = SessionMessage (
102- message = message
122+ except Exception as exc :
123+ logger .error (f"Error in sse_reader: { exc } " )
124+ await read_stream_writer .send (exc )
125+ finally :
126+ await read_stream_writer .aclose ()
127+
128+ async def post_writer (endpoint_url : str ):
129+ try :
130+ async with write_stream_reader :
131+ async for session_message in write_stream_reader :
132+ logger .debug (
133+ f"Sending client message: { session_message } "
103134 )
104- await read_stream_writer .send (session_message )
105- case _:
106- logger .warning (
107- f"Unknown SSE event: { sse .event } "
135+ response = await client .post (
136+ endpoint_url ,
137+ json = session_message .message .model_dump (
138+ by_alias = True ,
139+ mode = "json" ,
140+ exclude_none = True ,
141+ ),
108142 )
109- except Exception as exc :
110- logger .error (f"Error in sse_reader: { exc } " )
111- await read_stream_writer .send (exc )
112- finally :
113- await read_stream_writer .aclose ()
143+ response .raise_for_status ()
144+ logger .debug (
145+ "Client message sent successfully: "
146+ f"{ response .status_code } "
147+ )
148+ except Exception as exc :
149+ logger .error (f"Error in post_writer: { exc } " )
150+ finally :
151+ await write_stream .aclose ()
152+
153+ endpoint_url = await tg .start (sse_reader )
154+ logger .info (
155+ f"Starting post writer with endpoint URL: { endpoint_url } "
156+ )
157+ tg .start_soon (post_writer , endpoint_url )
114158
115- async def post_writer (endpoint_url : str ):
116159 try :
117- async with write_stream_reader :
118- async for session_message in write_stream_reader :
119- logger .debug (
120- f"Sending client message: { session_message } "
121- )
122- response = await client .post (
123- endpoint_url ,
124- json = session_message .message .model_dump (
125- by_alias = True ,
126- mode = "json" ,
127- exclude_none = True ,
128- ),
129- )
130- response .raise_for_status ()
131- logger .debug (
132- "Client message sent successfully: "
133- f"{ response .status_code } "
134- )
135- except Exception as exc :
136- logger .error (f"Error in post_writer: { exc } " )
160+ yield read_stream , write_stream
137161 finally :
138- await write_stream .aclose ()
139-
140- endpoint_url = await tg .start (sse_reader )
141- logger .info (
142- f"Starting post writer with endpoint URL: { endpoint_url } "
143- )
144- tg .start_soon (post_writer , endpoint_url )
145-
146- try :
147- yield read_stream , write_stream
148- finally :
149- tg .cancel_scope .cancel ()
150- finally :
151- await read_stream_writer .aclose ()
152- await write_stream .aclose ()
153- await read_stream .aclose ()
154- await write_stream_reader .aclose ()
162+ tg .cancel_scope .cancel ()
163+ finally :
164+ await read_stream_writer .aclose ()
165+ await write_stream .aclose ()
166+ await read_stream .aclose ()
167+ await write_stream_reader .aclose ()
0 commit comments