Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 116 additions & 37 deletions lib/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Tuple, Awaitable, NoReturn, List, Union, Callable, Optional
from functools import cached_property
from distutils.util import strtobool
from collections import deque

from anyio import open_file
from aiohttp import web, ClientResponse, ClientSession, ClientConnectorError, ClientTimeout, TCPConnector
Expand All @@ -30,7 +31,7 @@
BenchmarkResult
)

VERSION = "0.2.0"
VERSION = "0.2.1"

MSG_HISTORY_LEN = 100
log = logging.getLogger(__file__)
Expand Down Expand Up @@ -63,6 +64,7 @@ class Backend:
version = VERSION
msg_history = []
sem: Semaphore = dataclasses.field(default_factory=Semaphore)
queue: deque = dataclasses.field(default_factory=deque, repr=False)
unsecured: bool = dataclasses.field(
default_factory=lambda: bool(strtobool(os.environ.get("UNSECURED", "false"))),
)
Expand Down Expand Up @@ -141,6 +143,19 @@ async def __handle_request(
workload = payload.count_workload()
request_metrics: RequestMetrics = RequestMetrics(request_idx=auth_data.request_idx, reqnum=auth_data.reqnum, workload=workload, status="Created")


def advance_queue_after_completion(event: asyncio.Event):
"""Pop current head and wake next waiter, if any."""
if self.queue and self.queue[0] is event:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to check if [0] is an event? Small little nit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not checking if it is an event, but verifying that the event we pass in is in fact the current head of the queue.

self.queue.popleft()
if self.queue:
self.queue[0].set()
else:
try:
self.queue.remove(event)
except ValueError:
pass

async def cancel_api_call_if_disconnected() -> web.Response:
await request.wait_for_disconnection()
log.debug(f"request with reqnum: {request_metrics.reqnum} was canceled")
Expand All @@ -162,7 +177,7 @@ async def make_request() -> Union[web.Response, web.StreamResponse]:
res = await handler.generate_client_response(request, response)
self.metrics._request_success(request_metrics)
return res
except requests.exceptions.RequestException as e:
except Exception as e:
log.debug(f"[backend] Request error: {e}")
self.metrics._request_errored(request_metrics)
return web.Response(status=500)
Expand All @@ -177,46 +192,110 @@ async def make_request() -> Union[web.Response, web.StreamResponse]:
self.metrics._request_reject(request_metrics)
return web.Response(status=429)

acquired = False
try:
self.metrics._request_start(request_metrics)
if self.allow_parallel_requests is False:
log.debug(f"Waiting to aquire Sem for reqnum:{request_metrics.reqnum}")
await self.sem.acquire()
acquired = True
log.debug(
f"Sem acquired for reqnum:{request_metrics.reqnum}, starting request..."
)
else:
disconnect_task = create_task(cancel_api_call_if_disconnected())
self.metrics._request_start(request_metrics)

if self.allow_parallel_requests:
try:
log.debug(f"Starting request for reqnum:{request_metrics.reqnum}")
done, pending = await wait(
[
create_task(make_request()),
create_task(cancel_api_call_if_disconnected()),
],
return_when=FIRST_COMPLETED,
)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)

for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)

if disconnect_task in done:
# Make sure work_task is settled/cancelled
try:
await work_task
except Exception:
pass
return web.Response(status=499)

# otherwise work_task completed
return await work_task

except asyncio.CancelledError:
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
self.metrics._request_end(request_metrics)

else:
# Insert a Event into the queue for this request
# Event.set() == our request is up next
event = asyncio.Event()
self.queue.append(event)
if self.queue and self.queue[0] is event:
event.set()

done_task = done.pop()
try:
return done_task.result()
# Race between our request being next and request being cancelled
next_request_task = create_task(event.wait())
first_done, first_pending = await wait(
[next_request_task, disconnect_task], return_when=FIRST_COMPLETED
)

# If the disconnect task wins the race
if disconnect_task in first_done and not event.is_set():
was_head = (self.queue and self.queue[0] is event)
try:
self.queue.remove(event)
except ValueError:
pass
if was_head and self.queue:
self.queue[0].set()

for t in first_pending:
t.cancel()
await asyncio.gather(*first_pending, return_exceptions=True)
return web.Response(status=499)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt this duplicating advance_queue_after_completion?


# We are the next-up request in the queue
log.debug(f"Starting work on request {request_metrics.reqnum}...")

# Race the backend API call with the disconnect task
work_task = create_task(make_request())
done, pending = await wait([work_task, disconnect_task], return_when=FIRST_COMPLETED)
for t in pending:
t.cancel()
await asyncio.gather(*pending, return_exceptions=True)

if disconnect_task in done:
# ensure work is cancelled and accounted for
try:
await work_task
except Exception:
pass
return web.Response(status=499)

# otherwise work_task completed
return await work_task

except asyncio.CancelledError:
# Cleanup if request was cancelled
was_head = (self.queue and self.queue[0] is event)
try:
self.queue.remove(event)
except ValueError:
pass
if was_head and self.queue:
self.queue[0].set()

return web.Response(status=499)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here?


except Exception as e:
log.debug(f"Request task raised exception: {e}")
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
except asyncio.CancelledError:
# Client is gone. Do not write a response; just unwind.
return web.Response(status=499)
except Exception as e:
log.debug(f"Exception in main handler loop {e}")
return web.Response(status=500)
finally:
# Always release the semaphore if it was acquired
if acquired:
self.sem.release()
self.metrics._request_end(request_metrics)

finally:
self.metrics._request_end(request_metrics)
if event.is_set():
# The request is done, advance the queue
advance_queue_after_completion(event)

@cached_property
def healthcheck_session(self):
Expand Down