diff --git a/cms/server/contest/handlers/api.py b/cms/server/contest/handlers/api.py index 519106ec4c..9cba1231ce 100644 --- a/cms/server/contest/handlers/api.py +++ b/cms/server/contest/handlers/api.py @@ -20,18 +20,15 @@ """ -from collections.abc import Callable -import functools import ipaddress import logging -import typing from cms.db.submission import Submission from cms.server import multi_contest from cms.server.contest.authentication import validate_login from cms.server.contest.submission import \ UnacceptableSubmission, accept_submission -from .contest import ContestHandler +from .contest import ContestHandler, api_login_required from ..phase_management import actual_phase_required logger = logging.getLogger(__name__) @@ -47,27 +44,6 @@ def __init__(self, *args, **kwargs): self.api_request = True -_P = typing.ParamSpec("_P") -_R = typing.TypeVar("_R") -_Self = typing.TypeVar("_Self", bound="ApiContestHandler") - -def api_login_required( - func: Callable[typing.Concatenate[_Self, _P], _R], -) -> Callable[typing.Concatenate[_Self, _P], _R | None]: - """A decorator filtering out unauthenticated requests. - - """ - - @functools.wraps(func) - def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs): - if not self.current_user: - self.json({"error": "An authenticated user is required"}, 403) - else: - return func(self, *args, **kwargs) - - return wrapped - - class ApiLoginHandler(ApiContestHandler): """Login handler. diff --git a/cms/server/contest/handlers/contest.py b/cms/server/contest/handlers/contest.py index 9624816666..be0b55b74e 100644 --- a/cms/server/contest/handlers/contest.py +++ b/cms/server/contest/handlers/contest.py @@ -29,9 +29,12 @@ """ +from collections.abc import Callable +import functools import ipaddress import json import logging +import typing import collections @@ -334,3 +337,26 @@ def check_xsrf_cookie(self): class FileHandler(ContestHandler, FileHandlerMixin): pass + +_P = typing.ParamSpec("_P") +_R = typing.TypeVar("_R") +_Self = typing.TypeVar("_Self", bound="ContestHandler") + +def api_login_required( + func: Callable[typing.Concatenate[_Self, _P], _R], +) -> Callable[typing.Concatenate[_Self, _P], _R | None]: + """A decorator filtering out unauthenticated requests. + + Unlike @tornado.web.authenticated, this returns a JSON error instead of + redirecting. + + """ + + @functools.wraps(func) + def wrapped(self: _Self, *args: _P.args, **kwargs: _P.kwargs): + if not self.current_user: + self.json({"error": "An authenticated user is required"}, 403) + else: + return func(self, *args, **kwargs) + + return wrapped diff --git a/cms/server/contest/handlers/main.py b/cms/server/contest/handlers/main.py index 42121eaf0b..d9af80f961 100644 --- a/cms/server/contest/handlers/main.py +++ b/cms/server/contest/handlers/main.py @@ -58,7 +58,7 @@ UnacceptablePrintJob from cmscommon.crypto import hash_password, validate_password from cmscommon.datetime import make_datetime, make_timestamp -from .contest import ContestHandler +from .contest import ContestHandler, api_login_required from ..phase_management import actual_phase_required @@ -294,7 +294,7 @@ class NotificationsHandler(ContestHandler): refresh_cookie = False - @tornado.web.authenticated + @api_login_required @multi_contest def get(self): participation: Participation = self.current_user diff --git a/cms/server/contest/handlers/tasksubmission.py b/cms/server/contest/handlers/tasksubmission.py index 609613c749..a5cdedee4b 100644 --- a/cms/server/contest/handlers/tasksubmission.py +++ b/cms/server/contest/handlers/tasksubmission.py @@ -57,7 +57,7 @@ UnacceptableToken, TokenAlreadyPlayed, accept_token, tokens_available from cmscommon.crypto import encrypt_number from cmscommon.mimetypes import get_type_for_file_name -from .contest import ContestHandler, FileHandler +from .contest import ContestHandler, FileHandler, api_login_required from ..phase_management import actual_phase_required @@ -236,7 +236,7 @@ def add_task_score(self, participation: Participation, task: Task, data: dict): data["task_tokened_score"], score_type.max_score, None, task.score_precision, translation=self.translation) - @tornado.web.authenticated + @api_login_required @actual_phase_required(0, 1, 2, 3, 4) @multi_contest def get(self, task_name, opaque_id): @@ -296,7 +296,7 @@ class SubmissionDetailsHandler(ContestHandler): refresh_cookie = False - @tornado.web.authenticated + @api_login_required @actual_phase_required(0, 1, 2, 3, 4) @multi_contest def get(self, task_name, opaque_id): diff --git a/cms/server/contest/handlers/taskusertest.py b/cms/server/contest/handlers/taskusertest.py index 7ac572cb21..fc9e9814db 100644 --- a/cms/server/contest/handlers/taskusertest.py +++ b/cms/server/contest/handlers/taskusertest.py @@ -48,7 +48,7 @@ TestingNotAllowed, UnacceptableUserTest, accept_user_test from cmscommon.crypto import encrypt_number from cmscommon.mimetypes import get_type_for_file_name -from .contest import ContestHandler, FileHandler +from .contest import ContestHandler, FileHandler, api_login_required from ..phase_management import actual_phase_required @@ -166,7 +166,7 @@ class UserTestStatusHandler(ContestHandler): refresh_cookie = False - @tornado.web.authenticated + @api_login_required @actual_phase_required(0) @multi_contest def get(self, task_name, user_test_num): @@ -221,7 +221,7 @@ class UserTestDetailsHandler(ContestHandler): refresh_cookie = False - @tornado.web.authenticated + @api_login_required @actual_phase_required(0) @multi_contest def get(self, task_name, user_test_num): diff --git a/cms/server/contest/templates/task_submissions.html b/cms/server/contest/templates/task_submissions.html index e821fec4b4..ea5a08f2b8 100644 --- a/cms/server/contest/templates/task_submissions.html +++ b/cms/server/contest/templates/task_submissions.html @@ -40,7 +40,11 @@ var modal = $("#submission_detail"); var modal_body = modal.children(".modal-body"); modal_body.html('