@@ -120,6 +120,74 @@ async def mock_aiter():
120120 record , expected_response_count = 2 , expected_response_type = SSEMessage
121121 )
122122
123+ @pytest .mark .asyncio
124+ @pytest .mark .parametrize (
125+ "comment_value,expected_error_text" ,
126+ [
127+ ("Rate limit exceeded" , "Rate limit exceeded" ),
128+ (None , "Unknown error in SSE response" ),
129+ ],
130+ )
131+ async def test_sse_stream_error_event_handling (
132+ self ,
133+ aiohttp_client : AioHttpClient ,
134+ mock_sse_response : Mock ,
135+ comment_value : str | None ,
136+ expected_error_text : str ,
137+ ) -> None :
138+ """Test that SSE error events are properly caught and handled in the client."""
139+ from aiperf .common .enums import SSEEventType , SSEFieldType
140+ from aiperf .common .models import SSEField
141+
142+ packets = [
143+ SSEField (name = SSEFieldType .EVENT , value = SSEEventType .ERROR ),
144+ ]
145+ if comment_value :
146+ packets .append (SSEField (name = SSEFieldType .COMMENT , value = comment_value ))
147+ packets .append (SSEField (name = SSEFieldType .DATA , value = "{}" ))
148+
149+ mock_error_message = SSEMessage (perf_ns = 123456789 , packets = packets )
150+
151+ with (
152+ patch ("aiohttp.ClientSession" ) as mock_session_class ,
153+ patch (
154+ "aiperf.transports.aiohttp_client.AsyncSSEStreamReader"
155+ ) as mock_reader_class ,
156+ ):
157+
158+ async def mock_content_iter ():
159+ yield b"event: error\n "
160+ if comment_value :
161+ yield f": { comment_value } \n " .encode ()
162+ yield b"data: {}\n \n "
163+
164+ mock_sse_response .content = mock_content_iter ()
165+
166+ setup_mock_session (mock_session_class , mock_sse_response , ["request" ])
167+
168+ async def mock_aiter ():
169+ yield mock_error_message
170+ from aiperf .transports .sse_utils import AsyncSSEStreamReader
171+
172+ AsyncSSEStreamReader .inspect_message_for_error (mock_error_message )
173+
174+ mock_reader = Mock ()
175+ mock_reader .__aiter__ = Mock (return_value = mock_aiter ())
176+ mock_reader_class .return_value = mock_reader
177+
178+ record = await aiohttp_client .post_request (
179+ "http://test.com/stream" ,
180+ '{"stream": true}' ,
181+ {"Accept" : "text/event-stream" },
182+ )
183+
184+ assert record .error is not None
185+ assert record .error .code == 502
186+ assert record .error .type == "SSEResponseError"
187+ assert expected_error_text in record .error .message
188+ assert len (record .responses ) == 1
189+ assert isinstance (record .responses [0 ], SSEMessage )
190+
123191 @pytest .mark .asyncio
124192 @pytest .mark .parametrize (
125193 "status_code,reason,error_text" ,
@@ -184,13 +252,9 @@ async def test_exception_handling(
184252 "exception_class,message,expected_type" ,
185253 [
186254 (aiohttp .ClientConnectorError , "Connection failed" , "ClientConnectorError" ),
187- (
188- aiohttp .ClientResponseError ,
189- "Internal Server Error" ,
190- "ClientResponseError" ,
191- ),
255+ (aiohttp .ClientResponseError , "Internal Server Error" , "ClientResponseError" ),
192256 ],
193- )
257+ ) # fmt: skip
194258 async def test_aiohttp_specific_exceptions (
195259 self ,
196260 aiohttp_client : AioHttpClient ,
0 commit comments