Skip to content

Commit 43788c7

Browse files
add fifo queue
1 parent 7db54f3 commit 43788c7

File tree

1 file changed

+130
-39
lines changed

1 file changed

+130
-39
lines changed

lib/backend.py

Lines changed: 130 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
1010
from functools import cached_property
1111
from distutils.util import strtobool
12+
from collections import deque
1213

1314
from anyio import open_file
1415
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
@@ -63,6 +64,7 @@ class Backend:
6364
version = VERSION
6465
msg_history = []
6566
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
67+
queue: deque = dataclasses.field(default_factory=deque, repr=False)
6668
unsecured: bool = dataclasses.field(
6769
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
6870
)
@@ -141,6 +143,19 @@ async def __handle_request(
141143
workload = payload.count_workload()
142144
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")
143145

146+
147+
def advance_queue_after_completion(event: asyncio.Event):
148+
"""Pop current head and wake next waiter, if any."""
149+
if self.queue and self.queue[0] is event:
150+
self.queue.popleft()
151+
if self.queue:
152+
self.queue[0].set()
153+
else:
154+
try:
155+
self.queue.remove(event)
156+
except ValueError:
157+
pass
158+
144159
async def cancel_api_call_if_disconnected() -> web.Response:
145160
await request.wait_for_disconnection()
146161
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
@@ -162,61 +177,137 @@ async def make_request() -> Union[web.Response, web.StreamResponse]:
162177
res = await handler.generate_client_response(request, response)
163178
self.metrics._request_success(request_metrics)
164179
return res
165-
except requests.exceptions.RequestException as e:
180+
except Exception as e:
166181
log.debug(f"[backend] Request error: {e}")
167182
self.metrics._request_errored(request_metrics)
168183
return web.Response(status=500)
169184

170185
###########
171186

172187
if self.__check_signature(auth_data) is False:
173-
self.metrics._request_reject(request_metrics)
174-
return web.Response(status=401)
175-
188+
self.metrics._request_reject(request_metrics)
189+
return web.Response(status=401)
190+
176191
if self.metrics.model_metrics.wait_time > self.max_wait_time:
177192
self.metrics._request_reject(request_metrics)
178193
return web.Response(status=429)
179194

180-
acquired = False
181-
try:
182-
self.metrics._request_start(request_metrics)
183-
if self.allow_parallel_requests is False:
184-
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
185-
await self.sem.acquire()
186-
acquired = True
187-
log.debug(
188-
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
189-
)
190-
else:
195+
disconnect_task = create_task(cancel_api_call_if_disconnected())
196+
self.metrics._request_start(request_metrics)
197+
198+
if self.allow_parallel_requests:
199+
ended = False
200+
try:
191201
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
192-
done, pending = await wait(
193-
[
194-
create_task(make_request()),
195-
create_task(cancel_api_call_if_disconnected()),
196-
],
197-
return_when=FIRST_COMPLETED,
198-
)
199-
for t in pending:
200-
t.cancel()
201-
await asyncio.gather(*pending, return_exceptions=True)
202+
work_task = create_task(make_request())
203+
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
204+
205+
for t in pending:
206+
t.cancel()
207+
await asyncio.gather(*pending, return_exceptions=True)
208+
209+
if disconnect_task in done:
210+
# Make sure work_task is settled/cancelled
211+
try:
212+
await work_task
213+
except Exception:
214+
pass
215+
return web.Response(status=499)
216+
217+
# otherwise work_task completed
218+
return await work_task
219+
220+
except asyncio.CancelledError:
221+
return web.Response(status=499)
222+
except Exception as e:
223+
log.debug(f"Exception in main handler loop {e}")
224+
return web.Response(status=500)
225+
finally:
226+
if not ended:
227+
self.metrics._request_end(request_metrics)
202228

203-
done_task = done.pop()
229+
else:
230+
# Insert a Event into the queue for this request
231+
# Event.set() == our request is up next
232+
event = asyncio.Event()
233+
self.queue.append(event)
234+
if self.queue and self.queue[0] is event:
235+
event.set()
236+
237+
ended = False
204238
try:
205-
return done_task.result()
239+
# Race between our request being next and request being cancelled
240+
next_request_task = create_task(event.wait())
241+
first_done, first_pending = await wait(
242+
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
243+
)
244+
245+
# If the disconnect task wins the race
246+
if disconnect_task in first_done and not event.is_set():
247+
was_head = (self.queue and self.queue[0] is event)
248+
try:
249+
self.queue.remove(event)
250+
except ValueError:
251+
pass
252+
if was_head and self.queue:
253+
self.queue[0].set()
254+
255+
self.metrics._request_end(request_metrics)
256+
ended = True
257+
258+
for t in first_pending:
259+
t.cancel()
260+
await asyncio.gather(*first_pending, return_exceptions=True)
261+
return web.Response(status=499)
262+
263+
# We are the next-up request in the queue
264+
log.debug(f"Starting work on request {request_metrics.reqnum}...")
265+
266+
# Race the backend API call with the disconnect task
267+
work_task = create_task(make_request())
268+
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
269+
for t in pending:
270+
t.cancel()
271+
await asyncio.gather(*pending, return_exceptions=True)
272+
273+
if disconnect_task in done:
274+
# ensure work is cancelled and accounted for
275+
try:
276+
await work_task
277+
except Exception:
278+
pass
279+
return web.Response(status=499)
280+
281+
# otherwise work_task completed
282+
return await work_task
283+
284+
except asyncio.CancelledError:
285+
# Cleanup if request was cancelled
286+
was_head = (self.queue and self.queue[0] is event)
287+
try:
288+
self.queue.remove(event)
289+
except ValueError:
290+
pass
291+
if was_head and self.queue:
292+
self.queue[0].set()
293+
294+
if not ended:
295+
self.metrics._request_end(request_metrics)
296+
ended = True
297+
298+
return web.Response(status=499)
299+
206300
except Exception as e:
207-
log.debug(f"Request task raised exception: {e}")
301+
log.debug(f"Exception in main handler loop {e}")
208302
return web.Response(status=500)
209-
except asyncio.CancelledError:
210-
# Client is gone. Do not write a response; just unwind.
211-
return web.Response(status=499)
212-
except Exception as e:
213-
log.debug(f"Exception in main handler loop {e}")
214-
return web.Response(status=500)
215-
finally:
216-
# Always release the semaphore if it was acquired
217-
if acquired:
218-
self.sem.release()
219-
self.metrics._request_end(request_metrics)
303+
304+
finally:
305+
if not ended:
306+
self.metrics._request_end(request_metrics)
307+
ended = True
308+
if event.is_set():
309+
# The request is done, advance the queue
310+
advance_queue_after_completion(event)
220311

221312
@cached_property
222313
def healthcheck_session(self):

0 commit comments

Comments
 (0)