Skip to content

Commit

Permalink
analysts
Browse files Browse the repository at this point in the history
  • Loading branch information
jayceslesar committed Apr 9, 2024
1 parent 6f46e16 commit 97de002
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 3 deletions.
39 changes: 37 additions & 2 deletions masterbase/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from masterbase.lib import (
check_analyst,
check_is_active,
check_key_exists,
check_steam_id_has_api_key,
Expand All @@ -27,6 +28,7 @@
generate_uuid4_int,
is_limited_account,
late_bytes_helper,
list_demos_helper,
make_db_uri,
make_demo_path,
provision_api_key,
Expand Down Expand Up @@ -85,6 +87,16 @@ async def valid_key_guard(connection: ASGIConnection, _: BaseRouteHandler) -> No
raise NotAuthorizedException()


async def analyst_guard(connection: ASGIConnection, _: BaseRouteHandler) -> None:
"""Guard clause to User is an analyst."""
api_key = connection.query_params["api_key"]

async_engine = connection.app.state.async_engine
exists = await check_analyst(async_engine, api_key)
if not exists:
raise NotAuthorizedException()


async def user_in_session_guard(connection: ASGIConnection, _: BaseRouteHandler) -> None:
"""Assert that the user is not currently in a session."""
async_engine = connection.app.state.async_engine
Expand Down Expand Up @@ -173,7 +185,21 @@ def late_bytes(request: Request, api_key: str, data: dict[str, str]) -> dict[str
return {"late_bytes": True}


@get("/demodata", guards=[valid_key_guard, session_closed_guard], sync_to_thread=False)
@get("/list_demos", guards=[valid_key_guard, analyst_guard], sync_to_thread=False)
def list_demos(
request: Request, api_key: str, page_size: int | None = None, page_number: int | None = None
) -> list[dict[str, str]]:
"""List demo data."""
if page_size is None or page_size >= 50 or page_size < 1:
page_size = 50
if page_number is None or page_number < 1:
page_number = 1
engine = request.app.state.engine
demos = list_demos_helper(engine, api_key, page_size, page_number)
return demos


@get("/demodata", guards=[valid_key_guard, session_closed_guard, analyst_guard], sync_to_thread=False)
def demodata(request: Request, api_key: str, session_id: str) -> Stream:
"""Return the demo."""
engine = request.app.state.engine
Expand Down Expand Up @@ -345,7 +371,16 @@ def provision_handler(request: Request) -> str:

app = Litestar(
on_startup=[get_db_connection, get_async_db_connection],
route_handlers=[session_id, close_session, DemoHandler, provision, provision_handler, late_bytes, demodata],
route_handlers=[
session_id,
close_session,
DemoHandler,
provision,
provision_handler,
late_bytes,
demodata,
list_demos,
],
on_shutdown=[close_db_connection, close_async_db_connection],
)

Expand Down
50 changes: 49 additions & 1 deletion masterbase/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os
from datetime import datetime, timezone
from typing import IO, Generator
from typing import IO, Any, Generator
from uuid import uuid4
from xml.etree import ElementTree

Expand Down Expand Up @@ -89,6 +89,33 @@ async def check_is_active(engine: AsyncEngine, api_key: str) -> bool:
return is_active


async def check_analyst(engine: AsyncEngine, api_key: str) -> bool:
"""Determine if a user is in an analyst session."""
sql = """
SELECT
*
FROM
api_keys
JOIN
analyst_steam_ids ON analyst_steam_ids.steam_id = api_keys.steam_id
WHERE
api_keys.api_key = :api_key
;
"""
params = {"api_key": api_key}

async with engine.connect() as conn:
_result = await conn.execute(
sa.text(sql),
params,
)

result = _result.one_or_none()
analyst = True if result is not None else False

return analyst


async def session_closed(engine: AsyncEngine, session_id: str) -> bool:
"""Determine if a session is active."""
sql = "SELECT active FROM demo_sessions WHERE session_id = :session_id;"
Expand Down Expand Up @@ -312,6 +339,27 @@ def demodata_helper(engine: Engine, api_key: str, session_id: str) -> Generator[
yield bytestream


def list_demos_helper(engine: Engine, api_key: str, page_size: int, page_number: int) -> list[dict[str, Any]]:
"""List all demos in the DB for a user with pagination."""
offset = (page_number - 1) * page_size

sql = """
SELECT
demo_name, session_id, map, start_time, end_time
FROM demo_sessions
WHERE
api_key = :api_key
AND active = false
LIMIT :page_size OFFSET :offset
;
"""

with engine.connect() as conn:
data = conn.execute(sa.text(sql), {"api_key": api_key, "page_size": page_size, "offset": offset})

return [row._asdict() for row in data.all()]


def check_steam_id_has_api_key(engine: Engine, steam_id: str) -> str | None:
"""Check that a given steam id has an API key or not."""
with engine.connect() as conn:
Expand Down
9 changes: 9 additions & 0 deletions migrations/versions/58fb39990d30_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ def upgrade() -> None:
"""
)


op.execute(
"""
CREATE TABLE analyst_steam_ids (
steam_id varchar PRIMARY KEY
);
"""
)

def downgrade() -> None:
op.execute(
"""
Expand Down

0 comments on commit 97de002

Please sign in to comment.