@@ -49,119 +49,117 @@ async def sse_client(
4949 write_stream , write_stream_reader = anyio .create_memory_object_stream (0 )
5050
5151 with catch ({Exception : handle_exception }):
52- async with anyio .create_task_group () as tg :
53- try :
54- logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
55- async with httpx .AsyncClient (headers = headers ) as client :
56- async with aconnect_sse (
57- client ,
58- "GET" ,
59- url ,
60- timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
61- ) as event_source :
62- event_source .response .raise_for_status ()
63- logger .debug ("SSE connection established" )
64-
65- async def sse_reader (
66- task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
67- ):
68- try :
69- async for sse in event_source .aiter_sse ():
70- logger .debug (f"Received SSE event: { sse .event } " )
71- match sse .event :
72- case "endpoint" :
73- endpoint_url = urljoin (url , sse .data )
74- logger .info (
75- f"Received endpoint URL: { endpoint_url } "
76- )
77-
78- url_parsed = urlparse (url )
79- endpoint_parsed = urlparse (endpoint_url )
80- if (
81- url_parsed .netloc
82- != endpoint_parsed .netloc
83- or url_parsed .scheme
84- != endpoint_parsed .scheme
85- ):
86- error_msg = (
87- "Endpoint origin does not match "
88- f"connection origin: { endpoint_url } "
89- )
90- logger .error (error_msg )
91- raise ValueError (error_msg )
92-
93- task_status .started (endpoint_url )
94-
95- case "message" :
96- try :
97- message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
98- sse .data
99- )
100- logger .debug (
101- "Received server message: "
102- f"{ message } "
103- )
104- except Exception as exc :
105- logger .error (
106- "Error parsing server message: "
107- f"{ exc } "
108- )
109- await read_stream_writer .send (exc )
110- continue
111-
112- session_message = SessionMessage (
113- message = message
114- )
115- await read_stream_writer .send (
116- session_message
117- )
118- case _:
119- logger .warning (
120- f"Unknown SSE event: { sse .event } "
121- )
122- except Exception as exc :
123- logger .error (f"Error in sse_reader: { exc } " )
124- await read_stream_writer .send (exc )
125- finally :
126- await read_stream_writer .aclose ()
127-
128- async def post_writer (endpoint_url : str ):
129- try :
130- async with write_stream_reader :
131- async for session_message in write_stream_reader :
132- logger .debug (
133- f"Sending client message: { session_message } "
52+ logger .info (f"Connecting to SSE endpoint: { remove_request_params (url )} " )
53+ async with httpx .AsyncClient (headers = headers ) as client :
54+ async with aconnect_sse (
55+ client ,
56+ "GET" ,
57+ url ,
58+ timeout = httpx .Timeout (timeout , read = sse_read_timeout ),
59+ ) as event_source :
60+ event_source .response .raise_for_status ()
61+ logger .debug ("SSE connection established" )
62+
63+ async def sse_reader (
64+ task_status : TaskStatus [str ] = anyio .TASK_STATUS_IGNORED ,
65+ ):
66+ try :
67+ async for sse in event_source .aiter_sse ():
68+ logger .debug (f"Received SSE event: { sse .event } " )
69+ match sse .event :
70+ case "endpoint" :
71+ endpoint_url = urljoin (url , sse .data )
72+ logger .info (
73+ f"Received endpoint URL: { endpoint_url } "
74+ )
75+
76+ url_parsed = urlparse (url )
77+ endpoint_parsed = urlparse (endpoint_url )
78+ if (
79+ url_parsed .netloc
80+ != endpoint_parsed .netloc
81+ or url_parsed .scheme
82+ != endpoint_parsed .scheme
83+ ):
84+ error_msg = (
85+ "Endpoint origin does not match "
86+ f"connection origin: { endpoint_url } "
13487 )
135- response = await client .post (
136- endpoint_url ,
137- json = session_message .message .model_dump (
138- by_alias = True ,
139- mode = "json" ,
140- exclude_none = True ,
141- ),
88+ logger .error (error_msg )
89+ raise ValueError (error_msg )
90+
91+ task_status .started (endpoint_url )
92+
93+ case "message" :
94+ try :
95+ message = types .JSONRPCMessage .model_validate_json ( # noqa: E501
96+ sse .data
14297 )
143- response .raise_for_status ()
14498 logger .debug (
145- "Client message sent successfully : "
146- f"{ response . status_code } "
99+ "Received server message : "
100+ f"{ message } "
147101 )
148- except Exception as exc :
149- logger .error (f"Error in post_writer: { exc } " )
150- finally :
151- await write_stream .aclose ()
152-
102+ except Exception as exc :
103+ logger .error (
104+ "Error parsing server message: "
105+ f"{ exc } "
106+ )
107+ await read_stream_writer .send (exc )
108+ continue
109+
110+ session_message = SessionMessage (
111+ message = message
112+ )
113+ await read_stream_writer .send (
114+ session_message
115+ )
116+ case _:
117+ logger .warning (
118+ f"Unknown SSE event: { sse .event } "
119+ )
120+ except Exception as exc :
121+ logger .error (f"Error in sse_reader: { exc } " )
122+ await read_stream_writer .send (exc )
123+ finally :
124+ await read_stream_writer .aclose ()
125+
126+ async def post_writer (endpoint_url : str ):
127+ try :
128+ async with write_stream_reader :
129+ async for session_message in write_stream_reader :
130+ logger .debug (
131+ f"Sending client message: { session_message } "
132+ )
133+ response = await client .post (
134+ endpoint_url ,
135+ json = session_message .message .model_dump (
136+ by_alias = True ,
137+ mode = "json" ,
138+ exclude_none = True ,
139+ ),
140+ )
141+ response .raise_for_status ()
142+ logger .debug (
143+ "Client message sent successfully: "
144+ f"{ response .status_code } "
145+ )
146+ except Exception as exc :
147+ logger .error (f"Error in post_writer: { exc } " )
148+ finally :
149+ await write_stream .aclose ()
150+
151+ try :
152+ async with anyio .create_task_group () as tg :
153153 endpoint_url = await tg .start (sse_reader )
154154 logger .info (
155155 f"Starting post writer with endpoint URL: { endpoint_url } "
156156 )
157157 tg .start_soon (post_writer , endpoint_url )
158-
159- try :
160- yield read_stream , write_stream
161- finally :
162- tg .cancel_scope .cancel ()
163- finally :
164- await read_stream_writer .aclose ()
165- await write_stream .aclose ()
166- await read_stream .aclose ()
167- await write_stream_reader .aclose ()
158+
159+ # Move streams outside
160+ yield read_stream , write_stream
161+ finally :
162+ await read_stream_writer .aclose ()
163+ await write_stream .aclose ()
164+ await read_stream .aclose ()
165+ await write_stream_reader .aclose ()
0 commit comments