From b8f2868ebeb37e5e5db2c72e2e008cd1b1d08ceb Mon Sep 17 00:00:00 2001 From: Rodrigo Almeida <1205851+rodrigoalmeidaee@users.noreply.github.com> Date: Sun, 4 Aug 2024 14:25:09 -0300 Subject: [PATCH] feat: allow custom HTTP status codes when using DiscardTaskException To prevent a Google Cloud Task from being retried, it is necessary to return a status code in the 200-299 range. The mechanism django_cloud_tasks currently offers for this is raising `DiscardTaskException`, but in this case the status code will always be HTTP 202 (Accepted). When we want to discard a task due to an unrecoverable error, this HTTP status code offers no good semantics. Also, from a monitoring perspective, simply discarding a task (perhaps it is no longer needed - for example: attempting to delete something that has already been deleted) and discarding a task due to an unrecoverable error (for example: the task input is invalid) are two different things, but we have no means to differentiate them if the status code is always the same. This PR adds a bit of flexibility, allowing DiscardTaskException to receive an HTTP status code / HTTP status reason phrase as either constructor arguments, or by subclassing it and overriding default_http_status_code / default_http_status_reason. This PR won't add a built-in "UnrecoverableTaskException" base class because there is no HTTP 2xx status code (even when considering augmented standards) to reflect this scenario, so we will leave it up to each project that uses django-cloud-tasks to configure this setup, as it will be project-specific by definition. --- django_cloud_tasks/exceptions.py | 9 +++- django_cloud_tasks/views.py | 12 +++-- pyproject.toml | 2 +- sample_project/sample_app/tasks.py | 44 +++++++++++++++++++ .../tests/tests_tasks/tests_tasks.py | 14 +++--- .../tests/tests_views/tests_task_views.py | 31 +++++++++++++ 6 files changed, 100 insertions(+), 12 deletions(-) diff --git a/django_cloud_tasks/exceptions.py b/django_cloud_tasks/exceptions.py index 127fdd0..d4431d4 100644 --- a/django_cloud_tasks/exceptions.py +++ b/django_cloud_tasks/exceptions.py @@ -10,4 +10,11 @@ def __init__(self, name: str): super().__init__(message) -class DiscardTaskException(Exception): ... +class DiscardTaskException(Exception): + default_http_status_code: int = 202 + default_http_status_reason: str | None = None # only needed for custom HTTP status codes + + def __init__(self, *args, http_status_code: int | None = None, http_status_reason: str | None = None, **kwargs): + super().__init__(*args, **kwargs) + self.http_status_code = http_status_code or self.default_http_status_code + self.http_status_reason = http_status_reason or self.default_http_status_reason diff --git a/django_cloud_tasks/views.py b/django_cloud_tasks/views.py index a0758c3..7dcf8f6 100644 --- a/django_cloud_tasks/views.py +++ b/django_cloud_tasks/views.py @@ -29,17 +29,21 @@ def post(self, request, task_name, *args, **kwargs): output = self.execute_task(task_class=task_class, task_metadata=task_metadata, task_kwargs=task_kwargs) status = "executed" status_code = 200 - except exceptions.DiscardTaskException: + status_reason = None + except exceptions.DiscardTaskException as e: output = None status = "discarded" - status_code = 202 + status_code = e.http_status_code + status_reason = e.http_status_reason data = {"result": output, "status": status} try: - return JsonResponse(status=status_code, data=data) + return JsonResponse(status=status_code, reason=status_reason, data=data) except TypeError: logger.warning(f"Unable to serialize task output from {request.path}: {str(output)}") - return JsonResponse(status=status_code, data={"result": str(output), "status": "executed"}) + return JsonResponse( + status=status_code, reason=status_reason, data={"result": str(output), "status": "executed"} + ) def get_task(self, name: str) -> Type[Task]: app = apps.get_app_config("django_cloud_tasks") diff --git a/pyproject.toml b/pyproject.toml index 96bcf3a..ba4064e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "django-google-cloud-tasks" -version = "2.16.4" +version = "2.17.0" description = "Async Tasks with HTTP endpoints" authors = ["Joao Daher "] packages = [ diff --git a/sample_project/sample_app/tasks.py b/sample_project/sample_app/tasks.py index a150bcb..d002f04 100644 --- a/sample_project/sample_app/tasks.py +++ b/sample_project/sample_app/tasks.py @@ -3,6 +3,7 @@ from django.db.models import Model from django_cloud_tasks.tasks import PeriodicTask, RoutineTask, SubscriberTask, Task, ModelPublisherTask, TaskMetadata +from django_cloud_tasks.exceptions import DiscardTaskException class BaseAbstractTask(Task, abc.ABC): @@ -99,6 +100,49 @@ def build_message_attributes(cls, obj: Model, event: str, **kwargs) -> dict[str, return {"any-custom-attribute": "yay!", "event": event} +class FindPrimeNumbersTask(Task): + storage: list[int] = [] + + @classmethod + def reset(cls): + cls.storage = [] + + def run(self, quantity): + if not isinstance(quantity, int): + raise DiscardTaskException( + "Can't find a non-integer amount of prime numbers", + http_status_code=299, + http_status_reason="Unretriable failure", + ) + + if len(self.storage) >= quantity: + raise DiscardTaskException("Nothing to do here") + + return self._find_primes(quantity) + + @classmethod + def _find_primes(cls, quantity: int) -> list[int]: + if not cls.storage: + cls.storage = [2] + + while len(cls.storage) < quantity: + cls.storage.append(cls._find_next_prime(cls.storage[-1] + 1)) + + return cls.storage + + @classmethod + def _find_next_prime(cls, starting_number: int) -> int: + candidate = starting_number + + while True: + for prime in cls.storage: + if candidate % prime == 0: + candidate += 1 + break + else: + return candidate + + class DummyRoutineTask(RoutineTask): def run(self, **kwargs): ... diff --git a/sample_project/sample_app/tests/tests_tasks/tests_tasks.py b/sample_project/sample_app/tests/tests_tasks/tests_tasks.py index 744ae76..d060991 100644 --- a/sample_project/sample_app/tests/tests_tasks/tests_tasks.py +++ b/sample_project/sample_app/tests/tests_tasks/tests_tasks.py @@ -50,17 +50,18 @@ def app_config(self): def test_registered_tasks(self): expected_tasks = { "CalculatePriceTask", + "DummyRoutineTask", + "ExposeCustomHeadersTask", "FailMiserablyTask", + "FindPrimeNumbersTask", + "NonCompliantTask", "OneBigDedicatedTask", + "ParentCallingChildTask", + "PublishPersonTask", "RoutineExecutorTask", + "RoutineReverterTask", "SayHelloTask", "SayHelloWithParamsTask", - "DummyRoutineTask", - "RoutineReverterTask", - "ParentCallingChildTask", - "ExposeCustomHeadersTask", - "PublishPersonTask", - "NonCompliantTask", } self.assertEqual(expected_tasks, set(self.app_config.on_demand_tasks)) @@ -89,6 +90,7 @@ def test_get_tasks(self): tasks.SayHelloTask, tasks.SayHelloWithParamsTask, tasks.PublishPersonTask, + tasks.FindPrimeNumbersTask, tasks.DummyRoutineTask, another_app_tasks.deep_down_tasks.one_dedicated_task.OneBigDedicatedTask, another_app_tasks.deep_down_tasks.one_dedicated_task.NonCompliantTask, diff --git a/sample_project/sample_app/tests/tests_views/tests_task_views.py b/sample_project/sample_app/tests/tests_views/tests_task_views.py index 62d820c..b54b4b0 100644 --- a/sample_project/sample_app/tests/tests_views/tests_task_views.py +++ b/sample_project/sample_app/tests/tests_views/tests_task_views.py @@ -3,6 +3,7 @@ from another_app.tasks.deep_down_tasks.one_dedicated_task import NonCompliantTask from django_cloud_tasks.tasks import TaskMetadata from sample_app.tests.tests_base_tasks import AuthenticationMixin +from sample_app.tasks import FindPrimeNumbersTask class TaskViewTest(AuthenticationMixin): @@ -116,3 +117,33 @@ def test_non_compliant_task_called(self): response = self.client.post(path=url, data=data, content_type="application/json") self.assertEqual(200, response.status_code) self.assertEqual({"result": ANY, "status": "executed"}, response.json()) + + +class TaskDiscardingTest(AuthenticationMixin): + url = "/tasks/FindPrimeNumbersTask" + + def setUp(self): + super().setUp() + FindPrimeNumbersTask.reset() + + def call_task(self, data): + return self.client.post(path=self.url, data=data, content_type="application/json") + + def test_when_task_is_not_discarded(self): + response = self.call_task(data={"quantity": 3}) + self.assertEqual(200, response.status_code) + self.assertEqual({"result": [2, 3, 5], "status": "executed"}, response.json()) + + def test_when_task_is_discarded_due_to_no_longer_being_needed(self): + self.call_task(data={"quantity": 3}) + + response = self.call_task(data={"quantity": 3}) + self.assertEqual(202, response.status_code) + self.assertEqual("Accepted", response.reason_phrase) + self.assertEqual({"result": None, "status": "discarded"}, response.json()) + + def test_when_task_is_discarded_due_to_permanent_error(self): + response = self.call_task(data={"quantity": "not-a-number"}) + self.assertEqual(299, response.status_code) + self.assertEqual("Unretriable failure", response.reason_phrase) + self.assertEqual({"result": None, "status": "discarded"}, response.json())