99from typing import Tuple , Awaitable , NoReturn , List , Union , Callable , Optional
1010from functools import cached_property
1111from distutils .util import strtobool
12+ from collections import deque
1213
1314from anyio import open_file
1415from aiohttp import web , ClientResponse , ClientSession , ClientConnectorError , ClientTimeout , TCPConnector
@@ -63,6 +64,7 @@ class Backend:
6364 version = VERSION
6465 msg_history = []
6566 sem : Semaphore = dataclasses .field (default_factory = Semaphore )
67+ queue : deque = dataclasses .field (default_factory = deque , repr = False )
6668 unsecured : bool = dataclasses .field (
6769 default_factory = lambda : bool (strtobool (os .environ .get ("UNSECURED" , "false" ))),
6870 )
@@ -141,6 +143,19 @@ async def __handle_request(
141143 workload = payload .count_workload ()
142144 request_metrics : RequestMetrics = RequestMetrics (request_idx = auth_data .request_idx , reqnum = auth_data .reqnum , workload = workload , status = "Created" )
143145
146+
147+ def advance_queue_after_completion (event : asyncio .Event ):
148+ """Pop current head and wake next waiter, if any."""
149+ if self .queue and self .queue [0 ] is event :
150+ self .queue .popleft ()
151+ if self .queue :
152+ self .queue [0 ].set ()
153+ else :
154+ try :
155+ self .queue .remove (event )
156+ except ValueError :
157+ pass
158+
144159 async def cancel_api_call_if_disconnected () -> web .Response :
145160 await request .wait_for_disconnection ()
146161 log .debug (f"request with reqnum: { request_metrics .reqnum } was canceled" )
@@ -162,61 +177,137 @@ async def make_request() -> Union[web.Response, web.StreamResponse]:
162177 res = await handler .generate_client_response (request , response )
163178 self .metrics ._request_success (request_metrics )
164179 return res
165- except requests . exceptions . RequestException as e :
180+ except Exception as e :
166181 log .debug (f"[backend] Request error: { e } " )
167182 self .metrics ._request_errored (request_metrics )
168183 return web .Response (status = 500 )
169184
170185 ###########
171186
172187 if self .__check_signature (auth_data ) is False :
173- self .metrics ._request_reject (request_metrics )
174- return web .Response (status = 401 )
175-
188+ self .metrics ._request_reject (request_metrics )
189+ return web .Response (status = 401 )
190+
176191 if self .metrics .model_metrics .wait_time > self .max_wait_time :
177192 self .metrics ._request_reject (request_metrics )
178193 return web .Response (status = 429 )
179194
180- acquired = False
181- try :
182- self .metrics ._request_start (request_metrics )
183- if self .allow_parallel_requests is False :
184- log .debug (f"Waiting to aquire Sem for reqnum:{ request_metrics .reqnum } " )
185- await self .sem .acquire ()
186- acquired = True
187- log .debug (
188- f"Sem acquired for reqnum:{ request_metrics .reqnum } , starting request..."
189- )
190- else :
195+ disconnect_task = create_task (cancel_api_call_if_disconnected ())
196+ self .metrics ._request_start (request_metrics )
197+
198+ if self .allow_parallel_requests :
199+ ended = False
200+ try :
191201 log .debug (f"Starting request for reqnum:{ request_metrics .reqnum } " )
192- done , pending = await wait (
193- [
194- create_task (make_request ()),
195- create_task (cancel_api_call_if_disconnected ()),
196- ],
197- return_when = FIRST_COMPLETED ,
198- )
199- for t in pending :
200- t .cancel ()
201- await asyncio .gather (* pending , return_exceptions = True )
202+ work_task = create_task (make_request ())
203+ done , pending = await wait ([work_task , disconnect_task ], return_when = FIRST_COMPLETED )
204+
205+ for t in pending :
206+ t .cancel ()
207+ await asyncio .gather (* pending , return_exceptions = True )
208+
209+ if disconnect_task in done :
210+ # Make sure work_task is settled/cancelled
211+ try :
212+ await work_task
213+ except Exception :
214+ pass
215+ return web .Response (status = 499 )
216+
217+ # otherwise work_task completed
218+ return await work_task
219+
220+ except asyncio .CancelledError :
221+ return web .Response (status = 499 )
222+ except Exception as e :
223+ log .debug (f"Exception in main handler loop { e } " )
224+ return web .Response (status = 500 )
225+ finally :
226+ if not ended :
227+ self .metrics ._request_end (request_metrics )
202228
203- done_task = done .pop ()
229+ else :
230+ # Insert a Event into the queue for this request
231+ # Event.set() == our request is up next
232+ event = asyncio .Event ()
233+ self .queue .append (event )
234+ if self .queue and self .queue [0 ] is event :
235+ event .set ()
236+
237+ ended = False
204238 try :
205- return done_task .result ()
239+ # Race between our request being next and request being cancelled
240+ next_request_task = create_task (event .wait ())
241+ first_done , first_pending = await wait (
242+ [next_request_task , disconnect_task ], return_when = FIRST_COMPLETED
243+ )
244+
245+ # If the disconnect task wins the race
246+ if disconnect_task in first_done and not event .is_set ():
247+ was_head = (self .queue and self .queue [0 ] is event )
248+ try :
249+ self .queue .remove (event )
250+ except ValueError :
251+ pass
252+ if was_head and self .queue :
253+ self .queue [0 ].set ()
254+
255+ self .metrics ._request_end (request_metrics )
256+ ended = True
257+
258+ for t in first_pending :
259+ t .cancel ()
260+ await asyncio .gather (* first_pending , return_exceptions = True )
261+ return web .Response (status = 499 )
262+
263+ # We are the next-up request in the queue
264+ log .debug (f"Starting work on request { request_metrics .reqnum } ..." )
265+
266+ # Race the backend API call with the disconnect task
267+ work_task = create_task (make_request ())
268+ done , pending = await wait ([work_task , disconnect_task ], return_when = FIRST_COMPLETED )
269+ for t in pending :
270+ t .cancel ()
271+ await asyncio .gather (* pending , return_exceptions = True )
272+
273+ if disconnect_task in done :
274+ # ensure work is cancelled and accounted for
275+ try :
276+ await work_task
277+ except Exception :
278+ pass
279+ return web .Response (status = 499 )
280+
281+ # otherwise work_task completed
282+ return await work_task
283+
284+ except asyncio .CancelledError :
285+ # Cleanup if request was cancelled
286+ was_head = (self .queue and self .queue [0 ] is event )
287+ try :
288+ self .queue .remove (event )
289+ except ValueError :
290+ pass
291+ if was_head and self .queue :
292+ self .queue [0 ].set ()
293+
294+ if not ended :
295+ self .metrics ._request_end (request_metrics )
296+ ended = True
297+
298+ return web .Response (status = 499 )
299+
206300 except Exception as e :
207- log .debug (f"Request task raised exception: { e } " )
301+ log .debug (f"Exception in main handler loop { e } " )
208302 return web .Response (status = 500 )
209- except asyncio .CancelledError :
210- # Client is gone. Do not write a response; just unwind.
211- return web .Response (status = 499 )
212- except Exception as e :
213- log .debug (f"Exception in main handler loop { e } " )
214- return web .Response (status = 500 )
215- finally :
216- # Always release the semaphore if it was acquired
217- if acquired :
218- self .sem .release ()
219- self .metrics ._request_end (request_metrics )
303+
304+ finally :
305+ if not ended :
306+ self .metrics ._request_end (request_metrics )
307+ ended = True
308+ if event .is_set ():
309+ # The request is done, advance the queue
310+ advance_queue_after_completion (event )
220311
221312 @cached_property
222313 def healthcheck_session (self ):
0 commit comments