From 8c7ec8133f70518c34c42d1b13ed71cce5394115 Mon Sep 17 00:00:00 2001 From: Michel Van den Bergh Date: Sat, 11 May 2024 13:18:08 +0000 Subject: [PATCH] Clean refactoring of api.py. The structure is now as follows class GenericApi: -- contains handle_error, add_time class WorkerApi(GenericApi): -- contains the apis used by the worker class UserApi(GenericApi): -- Apis useful for application developers class InternalApi(GenericApi): -- Apis (possibly exposing sensitive data) that should only be used by Fishtest itself (possibly access controlled). --- server/fishtest/api.py | 287 ++++++++++++++++++--------------------- server/fishtest/util.py | 28 ++++ server/tests/test_api.py | 64 ++++----- 3 files changed, 194 insertions(+), 185 deletions(-) diff --git a/server/fishtest/api.py b/server/fishtest/api.py index e8bc58529..3c8a18e71 100644 --- a/server/fishtest/api.py +++ b/server/fishtest/api.py @@ -7,7 +7,7 @@ from fishtest.schemas import api_access_schema, api_schema, gzip_data from fishtest.stats.stat_util import SPRT_elo, get_elo -from fishtest.util import worker_name +from fishtest.util import strip_run, worker_name from pyramid.httpexceptions import ( HTTPBadRequest, HTTPFound, @@ -36,37 +36,6 @@ WORKER_VERSION = 237 -def validate_request(request): - validate(api_schema, request, "request") - - -# Avoids exposing sensitive data about the workers to the client and skips some heavy data. -def strip_run(run): - # a deep copy, avoiding copies of a few large lists. - stripped = {} - for k1, v1 in run.items(): - if k1 in ("tasks", "bad_tasks"): - stripped[k1] = [] - elif k1 == "args": - stripped[k1] = {} - for k2, v2 in v1.items(): - if k2 == "spsa": - stripped[k1][k2] = { - k3: [] if k3 == "param_history" else copy.deepcopy(v3) - for k3, v3 in v2.items() - } - else: - stripped[k1][k2] = copy.deepcopy(v2) - else: - stripped[k1] = copy.deepcopy(v1) - - # and some string conversions - for key in ("_id", "start_time", "last_updated"): - stripped[key] = str(run[key]) - - return stripped - - @exception_view_config(HTTPBadRequest) def badrequest_failed(error, request): response = Response(json_body=error.detail) @@ -81,12 +50,19 @@ def authentication_failed(error, request): return response -@view_defaults(renderer="json") -class ApiView(object): - """All API endpoints that require authentication are used by workers""" - +class GenericApi: def __init__(self, request): self.request = request + self.__t0 = datetime.now(timezone.utc) + + def timestamp(self): + return self.__t0 + + def add_time(self, result): + result["duration"] = ( + datetime.now(timezone.utc) - self.timestamp() + ).total_seconds() + return result def handle_error(self, error, exception=HTTPBadRequest): if error != "": @@ -96,14 +72,20 @@ def handle_error(self, error, exception=HTTPBadRequest): print(error, flush=True) raise exception(self.add_time({"error": error})) - def validate_username_password(self): - self.__t0 = datetime.now(timezone.utc) + +@view_defaults(renderer="json", request_method="POST") +class WorkerApi(GenericApi): + """All API endpoints that require authentication are used by workers""" + + def __init__(self, request): + super().__init__(request) # is the request valid json? try: - self.request_body = self.request.json_body + self.request_body = request.json_body except: self.handle_error("request is not json encoded") + def validate_username_password(self): # Is the request syntactically correct? try: validate(api_access_schema, self.request_body, "request") @@ -137,7 +119,7 @@ def validate_request(self): # Is the request syntactically correct? try: - validate_request(self.request_body) + validate(api_schema, self.request_body, "request") except ValidationError as e: self.handle_error(str(e)) @@ -170,10 +152,6 @@ def validate_request(self): ) self.__task = task - def add_time(self, result): - result["duration"] = (datetime.now(timezone.utc) - self.__t0).total_seconds() - return result - def get_username(self): return self.request_body["worker_info"]["username"] @@ -234,6 +212,116 @@ def get_country_code(self): country_code = self.request.headers.get("X-Country-Code") return "?" if country_code in (None, "ZZ") else country_code + @view_config(route_name="api_request_task") + def request_task(self): + self.validate_request() + worker_info = self.worker_info() + # rundb.request_task() needs this for an error message... + worker_info["host_url"] = self.request.host_url + result = self.request.rundb.request_task(worker_info) + if "task_waiting" in result: + return self.add_time(result) + + # Strip the run of unneccesary information + run = result["run"] + task = run["tasks"][result["task_id"]] + min_task = {"num_games": task["num_games"], "start": task["start"]} + if "stats" in task: + min_task["stats"] = task["stats"] + min_run = {"_id": str(run["_id"]), "args": run["args"], "my_task": min_task} + result["run"] = min_run + return self.add_time(result) + + @view_config(route_name="api_update_task") + def update_task(self): + self.validate_request() + result = self.request.rundb.update_task( + worker_info=self.worker_info(), + run_id=self.run_id(), + task_id=self.task_id(), + stats=self.stats(), + spsa=self.spsa(), + ) + return self.add_time(result) + + @view_config(route_name="api_failed_task") + def failed_task(self): + self.validate_request() + result = self.request.rundb.failed_task( + self.run_id(), self.task_id(), self.message() + ) + return self.add_time(result) + + @view_config(route_name="api_upload_pgn") + def upload_pgn(self): + self.validate_request() + try: + pgn_zip = base64.b64decode(self.pgn()) + validate(gzip_data, pgn_zip, "pgn") + except Exception as e: + self.handle_error(str(e)) + result = self.request.rundb.upload_pgn( + run_id="{}-{}".format(self.run_id(), self.task_id()), + pgn_zip=pgn_zip, + ) + return self.add_time(result) + + @view_config(route_name="api_stop_run") + def stop_run(self): + self.validate_request() + error = "" + if self.cpu_hours() < 1000: + error = "User {} has too few games to stop a run".format( + self.get_username() + ) + with self.request.rundb.active_run_lock(self.run_id()): + run = self.run() + message = self.message()[:1024] + ( + " (not authorized)" if error != "" else "" + ) + self.request.actiondb.stop_run( + username=self.get_username(), + run=run, + task_id=self.task_id(), + message=message, + ) + if error == "": + run["finished"] = True + run["failed"] = True + self.request.rundb.stop_run(self.run_id()) + else: + task = self.task() + task["active"] = False + self.request.rundb.buffer(run, True) + + self.handle_error(error, exception=HTTPUnauthorized) + return self.add_time({}) + + @view_config(route_name="api_request_version") + def request_version(self): + # By being mor lax here we can be more strict + # elsewhere since the worker will upgrade. + self.validate_username_password() + return self.add_time({"version": WORKER_VERSION}) + + @view_config(route_name="api_beat") + def beat(self): + self.validate_request() + run = self.run() + task = self.task() + task["last_updated"] = datetime.now(timezone.utc) + self.request.rundb.buffer(run, False) + return self.add_time({}) + + @view_config(route_name="api_request_spsa") + def request_spsa(self): + self.validate_request() + result = self.request.rundb.request_spsa(self.run_id(), self.task_id()) + return self.add_time(result) + + +@view_defaults(renderer="json") +class UserApi(GenericApi): @view_config(route_name="api_active_runs") def active_runs(self): runs = self.request.rundb.runs.find( @@ -250,8 +338,6 @@ def active_runs(self): @view_config(route_name="api_finished_runs") def finished_runs(self): - self.__t0 = datetime.now(timezone.utc) - username = self.request.params.get("username", "") success_only = self.request.params.get("success_only", False) yellow_only = self.request.params.get("yellow_only", False) @@ -367,8 +453,6 @@ def get_elo(self): @view_config(route_name="api_calc_elo") def calc_elo(self): - self.__t0 = datetime.now(timezone.utc) - W = self.request.params.get("W") D = self.request.params.get("D") L = self.request.params.get("L") @@ -473,60 +557,6 @@ def calc_elo(self): elo_model=elo_model, ) - @view_config(route_name="api_request_task") - def request_task(self): - self.validate_request() - worker_info = self.worker_info() - # rundb.request_task() needs this for an error message... - worker_info["host_url"] = self.request.host_url - result = self.request.rundb.request_task(worker_info) - if "task_waiting" in result: - return self.add_time(result) - - # Strip the run of unneccesary information - run = result["run"] - task = run["tasks"][result["task_id"]] - min_task = {"num_games": task["num_games"], "start": task["start"]} - if "stats" in task: - min_task["stats"] = task["stats"] - min_run = {"_id": str(run["_id"]), "args": run["args"], "my_task": min_task} - result["run"] = min_run - return self.add_time(result) - - @view_config(route_name="api_update_task") - def update_task(self): - self.validate_request() - result = self.request.rundb.update_task( - worker_info=self.worker_info(), - run_id=self.run_id(), - task_id=self.task_id(), - stats=self.stats(), - spsa=self.spsa(), - ) - return self.add_time(result) - - @view_config(route_name="api_failed_task") - def failed_task(self): - self.validate_request() - result = self.request.rundb.failed_task( - self.run_id(), self.task_id(), self.message() - ) - return self.add_time(result) - - @view_config(route_name="api_upload_pgn") - def upload_pgn(self): - self.validate_request() - try: - pgn_zip = base64.b64decode(self.pgn()) - validate(gzip_data, pgn_zip, "pgn") - except Exception as e: - self.handle_error(str(e)) - result = self.request.rundb.upload_pgn( - run_id="{}-{}".format(self.run_id(), self.task_id()), - pgn_zip=pgn_zip, - ) - return self.add_time(result) - @view_config(route_name="api_download_pgn", renderer="string") def download_pgn(self): zip_name = self.request.matchdict["id"] @@ -569,55 +599,6 @@ def download_nn(self): "https://data.stockfishchess.org/nn/" + self.request.matchdict["id"] ) - @view_config(route_name="api_stop_run") - def stop_run(self): - self.validate_request() - error = "" - if self.cpu_hours() < 1000: - error = "User {} has too few games to stop a run".format( - self.get_username() - ) - with self.request.rundb.active_run_lock(self.run_id()): - run = self.run() - message = self.message()[:1024] + ( - " (not authorized)" if error != "" else "" - ) - self.request.actiondb.stop_run( - username=self.get_username(), - run=run, - task_id=self.task_id(), - message=message, - ) - if error == "": - run["finished"] = True - run["failed"] = True - self.request.rundb.stop_run(self.run_id()) - else: - task = self.task() - task["active"] = False - self.request.rundb.buffer(run, True) - - self.handle_error(error, exception=HTTPUnauthorized) - return self.add_time({}) - - @view_config(route_name="api_request_version") - def request_version(self): - # By being mor lax here we can be more strict - # elsewhere since the worker will upgrade. - self.validate_username_password() - return self.add_time({"version": WORKER_VERSION}) - - @view_config(route_name="api_beat") - def beat(self): - self.validate_request() - run = self.run() - task = self.task() - task["last_updated"] = datetime.now(timezone.utc) - self.request.rundb.buffer(run, False) - return self.add_time({}) - @view_config(route_name="api_request_spsa") - def request_spsa(self): - self.validate_request() - result = self.request.rundb.request_spsa(self.run_id(), self.task_id()) - return self.add_time(result) +class InternalApi(GenericApi): + pass diff --git a/server/fishtest/util.py b/server/fishtest/util.py index b7ea591af..63d7c3126 100644 --- a/server/fishtest/util.py +++ b/server/fishtest/util.py @@ -1,3 +1,4 @@ +import copy import hashlib import math import re @@ -512,3 +513,30 @@ def get_hash(s): if h: return int(h.group(1)) return 0 + + +# Avoids exposing sensitive data about the workers to the client and skips some heavy data. +def strip_run(run): + # a deep copy, avoiding copies of a few large lists. + stripped = {} + for k1, v1 in run.items(): + if k1 in ("tasks", "bad_tasks"): + stripped[k1] = [] + elif k1 == "args": + stripped[k1] = {} + for k2, v2 in v1.items(): + if k2 == "spsa": + stripped[k1][k2] = { + k3: [] if k3 == "param_history" else copy.deepcopy(v3) + for k3, v3 in v2.items() + } + else: + stripped[k1][k2] = copy.deepcopy(v2) + else: + stripped[k1] = copy.deepcopy(v1) + + # and some string conversions + for key in ("_id", "start_time", "last_updated"): + stripped[key] = str(run[key]) + + return stripped diff --git a/server/tests/test_api.py b/server/tests/test_api.py index 9e5494b0e..9c99757d3 100644 --- a/server/tests/test_api.py +++ b/server/tests/test_api.py @@ -6,7 +6,7 @@ import unittest from datetime import datetime, timezone -from fishtest.api import WORKER_VERSION, ApiView +from fishtest.api import WORKER_VERSION, UserApi, WorkerApi from pyramid.httpexceptions import HTTPBadRequest, HTTPUnauthorized from pyramid.testing import DummyRequest from util import get_rundb @@ -186,19 +186,19 @@ def correct_password_request(self, json_body={}): def test_get_active_runs(self): run_id = new_run(self) request = DummyRequest(rundb=self.rundb) - response = ApiView(request).active_runs() + response = UserApi(request).active_runs() self.assertTrue(run_id in response) def test_get_run(self): run_id = new_run(self) request = DummyRequest(rundb=self.rundb, matchdict={"id": run_id}) - response = ApiView(request).get_run() + response = UserApi(request).get_run() self.assertEqual(run_id, response["_id"]) def test_get_elo(self): run_id = new_run(self) request = DummyRequest(rundb=self.rundb, matchdict={"id": run_id}) - response = ApiView(request).get_elo() + response = UserApi(request).get_elo() # /api/get_elo only works for SPRT self.assertFalse(response) @@ -209,12 +209,12 @@ def test_request_task(self): request = self.invalid_password_request() with self.assertRaises(HTTPUnauthorized): - response = ApiView(request).request_task() + response = WorkerApi(request).request_task() self.assertTrue("error" in response) print(response["error"]) request = self.correct_password_request() - response = ApiView(request).request_task() + response = WorkerApi(request).request_task() run = response["run"] run_id = str(run["_id"]) @@ -235,7 +235,7 @@ def test_update_task(self): # Request fails if username/password is invalid with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).update_task() + response = WorkerApi(self.invalid_password_request()).update_task() self.assertTrue("error" in response) print(response["error"]) @@ -254,7 +254,7 @@ def test_update_task(self): }, } ) - response = ApiView(cleanup(request)).update_task() + response = WorkerApi(cleanup(request)).update_task() self.assertTrue(response["task_alive"]) # Task is still active @@ -268,7 +268,7 @@ def test_update_task(self): "time_losses": 0, "pentanomial": [0, 0, d // 2, 0, w // 2], } - response = ApiView(cleanup(request)).update_task() + response = WorkerApi(cleanup(request)).update_task() self.assertTrue(response["task_alive"]) # Task is still active. Odd update. @@ -281,7 +281,7 @@ def test_update_task(self): "pentanomial": [0, 0, d // 2, 0, w // 2], } with self.assertRaises(HTTPBadRequest): - response = ApiView(cleanup(request)).update_task() + response = WorkerApi(cleanup(request)).update_task() request.json_body["stats"] = { "wins": w + 2, @@ -292,7 +292,7 @@ def test_update_task(self): "pentanomial": [0, 0, d // 2, 0, w // 2 + 1], } - response = ApiView(cleanup(request)).update_task() + response = WorkerApi(cleanup(request)).update_task() self.assertTrue(response["task_alive"]) # Go back in time @@ -304,7 +304,7 @@ def test_update_task(self): "time_losses": 0, "pentanomial": [0, 0, d // 2, 0, w // 2], } - response = ApiView(cleanup(request)).update_task() + response = WorkerApi(cleanup(request)).update_task() self.assertFalse(response["task_alive"]) # revive the task @@ -322,7 +322,7 @@ def test_update_task(self): "time_losses": 0, "pentanomial": [0, 0, 0, 0, task_num_games // 2], } - response = ApiView(cleanup(request)).update_task() + response = WorkerApi(cleanup(request)).update_task() self.assertFalse(response["task_alive"]) run = self.rundb.get_run(run_id) task = run["tasks"][0] @@ -334,7 +334,7 @@ def test_failed_task(self): # Request fails if username/password is invalid request = self.invalid_password_request() with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).update_task() + response = WorkerApi(self.invalid_password_request()).update_task() self.assertTrue("error" in response) print(response["error"]) @@ -343,13 +343,13 @@ def test_failed_task(self): request = self.correct_password_request( {"run_id": run_id, "task_id": 0, "message": message} ) - response = ApiView(request).failed_task() + response = WorkerApi(request).failed_task() response.pop("duration", None) self.assertEqual(response, {}) self.assertFalse(run["tasks"][0]["active"]) request = self.correct_password_request({"run_id": run_id, "task_id": 0}) - response = ApiView(request).failed_task() + response = WorkerApi(request).failed_task() self.assertTrue("info" in response) print(response["info"]) self.assertFalse(run["tasks"][0]["active"]) @@ -360,7 +360,7 @@ def test_failed_task(self): request = self.correct_password_request( {"run_id": run_id, "task_id": 0, "message": message} ) - response = ApiView(request).failed_task() + response = WorkerApi(request).failed_task() response.pop("duration", None) self.assertTrue(response == {}) self.assertFalse(run["tasks"][0]["active"]) @@ -368,7 +368,7 @@ def test_failed_task(self): def test_stop_run(self): run_id = new_run(self, add_tasks=1) with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).stop_run() + response = WorkerApi(self.invalid_password_request()).stop_run() self.assertTrue("error" in response) print(response["error"]) @@ -380,7 +380,7 @@ def test_stop_run(self): {"run_id": run_id, "task_id": 0, "message": message} ) with self.assertRaises(HTTPUnauthorized): - response = ApiView(request).stop_run() + response = WorkerApi(request).stop_run() self.assertTrue("error" in response) self.assertFalse(run["tasks"][0]["active"]) @@ -388,7 +388,7 @@ def test_stop_run(self): {"username": self.username}, {"$set": {"cpu_hours": 10000}} ) - response = ApiView(request).stop_run() + response = WorkerApi(request).stop_run() response.pop("duration", None) self.assertTrue(response == {}) @@ -410,7 +410,7 @@ def test_upload_pgn(self): "pgn": base64.b64encode(gz_buffer.getvalue()).decode(), } ) - response = ApiView(request).upload_pgn() + response = WorkerApi(request).upload_pgn() response.pop("duration", None) self.assertTrue(response == {}) @@ -435,30 +435,30 @@ def test_request_spsa(self): ], } request = self.correct_password_request({"run_id": run_id, "task_id": 0}) - response = ApiView(request).request_spsa() + response = WorkerApi(request).request_spsa() self.assertTrue(response["task_alive"]) self.assertTrue(response["w_params"] is not None) self.assertTrue(response["b_params"] is not None) def test_request_version(self): with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).request_version() + response = WorkerApi(self.invalid_password_request()).request_version() self.assertTrue("error" in response) print(response["error"]) - response = ApiView(self.correct_password_request()).request_version() + response = WorkerApi(self.correct_password_request()).request_version() self.assertEqual(WORKER_VERSION, response["version"]) def test_beat(self): run_id = new_run(self, add_tasks=1) with self.assertRaises(HTTPUnauthorized): - response = ApiView(self.invalid_password_request()).beat() + response = WorkerApi(self.invalid_password_request()).beat() self.assertTrue("error" in response) print(response["error"]) request = self.correct_password_request({"run_id": run_id, "task_id": 0}) - response = ApiView(request).beat() + response = WorkerApi(request).beat() response.pop("duration", None) self.assertEqual(response, {}) @@ -543,11 +543,11 @@ def test_duplicate_workers(self): self.rundb.buffer(run, True) # Request task 1 of 2 request = self.correct_password_request() - response = ApiView(request).request_task() + response = WorkerApi(request).request_task() self.assertFalse("error" in response) # Request task 2 of 2 request = self.correct_password_request() - response = ApiView(request).request_task() + response = WorkerApi(request).request_task() self.assertFalse("error" in response) # TODO Add test for a different worker connecting @@ -561,7 +561,7 @@ def test_auto_purge_runs(self): # Request task 1 of 2 request = self.correct_password_request() - response = ApiView(request).request_task() + response = WorkerApi(request).request_task() self.assertEqual(response["run"]["_id"], str(run["_id"])) self.assertEqual(response["task_id"], 0) task1 = self.rundb.get_run(run_id)["tasks"][0] @@ -586,14 +586,14 @@ def test_auto_purge_runs(self): }, } ) - response = ApiView(request).update_task() + response = WorkerApi(request).update_task() self.assertFalse(response["task_alive"]) run = self.rundb.get_run(run_id) self.assertFalse(run["finished"]) # Request task 2 of 2 request = self.correct_password_request() - response = ApiView(request).request_task() + response = WorkerApi(request).request_task() self.assertEqual(response["run"]["_id"], str(run["_id"])) self.assertEqual(response["task_id"], 1) task2 = self.rundb.get_run(run_id)["tasks"][1] @@ -621,7 +621,7 @@ def test_auto_purge_runs(self): }, } ) - response = ApiView(request).update_task() + response = WorkerApi(request).update_task() self.assertFalse(response["task_alive"]) # The run should be marked as finished after the last task completes