Skip to content

Commit

Permalink
Clean refactoring of api.py.
Browse files Browse the repository at this point in the history
The structure is now as follows

class GenericApi:
-- contains handle_error, add_time

class WorkerApi(GenericApi):
-- contains the apis used by the worker

class UserApi(GenericApi):
-- Apis useful for application developers

class InternalApi(GenericApi):
-- Apis (possibly exposing sensitive data) that should only be
   used by Fishtest itself (possibly access controlled).
  • Loading branch information
vdbergh authored and ppigazzini committed May 12, 2024
1 parent 291c616 commit 8c7ec81
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 185 deletions.
287 changes: 134 additions & 153 deletions server/fishtest/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fishtest.schemas import api_access_schema, api_schema, gzip_data
from fishtest.stats.stat_util import SPRT_elo, get_elo
from fishtest.util import worker_name
from fishtest.util import strip_run, worker_name
from pyramid.httpexceptions import (
HTTPBadRequest,
HTTPFound,
Expand Down Expand Up @@ -36,37 +36,6 @@
WORKER_VERSION = 237


def validate_request(request):
validate(api_schema, request, "request")


# Avoids exposing sensitive data about the workers to the client and skips some heavy data.
def strip_run(run):
# a deep copy, avoiding copies of a few large lists.
stripped = {}
for k1, v1 in run.items():
if k1 in ("tasks", "bad_tasks"):
stripped[k1] = []
elif k1 == "args":
stripped[k1] = {}
for k2, v2 in v1.items():
if k2 == "spsa":
stripped[k1][k2] = {
k3: [] if k3 == "param_history" else copy.deepcopy(v3)
for k3, v3 in v2.items()
}
else:
stripped[k1][k2] = copy.deepcopy(v2)
else:
stripped[k1] = copy.deepcopy(v1)

# and some string conversions
for key in ("_id", "start_time", "last_updated"):
stripped[key] = str(run[key])

return stripped


@exception_view_config(HTTPBadRequest)
def badrequest_failed(error, request):
response = Response(json_body=error.detail)
Expand All @@ -81,12 +50,19 @@ def authentication_failed(error, request):
return response


@view_defaults(renderer="json")
class ApiView(object):
"""All API endpoints that require authentication are used by workers"""

class GenericApi:
def __init__(self, request):
self.request = request
self.__t0 = datetime.now(timezone.utc)

def timestamp(self):
return self.__t0

def add_time(self, result):
result["duration"] = (
datetime.now(timezone.utc) - self.timestamp()
).total_seconds()
return result

def handle_error(self, error, exception=HTTPBadRequest):
if error != "":
Expand All @@ -96,14 +72,20 @@ def handle_error(self, error, exception=HTTPBadRequest):
print(error, flush=True)
raise exception(self.add_time({"error": error}))

def validate_username_password(self):
self.__t0 = datetime.now(timezone.utc)

@view_defaults(renderer="json", request_method="POST")
class WorkerApi(GenericApi):
"""All API endpoints that require authentication are used by workers"""

def __init__(self, request):
super().__init__(request)
# is the request valid json?
try:
self.request_body = self.request.json_body
self.request_body = request.json_body
except:
self.handle_error("request is not json encoded")

def validate_username_password(self):
# Is the request syntactically correct?
try:
validate(api_access_schema, self.request_body, "request")
Expand Down Expand Up @@ -137,7 +119,7 @@ def validate_request(self):

# Is the request syntactically correct?
try:
validate_request(self.request_body)
validate(api_schema, self.request_body, "request")
except ValidationError as e:
self.handle_error(str(e))

Expand Down Expand Up @@ -170,10 +152,6 @@ def validate_request(self):
)
self.__task = task

def add_time(self, result):
result["duration"] = (datetime.now(timezone.utc) - self.__t0).total_seconds()
return result

def get_username(self):
return self.request_body["worker_info"]["username"]

Expand Down Expand Up @@ -234,6 +212,116 @@ def get_country_code(self):
country_code = self.request.headers.get("X-Country-Code")
return "?" if country_code in (None, "ZZ") else country_code

@view_config(route_name="api_request_task")
def request_task(self):
self.validate_request()
worker_info = self.worker_info()
# rundb.request_task() needs this for an error message...
worker_info["host_url"] = self.request.host_url
result = self.request.rundb.request_task(worker_info)
if "task_waiting" in result:
return self.add_time(result)

# Strip the run of unneccesary information
run = result["run"]
task = run["tasks"][result["task_id"]]
min_task = {"num_games": task["num_games"], "start": task["start"]}
if "stats" in task:
min_task["stats"] = task["stats"]
min_run = {"_id": str(run["_id"]), "args": run["args"], "my_task": min_task}
result["run"] = min_run
return self.add_time(result)

@view_config(route_name="api_update_task")
def update_task(self):
self.validate_request()
result = self.request.rundb.update_task(
worker_info=self.worker_info(),
run_id=self.run_id(),
task_id=self.task_id(),
stats=self.stats(),
spsa=self.spsa(),
)
return self.add_time(result)

@view_config(route_name="api_failed_task")
def failed_task(self):
self.validate_request()
result = self.request.rundb.failed_task(
self.run_id(), self.task_id(), self.message()
)
return self.add_time(result)

@view_config(route_name="api_upload_pgn")
def upload_pgn(self):
self.validate_request()
try:
pgn_zip = base64.b64decode(self.pgn())
validate(gzip_data, pgn_zip, "pgn")
except Exception as e:
self.handle_error(str(e))
result = self.request.rundb.upload_pgn(
run_id="{}-{}".format(self.run_id(), self.task_id()),
pgn_zip=pgn_zip,
)
return self.add_time(result)

@view_config(route_name="api_stop_run")
def stop_run(self):
self.validate_request()
error = ""
if self.cpu_hours() < 1000:
error = "User {} has too few games to stop a run".format(
self.get_username()
)
with self.request.rundb.active_run_lock(self.run_id()):
run = self.run()
message = self.message()[:1024] + (
" (not authorized)" if error != "" else ""
)
self.request.actiondb.stop_run(
username=self.get_username(),
run=run,
task_id=self.task_id(),
message=message,
)
if error == "":
run["finished"] = True
run["failed"] = True
self.request.rundb.stop_run(self.run_id())
else:
task = self.task()
task["active"] = False
self.request.rundb.buffer(run, True)

self.handle_error(error, exception=HTTPUnauthorized)
return self.add_time({})

@view_config(route_name="api_request_version")
def request_version(self):
# By being mor lax here we can be more strict
# elsewhere since the worker will upgrade.
self.validate_username_password()
return self.add_time({"version": WORKER_VERSION})

@view_config(route_name="api_beat")
def beat(self):
self.validate_request()
run = self.run()
task = self.task()
task["last_updated"] = datetime.now(timezone.utc)
self.request.rundb.buffer(run, False)
return self.add_time({})

@view_config(route_name="api_request_spsa")
def request_spsa(self):
self.validate_request()
result = self.request.rundb.request_spsa(self.run_id(), self.task_id())
return self.add_time(result)


@view_defaults(renderer="json")
class UserApi(GenericApi):
@view_config(route_name="api_active_runs")
def active_runs(self):
runs = self.request.rundb.runs.find(
Expand All @@ -250,8 +338,6 @@ def active_runs(self):

@view_config(route_name="api_finished_runs")
def finished_runs(self):
self.__t0 = datetime.now(timezone.utc)

username = self.request.params.get("username", "")
success_only = self.request.params.get("success_only", False)
yellow_only = self.request.params.get("yellow_only", False)
Expand Down Expand Up @@ -367,8 +453,6 @@ def get_elo(self):

@view_config(route_name="api_calc_elo")
def calc_elo(self):
self.__t0 = datetime.now(timezone.utc)

W = self.request.params.get("W")
D = self.request.params.get("D")
L = self.request.params.get("L")
Expand Down Expand Up @@ -473,60 +557,6 @@ def calc_elo(self):
elo_model=elo_model,
)

@view_config(route_name="api_request_task")
def request_task(self):
self.validate_request()
worker_info = self.worker_info()
# rundb.request_task() needs this for an error message...
worker_info["host_url"] = self.request.host_url
result = self.request.rundb.request_task(worker_info)
if "task_waiting" in result:
return self.add_time(result)

# Strip the run of unneccesary information
run = result["run"]
task = run["tasks"][result["task_id"]]
min_task = {"num_games": task["num_games"], "start": task["start"]}
if "stats" in task:
min_task["stats"] = task["stats"]
min_run = {"_id": str(run["_id"]), "args": run["args"], "my_task": min_task}
result["run"] = min_run
return self.add_time(result)

@view_config(route_name="api_update_task")
def update_task(self):
self.validate_request()
result = self.request.rundb.update_task(
worker_info=self.worker_info(),
run_id=self.run_id(),
task_id=self.task_id(),
stats=self.stats(),
spsa=self.spsa(),
)
return self.add_time(result)

@view_config(route_name="api_failed_task")
def failed_task(self):
self.validate_request()
result = self.request.rundb.failed_task(
self.run_id(), self.task_id(), self.message()
)
return self.add_time(result)

@view_config(route_name="api_upload_pgn")
def upload_pgn(self):
self.validate_request()
try:
pgn_zip = base64.b64decode(self.pgn())
validate(gzip_data, pgn_zip, "pgn")
except Exception as e:
self.handle_error(str(e))
result = self.request.rundb.upload_pgn(
run_id="{}-{}".format(self.run_id(), self.task_id()),
pgn_zip=pgn_zip,
)
return self.add_time(result)

@view_config(route_name="api_download_pgn", renderer="string")
def download_pgn(self):
zip_name = self.request.matchdict["id"]
Expand Down Expand Up @@ -569,55 +599,6 @@ def download_nn(self):
"https://data.stockfishchess.org/nn/" + self.request.matchdict["id"]
)

@view_config(route_name="api_stop_run")
def stop_run(self):
self.validate_request()
error = ""
if self.cpu_hours() < 1000:
error = "User {} has too few games to stop a run".format(
self.get_username()
)
with self.request.rundb.active_run_lock(self.run_id()):
run = self.run()
message = self.message()[:1024] + (
" (not authorized)" if error != "" else ""
)
self.request.actiondb.stop_run(
username=self.get_username(),
run=run,
task_id=self.task_id(),
message=message,
)
if error == "":
run["finished"] = True
run["failed"] = True
self.request.rundb.stop_run(self.run_id())
else:
task = self.task()
task["active"] = False
self.request.rundb.buffer(run, True)

self.handle_error(error, exception=HTTPUnauthorized)
return self.add_time({})

@view_config(route_name="api_request_version")
def request_version(self):
# By being mor lax here we can be more strict
# elsewhere since the worker will upgrade.
self.validate_username_password()
return self.add_time({"version": WORKER_VERSION})

@view_config(route_name="api_beat")
def beat(self):
self.validate_request()
run = self.run()
task = self.task()
task["last_updated"] = datetime.now(timezone.utc)
self.request.rundb.buffer(run, False)
return self.add_time({})

@view_config(route_name="api_request_spsa")
def request_spsa(self):
self.validate_request()
result = self.request.rundb.request_spsa(self.run_id(), self.task_id())
return self.add_time(result)
class InternalApi(GenericApi):
pass
Loading

0 comments on commit 8c7ec81

Please sign in to comment.