diff --git a/django_cloud_tasks/exceptions.py b/django_cloud_tasks/exceptions.py index 127fdd0..d4431d4 100644 --- a/django_cloud_tasks/exceptions.py +++ b/django_cloud_tasks/exceptions.py @@ -10,4 +10,11 @@ def __init__(self, name: str): super().__init__(message) -class DiscardTaskException(Exception): ... +class DiscardTaskException(Exception): + default_http_status_code: int = 202 + default_http_status_reason: str | None = None # only needed for custom HTTP status codes + + def __init__(self, *args, http_status_code: int | None = None, http_status_reason: str | None = None, **kwargs): + super().__init__(*args, **kwargs) + self.http_status_code = http_status_code or self.default_http_status_code + self.http_status_reason = http_status_reason or self.default_http_status_reason diff --git a/django_cloud_tasks/views.py b/django_cloud_tasks/views.py index a0758c3..7dcf8f6 100644 --- a/django_cloud_tasks/views.py +++ b/django_cloud_tasks/views.py @@ -29,17 +29,21 @@ def post(self, request, task_name, *args, **kwargs): output = self.execute_task(task_class=task_class, task_metadata=task_metadata, task_kwargs=task_kwargs) status = "executed" status_code = 200 - except exceptions.DiscardTaskException: + status_reason = None + except exceptions.DiscardTaskException as e: output = None status = "discarded" - status_code = 202 + status_code = e.http_status_code + status_reason = e.http_status_reason data = {"result": output, "status": status} try: - return JsonResponse(status=status_code, data=data) + return JsonResponse(status=status_code, reason=status_reason, data=data) except TypeError: logger.warning(f"Unable to serialize task output from {request.path}: {str(output)}") - return JsonResponse(status=status_code, data={"result": str(output), "status": "executed"}) + return JsonResponse( + status=status_code, reason=status_reason, data={"result": str(output), "status": "executed"} + ) def get_task(self, name: str) -> Type[Task]: app = apps.get_app_config("django_cloud_tasks") diff --git a/pyproject.toml b/pyproject.toml index 96bcf3a..ba4064e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "django-google-cloud-tasks" -version = "2.16.4" +version = "2.17.0" description = "Async Tasks with HTTP endpoints" authors = ["Joao Daher "] packages = [ diff --git a/sample_project/sample_app/tasks.py b/sample_project/sample_app/tasks.py index a150bcb..d002f04 100644 --- a/sample_project/sample_app/tasks.py +++ b/sample_project/sample_app/tasks.py @@ -3,6 +3,7 @@ from django.db.models import Model from django_cloud_tasks.tasks import PeriodicTask, RoutineTask, SubscriberTask, Task, ModelPublisherTask, TaskMetadata +from django_cloud_tasks.exceptions import DiscardTaskException class BaseAbstractTask(Task, abc.ABC): @@ -99,6 +100,49 @@ def build_message_attributes(cls, obj: Model, event: str, **kwargs) -> dict[str, return {"any-custom-attribute": "yay!", "event": event} +class FindPrimeNumbersTask(Task): + storage: list[int] = [] + + @classmethod + def reset(cls): + cls.storage = [] + + def run(self, quantity): + if not isinstance(quantity, int): + raise DiscardTaskException( + "Can't find a non-integer amount of prime numbers", + http_status_code=299, + http_status_reason="Unretriable failure", + ) + + if len(self.storage) >= quantity: + raise DiscardTaskException("Nothing to do here") + + return self._find_primes(quantity) + + @classmethod + def _find_primes(cls, quantity: int) -> list[int]: + if not cls.storage: + cls.storage = [2] + + while len(cls.storage) < quantity: + cls.storage.append(cls._find_next_prime(cls.storage[-1] + 1)) + + return cls.storage + + @classmethod + def _find_next_prime(cls, starting_number: int) -> int: + candidate = starting_number + + while True: + for prime in cls.storage: + if candidate % prime == 0: + candidate += 1 + break + else: + return candidate + + class DummyRoutineTask(RoutineTask): def run(self, **kwargs): ... diff --git a/sample_project/sample_app/tests/tests_tasks/tests_tasks.py b/sample_project/sample_app/tests/tests_tasks/tests_tasks.py index 744ae76..d060991 100644 --- a/sample_project/sample_app/tests/tests_tasks/tests_tasks.py +++ b/sample_project/sample_app/tests/tests_tasks/tests_tasks.py @@ -50,17 +50,18 @@ def app_config(self): def test_registered_tasks(self): expected_tasks = { "CalculatePriceTask", + "DummyRoutineTask", + "ExposeCustomHeadersTask", "FailMiserablyTask", + "FindPrimeNumbersTask", + "NonCompliantTask", "OneBigDedicatedTask", + "ParentCallingChildTask", + "PublishPersonTask", "RoutineExecutorTask", + "RoutineReverterTask", "SayHelloTask", "SayHelloWithParamsTask", - "DummyRoutineTask", - "RoutineReverterTask", - "ParentCallingChildTask", - "ExposeCustomHeadersTask", - "PublishPersonTask", - "NonCompliantTask", } self.assertEqual(expected_tasks, set(self.app_config.on_demand_tasks)) @@ -89,6 +90,7 @@ def test_get_tasks(self): tasks.SayHelloTask, tasks.SayHelloWithParamsTask, tasks.PublishPersonTask, + tasks.FindPrimeNumbersTask, tasks.DummyRoutineTask, another_app_tasks.deep_down_tasks.one_dedicated_task.OneBigDedicatedTask, another_app_tasks.deep_down_tasks.one_dedicated_task.NonCompliantTask, diff --git a/sample_project/sample_app/tests/tests_views/tests_task_views.py b/sample_project/sample_app/tests/tests_views/tests_task_views.py index 62d820c..b54b4b0 100644 --- a/sample_project/sample_app/tests/tests_views/tests_task_views.py +++ b/sample_project/sample_app/tests/tests_views/tests_task_views.py @@ -3,6 +3,7 @@ from another_app.tasks.deep_down_tasks.one_dedicated_task import NonCompliantTask from django_cloud_tasks.tasks import TaskMetadata from sample_app.tests.tests_base_tasks import AuthenticationMixin +from sample_app.tasks import FindPrimeNumbersTask class TaskViewTest(AuthenticationMixin): @@ -116,3 +117,33 @@ def test_non_compliant_task_called(self): response = self.client.post(path=url, data=data, content_type="application/json") self.assertEqual(200, response.status_code) self.assertEqual({"result": ANY, "status": "executed"}, response.json()) + + +class TaskDiscardingTest(AuthenticationMixin): + url = "/tasks/FindPrimeNumbersTask" + + def setUp(self): + super().setUp() + FindPrimeNumbersTask.reset() + + def call_task(self, data): + return self.client.post(path=self.url, data=data, content_type="application/json") + + def test_when_task_is_not_discarded(self): + response = self.call_task(data={"quantity": 3}) + self.assertEqual(200, response.status_code) + self.assertEqual({"result": [2, 3, 5], "status": "executed"}, response.json()) + + def test_when_task_is_discarded_due_to_no_longer_being_needed(self): + self.call_task(data={"quantity": 3}) + + response = self.call_task(data={"quantity": 3}) + self.assertEqual(202, response.status_code) + self.assertEqual("Accepted", response.reason_phrase) + self.assertEqual({"result": None, "status": "discarded"}, response.json()) + + def test_when_task_is_discarded_due_to_permanent_error(self): + response = self.call_task(data={"quantity": "not-a-number"}) + self.assertEqual(299, response.status_code) + self.assertEqual("Unretriable failure", response.reason_phrase) + self.assertEqual({"result": None, "status": "discarded"}, response.json())