diff --git a/sample_project/sample_app/tests/tests_base_tasks.py b/sample_project/sample_app/tests/tests_base_tasks.py index e9ba99b..811ede5 100644 --- a/sample_project/sample_app/tests/tests_base_tasks.py +++ b/sample_project/sample_app/tests/tests_base_tasks.py @@ -1,12 +1,21 @@ import inspect from contextlib import contextmanager from typing import Callable, Type - +from django.test import SimpleTestCase from django.utils.connection import ConnectionProxy +from gcp_pilot.mocker import patch_auth LockType = (Callable | Exception | Type[Exception]) | None +class AuthenticationMixin(SimpleTestCase): + def setUp(self) -> None: + auth = patch_auth() + auth.start() + self.addCleanup(auth.stop) + super().setUp() + + def patch_cache_lock( lock_side_effect: LockType = None, unlock_side_effect: LockType = None, diff --git a/sample_project/sample_app/tests/tests_models/__init__.py b/sample_project/sample_app/tests/tests_models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/sample_project/sample_app/tests/tests_models/tests_pipeline_models.py b/sample_project/sample_app/tests/tests_models/tests_pipeline_models.py new file mode 100644 index 0000000..55f30f8 --- /dev/null +++ b/sample_project/sample_app/tests/tests_models/tests_pipeline_models.py @@ -0,0 +1,83 @@ +from unittest.mock import call, patch + +from django.test import TestCase +from django_cloud_tasks.tests import factories + + +class PipelineModelTest(TestCase): + def test_start_pipeline(self): + pipeline = factories.PipelineFactory() + leaf_already_completed = factories.RoutineWithoutSignalFactory(status="completed") + pipeline.routines.add(leaf_already_completed) + + leaf_already_reverted = factories.RoutineWithoutSignalFactory(status="reverted") + pipeline.routines.add(leaf_already_reverted) + + with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task: + with self.assertNumQueries(1): + pipeline.start() + task.assert_not_called() + + second_routine = factories.RoutineFactory() + pipeline.routines.add(second_routine) + third_routine = factories.RoutineFactory() + pipeline.routines.add(third_routine) + first_routine = factories.RoutineFactory() + pipeline.routines.add(first_routine) + + another_first_routine = factories.RoutineFactory() + pipeline.routines.add(another_first_routine) + + factories.RoutineVertexFactory(routine=second_routine, next_routine=third_routine) + factories.RoutineVertexFactory(routine=first_routine, next_routine=second_routine) + + with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task: + with self.assertNumQueries(7): + pipeline.start() + calls = [call(routine_id=first_routine.pk), call(routine_id=another_first_routine.pk)] + task.assert_has_calls(calls, any_order=True) + + def test_revert_pipeline(self): + pipeline = factories.PipelineFactory() + + leaf_already_reverted = factories.RoutineWithoutSignalFactory(status="reverted") + pipeline.routines.add(leaf_already_reverted) + + with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as task: + with self.assertNumQueries(1): + pipeline.revert() + task.assert_not_called() + + second_routine = factories.RoutineFactory() + pipeline.routines.add(second_routine) + + third_routine = factories.RoutineWithoutSignalFactory(status="completed") + pipeline.routines.add(third_routine) + + first_routine = factories.RoutineFactory() + pipeline.routines.add(first_routine) + + fourth_routine = factories.RoutineWithoutSignalFactory(status="completed") + pipeline.routines.add(fourth_routine) + + factories.RoutineVertexFactory(routine=second_routine, next_routine=third_routine) + factories.RoutineVertexFactory(routine=first_routine, next_routine=second_routine) + + with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as task: + with self.assertNumQueries(7): + pipeline.revert() + calls = [ + call(routine_id=fourth_routine.pk), + call(routine_id=third_routine.pk), + ] + task.assert_has_calls(calls, any_order=True) + + def test_add_routine(self): + pipeline = factories.PipelineFactory() + expected_routine_1 = { + "task_name": "DummyRoutineTask", + "body": {"spell": "wingardium leviosa"}, + } + routine = pipeline.add_routine(expected_routine_1) + self.assertEqual(expected_routine_1["body"], routine.body) + self.assertEqual(expected_routine_1["task_name"], routine.task_name) diff --git a/sample_project/sample_app/tests/tests_models.py b/sample_project/sample_app/tests/tests_models/tests_routine_models.py similarity index 53% rename from sample_project/sample_app/tests/tests_models.py rename to sample_project/sample_app/tests/tests_models/tests_routine_models.py index 84e5825..fc827a5 100644 --- a/sample_project/sample_app/tests/tests_models.py +++ b/sample_project/sample_app/tests/tests_models/tests_routine_models.py @@ -1,4 +1,3 @@ -from typing import List from unittest.mock import call, patch from django.core.exceptions import ValidationError @@ -6,7 +5,6 @@ from django.utils import timezone from freezegun import freeze_time from django.db import IntegrityError -from django_cloud_tasks import models from django_cloud_tasks.tests import factories @@ -150,142 +148,3 @@ def test_ensure_max_retries_greater_than_attempt_count(self): expected_exception=IntegrityError, expected_regex="constraint failed: max_retries_less_than_attempt_count" ): factories.RoutineFactory(max_retries=1, attempt_count=5) - - -class PipelineModelTest(TestCase): - def test_start_pipeline(self): - pipeline = factories.PipelineFactory() - leaf_already_completed = factories.RoutineWithoutSignalFactory(status="completed") - pipeline.routines.add(leaf_already_completed) - - leaf_already_reverted = factories.RoutineWithoutSignalFactory(status="reverted") - pipeline.routines.add(leaf_already_reverted) - - with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task: - with self.assertNumQueries(1): - pipeline.start() - task.assert_not_called() - - second_routine = factories.RoutineFactory() - pipeline.routines.add(second_routine) - third_routine = factories.RoutineFactory() - pipeline.routines.add(third_routine) - first_routine = factories.RoutineFactory() - pipeline.routines.add(first_routine) - - another_first_routine = factories.RoutineFactory() - pipeline.routines.add(another_first_routine) - - factories.RoutineVertexFactory(routine=second_routine, next_routine=third_routine) - factories.RoutineVertexFactory(routine=first_routine, next_routine=second_routine) - - with patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") as task: - with self.assertNumQueries(7): - pipeline.start() - calls = [call(routine_id=first_routine.pk), call(routine_id=another_first_routine.pk)] - task.assert_has_calls(calls, any_order=True) - - def test_revert_pipeline(self): - pipeline = factories.PipelineFactory() - - leaf_already_reverted = factories.RoutineWithoutSignalFactory(status="reverted") - pipeline.routines.add(leaf_already_reverted) - - with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as task: - with self.assertNumQueries(1): - pipeline.revert() - task.assert_not_called() - - second_routine = factories.RoutineFactory() - pipeline.routines.add(second_routine) - - third_routine = factories.RoutineWithoutSignalFactory(status="completed") - pipeline.routines.add(third_routine) - - first_routine = factories.RoutineFactory() - pipeline.routines.add(first_routine) - - fourth_routine = factories.RoutineWithoutSignalFactory(status="completed") - pipeline.routines.add(fourth_routine) - - factories.RoutineVertexFactory(routine=second_routine, next_routine=third_routine) - factories.RoutineVertexFactory(routine=first_routine, next_routine=second_routine) - - with patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") as task: - with self.assertNumQueries(7): - pipeline.revert() - calls = [ - call(routine_id=fourth_routine.pk), - call(routine_id=third_routine.pk), - ] - task.assert_has_calls(calls, any_order=True) - - def test_add_routine(self): - pipeline = factories.PipelineFactory() - expected_routine_1 = { - "task_name": "DummyRoutineTask", - "body": {"spell": "wingardium leviosa"}, - } - routine = pipeline.add_routine(expected_routine_1) - self.assertEqual(expected_routine_1["body"], routine.body) - self.assertEqual(expected_routine_1["task_name"], routine.task_name) - - -class RoutineStateMachineTest(TestCase): - def setUp(self): - super().setUp() - revert_routine_task = patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") - routine_task = patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") - routine_task.start() - revert_routine_task.start() - self.addCleanup(routine_task.stop) - self.addCleanup(revert_routine_task.stop) - - def _status_list(self, ignore_items: list) -> list: - statuses = models.Routine.Statuses.values - for item in ignore_items: - statuses.remove(item) - return statuses - - def test_dont_allow_initial_status_not_equal_pending(self): - for status in self._status_list(ignore_items=["pending", "failed", "scheduled"]): - msg_error = f"The initial routine's status must be 'pending' not '{status}'" - with self.assertRaises(ValidationError, msg=msg_error): - factories.RoutineFactory(status=status) - - def test_ignore_if_status_was_not_updated(self): - routine = factories.RoutineFactory(status="pending") - routine.status = "pending" - routine.save() - - def test_allow_to_update_status_from_scheduled_to_running_or_failed(self): - self.assert_machine_status(accepted_status=["running", "failed", "reverting"], from_status="scheduled") - - def test_allow_to_update_status_from_running_to_completed(self): - self.assert_machine_status( - accepted_status=["completed", "failed"], - from_status="running", - ) - - def test_allow_to_update_status_from_completed_to_failed_or_reverting(self): - self.assert_machine_status(accepted_status=["reverting"], from_status="completed") - - def test_allow_to_update_status_from_reverting_to_reverted(self): - self.assert_machine_status( - accepted_status=["reverted"], - from_status="reverting", - ) - - def assert_machine_status(self, from_status: str, accepted_status: List[str]): - for status in accepted_status: - routine = factories.RoutineWithoutSignalFactory(status=from_status) - routine.status = status - routine.save() - - accepted_status.append(from_status) - for status in self._status_list(ignore_items=accepted_status): - msg_error = f"Status update from '{from_status}' to '{status}' is not allowed" - with self.assertRaises(ValidationError, msg=msg_error): - routine = factories.RoutineWithoutSignalFactory(status=from_status) - routine.status = status - routine.save() diff --git a/sample_project/sample_app/tests/tests_models/tests_routine_state_machine_models.py b/sample_project/sample_app/tests/tests_models/tests_routine_state_machine_models.py new file mode 100644 index 0000000..03e0c54 --- /dev/null +++ b/sample_project/sample_app/tests/tests_models/tests_routine_state_machine_models.py @@ -0,0 +1,67 @@ +from typing import List +from unittest.mock import patch + +from django.core.exceptions import ValidationError +from django.test import TestCase +from django_cloud_tasks import models +from django_cloud_tasks.tests import factories + + +class RoutineStateMachineTest(TestCase): + def setUp(self): + super().setUp() + revert_routine_task = patch("django_cloud_tasks.tasks.RoutineReverterTask.asap") + routine_task = patch("django_cloud_tasks.tasks.RoutineExecutorTask.asap") + routine_task.start() + revert_routine_task.start() + self.addCleanup(routine_task.stop) + self.addCleanup(revert_routine_task.stop) + + def _status_list(self, ignore_items: list) -> list: + statuses = models.Routine.Statuses.values + for item in ignore_items: + statuses.remove(item) + return statuses + + def test_dont_allow_initial_status_not_equal_pending(self): + for status in self._status_list(ignore_items=["pending", "failed", "scheduled"]): + msg_error = f"The initial routine's status must be 'pending' not '{status}'" + with self.assertRaises(ValidationError, msg=msg_error): + factories.RoutineFactory(status=status) + + def test_ignore_if_status_was_not_updated(self): + routine = factories.RoutineFactory(status="pending") + routine.status = "pending" + routine.save() + + def test_allow_to_update_status_from_scheduled_to_running_or_failed(self): + self.assert_machine_status(accepted_status=["running", "failed", "reverting"], from_status="scheduled") + + def test_allow_to_update_status_from_running_to_completed(self): + self.assert_machine_status( + accepted_status=["completed", "failed"], + from_status="running", + ) + + def test_allow_to_update_status_from_completed_to_failed_or_reverting(self): + self.assert_machine_status(accepted_status=["reverting"], from_status="completed") + + def test_allow_to_update_status_from_reverting_to_reverted(self): + self.assert_machine_status( + accepted_status=["reverted"], + from_status="reverting", + ) + + def assert_machine_status(self, from_status: str, accepted_status: List[str]): + for status in accepted_status: + routine = factories.RoutineWithoutSignalFactory(status=from_status) + routine.status = status + routine.save() + + accepted_status.append(from_status) + for status in self._status_list(ignore_items=accepted_status): + msg_error = f"Status update from '{from_status}' to '{status}' is not allowed" + with self.assertRaises(ValidationError, msg=msg_error): + routine = factories.RoutineWithoutSignalFactory(status=from_status) + routine.status = status + routine.save() diff --git a/sample_project/sample_app/tests/tests_tasks/tests_routine_executor_tasks.py b/sample_project/sample_app/tests/tests_tasks/tests_routine_executor_tasks.py new file mode 100644 index 0000000..7da1414 --- /dev/null +++ b/sample_project/sample_app/tests/tests_tasks/tests_routine_executor_tasks.py @@ -0,0 +1,191 @@ +from datetime import datetime, UTC +from unittest.mock import patch + +from django.test import TestCase + +from django_cloud_tasks.tasks import RoutineExecutorTask, TaskMetadata +from django_cloud_tasks.tests import factories, tests_base +from django_cloud_tasks.tests.tests_base import EagerTasksMixin +from sample_app import tasks +from sample_app.tests.tests_base_tasks import patch_cache_lock + + +class RoutineExecutorTaskTest(EagerTasksMixin, TestCase): + _mock_lock = None + + def setUp(self): + super().setUp() + + self.mock_lock = patch_cache_lock() + self.mock_lock.start() + self.addCleanup(self.mock_lock.stop) + + def assert_routine_lock(self, routine_id: int): + self.mock_lock.assert_called_with( + key=f"lock-RoutineExecutorTask-{routine_id}", + timeout=60, + blocking_timeout=5, + ) + + def tests_dont_process_completed_routine(self): + routine = factories.RoutineWithoutSignalFactory( + status="completed", + task_name="SayHelloTask", + ) + with self.assertLogs(level="INFO") as context: + RoutineExecutorTask.asap(routine_id=routine.pk) + self.assert_routine_lock(routine_id=routine.pk) + self.assertEqual(context.output, [f"INFO:root:Routine #{routine.pk} is already completed"]) + + def tests_start_pipeline_revert_flow_if_exceeded_retries(self): + routine = factories.RoutineWithoutSignalFactory( + status="running", + task_name="SayHelloTask", + max_retries=3, + attempt_count=1, + ) + with ( + patch("django_cloud_tasks.models.Pipeline.revert") as revert, + self.assertLogs(level="INFO") as context, + patch("sample_app.tasks.SayHelloTask.sync", side_effect=Exception("any error")), + ): + RoutineExecutorTask.asap(routine_id=routine.pk) + self.assertEqual( + context.output, + [ + f"INFO:root:Routine #{routine.id} is running", + f"INFO:root:Routine #{routine.id} has failed", + f"INFO:root:Routine #{routine.id} is being enqueued to retry", + f"INFO:root:Routine #{routine.id} is running", + f"INFO:root:Routine #{routine.id} has failed", + f"INFO:root:Routine #{routine.id} is being enqueued to retry", + f"INFO:root:Routine #{routine.id} has exhausted retries and is being reverted", + ], + ) + + self.assert_routine_lock(routine_id=routine.pk) + revert.assert_called_once() + + def tests_store_task_output_into_routine(self): + routine = factories.RoutineWithoutSignalFactory( + status="running", + task_name="SayHelloTask", + body={"attributes": [1, 2, 3]}, + attempt_count=1, + ) + with self.assertLogs(level="INFO") as context: + RoutineExecutorTask.sync(routine_id=routine.pk) + self.assert_routine_lock(routine_id=routine.pk) + routine.refresh_from_db() + self.assertEqual( + context.output, + [ + f"INFO:root:Routine #{routine.id} is running", + f"INFO:root:Routine #{routine.id} just completed", + ], + ) + self.assertEqual("completed", routine.status) + self.assertEqual(2, routine.attempt_count) + + def tests_retry_and_complete_task_processing_once_failure(self): + routine = factories.RoutineWithoutSignalFactory( + status="scheduled", + task_name="SayHelloTask", + body={"attributes": [1, 2, 3]}, + attempt_count=0, + max_retries=2, + ) + with self.assertLogs(level="INFO") as context, patch( + "sample_app.tasks.SayHelloTask.sync", side_effect=[Exception("any error"), "success"] + ): + RoutineExecutorTask.sync(routine_id=routine.pk) + self.assert_routine_lock(routine_id=routine.pk) + routine.refresh_from_db() + self.assertEqual( + context.output, + [ + f"INFO:root:Routine #{routine.id} is running", + f"INFO:root:Routine #{routine.id} has failed", + f"INFO:root:Routine #{routine.id} is being enqueued to retry", + f"INFO:root:Routine #{routine.id} is running", + f"INFO:root:Routine #{routine.id} just completed", + ], + ) + self.assertEqual("completed", routine.status) + self.assertEqual(2, routine.attempt_count) + + +class SayHelloTaskTest(TestCase, tests_base.RoutineTaskTestMixin): + @property + def task(self): + return tasks.SayHelloTask + + +class SayHelloWithParamsTaskTest(TestCase, tests_base.RoutineTaskTestMixin): + @property + def task(self): + return tasks.SayHelloWithParamsTask + + @property + def task_run_params(self): + return {"spell": "Obliviate"} + + +class TestTaskMetadata(TestCase): + some_date = datetime(1990, 7, 19, 15, 30, 42, tzinfo=UTC) + + @property + def sample_headers(self) -> dict: + return { + "X-Cloudtasks-Taskexecutioncount": 7, + "X-Cloudtasks-Taskretrycount": 1, + "X-Cloudtasks-Tasketa": str(self.some_date.timestamp()), + "X-Cloudtasks-Projectname": "wizard-project", + "X-Cloudtasks-Queuename": "wizard-queue", + "X-Cloudtasks-Taskname": "hp-1234567", + } + + @property + def sample_metadata(self) -> TaskMetadata: + return TaskMetadata( + project_id="wizard-project", + queue_name="wizard-queue", + task_id="hp-1234567", + execution_number=7, + dispatch_number=1, + eta=self.some_date, + ) + + def test_create_from_headers(self): + metadata = TaskMetadata.from_headers(headers=self.sample_headers) + + self.assertEqual(7, metadata.execution_number) + self.assertEqual(1, metadata.dispatch_number) + self.assertEqual(2, metadata.attempt_number) + self.assertEqual(self.some_date, metadata.eta) + self.assertEqual("wizard-project", metadata.project_id) + self.assertEqual("wizard-queue", metadata.queue_name) + self.assertEqual("hp-1234567", metadata.task_id) + + def test_build_headers(self): + headers = self.sample_metadata.to_headers() + + self.assertEqual("7", headers["X-Cloudtasks-Taskexecutioncount"]) + self.assertEqual("1", headers["X-Cloudtasks-Taskretrycount"]) + self.assertEqual(str(int(self.some_date.timestamp())), headers["X-Cloudtasks-Tasketa"]) + self.assertEqual("wizard-project", headers["X-Cloudtasks-Projectname"]) + self.assertEqual("wizard-queue", headers["X-Cloudtasks-Queuename"]) + self.assertEqual("hp-1234567", headers["X-Cloudtasks-Taskname"]) + + def test_comparable(self): + reference = self.sample_metadata + + metadata_a = TaskMetadata.from_headers(self.sample_headers) + self.assertEqual(reference, metadata_a) + + metadata_b = TaskMetadata.from_headers(self.sample_headers) + metadata_b.execution_number += 1 + self.assertNotEqual(reference, metadata_b) + + not_metadata = True + self.assertNotEqual(reference, not_metadata) diff --git a/sample_project/sample_app/tests/tests_tasks/tests_routine_reverter_tasks.py b/sample_project/sample_app/tests/tests_tasks/tests_routine_reverter_tasks.py new file mode 100644 index 0000000..4cc534b --- /dev/null +++ b/sample_project/sample_app/tests/tests_tasks/tests_routine_reverter_tasks.py @@ -0,0 +1,35 @@ +from contextlib import ExitStack +from unittest.mock import patch + +from django.test import TestCase + +from django_cloud_tasks.tasks import RoutineReverterTask +from django_cloud_tasks.tests import factories +from django_cloud_tasks.tests.tests_base import EagerTasksMixin +from sample_app.tests.tests_base_tasks import patch_cache_lock + + +class RoutineReverterTaskTest(EagerTasksMixin, TestCase): + _mock_lock = None + + def setUp(self): + super().setUp() + + patched_settings = self.settings(EAGER_TASKS=True) + patched_settings.enable() + self.addCleanup(patched_settings.disable) + + stack = ExitStack() + self.mock_lock = stack.enter_context(patch_cache_lock()) + + def test_process_revert_and_update_routine_to_reverted(self): + routine = factories.RoutineWithoutSignalFactory( + status="reverting", + task_name="SayHelloTask", + output={"spell": "Obliviate"}, + ) + with patch("sample_app.tasks.SayHelloTask.revert") as revert: + RoutineReverterTask.asap(routine_id=routine.pk) + revert.assert_called_once_with(data=routine.output) + routine.refresh_from_db() + self.assertEqual(routine.status, "reverted") diff --git a/sample_project/sample_app/tests/tests_tasks.py b/sample_project/sample_app/tests/tests_tasks/tests_tasks.py similarity index 64% rename from sample_project/sample_app/tests/tests_tasks.py rename to sample_project/sample_app/tests/tests_tasks/tests_tasks.py index 1e2155a..3dcd118 100644 --- a/sample_project/sample_app/tests/tests_tasks.py +++ b/sample_project/sample_app/tests/tests_tasks/tests_tasks.py @@ -1,5 +1,4 @@ import json -from contextlib import ExitStack from datetime import timedelta, datetime, UTC from unittest.mock import patch @@ -11,11 +10,10 @@ from gcp_pilot.mocker import patch_auth from django_cloud_tasks import exceptions -from django_cloud_tasks.tasks import RoutineExecutorTask, RoutineReverterTask, Task, TaskMetadata -from django_cloud_tasks.tests import factories, tests_base -from django_cloud_tasks.tests.tests_base import EagerTasksMixin, eager_tasks +from django_cloud_tasks.tasks import Task, TaskMetadata +from django_cloud_tasks.tests import tests_base +from django_cloud_tasks.tests.tests_base import eager_tasks from sample_app import tasks -from sample_app.tests.tests_base_tasks import patch_cache_lock class TasksTest(SimpleTestCase): @@ -205,137 +203,6 @@ def test_singleton_client_creates_new_instance_on_new_task(self): self.assertEqual(2, client.call_count) -class RoutineReverterTaskTest(EagerTasksMixin, TestCase): - _mock_lock = None - - def setUp(self): - super().setUp() - - patched_settings = self.settings(EAGER_TASKS=True) - patched_settings.enable() - self.addCleanup(patched_settings.disable) - - stack = ExitStack() - self.mock_lock = stack.enter_context(patch_cache_lock()) - - def test_process_revert_and_update_routine_to_reverted(self): - routine = factories.RoutineWithoutSignalFactory( - status="reverting", - task_name="SayHelloTask", - output={"spell": "Obliviate"}, - ) - with patch("sample_app.tasks.SayHelloTask.revert") as revert: - RoutineReverterTask.asap(routine_id=routine.pk) - revert.assert_called_once_with(data=routine.output) - routine.refresh_from_db() - self.assertEqual(routine.status, "reverted") - - -class RoutineExecutorTaskTest(EagerTasksMixin, TestCase): - _mock_lock = None - - def setUp(self): - super().setUp() - - self.mock_lock = patch_cache_lock() - self.mock_lock.start() - self.addCleanup(self.mock_lock.stop) - - def assert_routine_lock(self, routine_id: int): - self.mock_lock.assert_called_with( - key=f"lock-RoutineExecutorTask-{routine_id}", - timeout=60, - blocking_timeout=5, - ) - - def tests_dont_process_completed_routine(self): - routine = factories.RoutineWithoutSignalFactory( - status="completed", - task_name="SayHelloTask", - ) - with self.assertLogs(level="INFO") as context: - RoutineExecutorTask.asap(routine_id=routine.pk) - self.assert_routine_lock(routine_id=routine.pk) - self.assertEqual(context.output, [f"INFO:root:Routine #{routine.pk} is already completed"]) - - def tests_start_pipeline_revert_flow_if_exceeded_retries(self): - routine = factories.RoutineWithoutSignalFactory( - status="running", - task_name="SayHelloTask", - max_retries=3, - attempt_count=1, - ) - with ( - patch("django_cloud_tasks.models.Pipeline.revert") as revert, - self.assertLogs(level="INFO") as context, - patch("sample_app.tasks.SayHelloTask.sync", side_effect=Exception("any error")), - ): - RoutineExecutorTask.asap(routine_id=routine.pk) - self.assertEqual( - context.output, - [ - f"INFO:root:Routine #{routine.id} is running", - f"INFO:root:Routine #{routine.id} has failed", - f"INFO:root:Routine #{routine.id} is being enqueued to retry", - f"INFO:root:Routine #{routine.id} is running", - f"INFO:root:Routine #{routine.id} has failed", - f"INFO:root:Routine #{routine.id} is being enqueued to retry", - f"INFO:root:Routine #{routine.id} has exhausted retries and is being reverted", - ], - ) - - self.assert_routine_lock(routine_id=routine.pk) - revert.assert_called_once() - - def tests_store_task_output_into_routine(self): - routine = factories.RoutineWithoutSignalFactory( - status="running", - task_name="SayHelloTask", - body={"attributes": [1, 2, 3]}, - attempt_count=1, - ) - with self.assertLogs(level="INFO") as context: - RoutineExecutorTask.sync(routine_id=routine.pk) - self.assert_routine_lock(routine_id=routine.pk) - routine.refresh_from_db() - self.assertEqual( - context.output, - [ - f"INFO:root:Routine #{routine.id} is running", - f"INFO:root:Routine #{routine.id} just completed", - ], - ) - self.assertEqual("completed", routine.status) - self.assertEqual(2, routine.attempt_count) - - def tests_retry_and_complete_task_processing_once_failure(self): - routine = factories.RoutineWithoutSignalFactory( - status="scheduled", - task_name="SayHelloTask", - body={"attributes": [1, 2, 3]}, - attempt_count=0, - max_retries=2, - ) - with self.assertLogs(level="INFO") as context, patch( - "sample_app.tasks.SayHelloTask.sync", side_effect=[Exception("any error"), "success"] - ): - RoutineExecutorTask.sync(routine_id=routine.pk) - self.assert_routine_lock(routine_id=routine.pk) - routine.refresh_from_db() - self.assertEqual( - context.output, - [ - f"INFO:root:Routine #{routine.id} is running", - f"INFO:root:Routine #{routine.id} has failed", - f"INFO:root:Routine #{routine.id} is being enqueued to retry", - f"INFO:root:Routine #{routine.id} is running", - f"INFO:root:Routine #{routine.id} just completed", - ], - ) - self.assertEqual("completed", routine.status) - self.assertEqual(2, routine.attempt_count) - - class SayHelloTaskTest(TestCase, tests_base.RoutineTaskTestMixin): @property def task(self): diff --git a/sample_project/sample_app/tests/tests_views/tests_publisher_task_views.py b/sample_project/sample_app/tests/tests_views/tests_publisher_task_views.py new file mode 100644 index 0000000..4f18973 --- /dev/null +++ b/sample_project/sample_app/tests/tests_views/tests_publisher_task_views.py @@ -0,0 +1,20 @@ +from unittest.mock import patch, ANY + +from django.test import TransactionTestCase + + +class PublisherTaskTest(TransactionTestCase): + def test_propagate_headerss(self): + url = "/create-person" + data = {"name": "Harry Potter"} + headers = { + "traceparent": "trace-this-potato", + "another-random-header": "please-do-not-propagate-this", + } + django_headers = {f"HTTP_{key.upper()}": value for key, value in headers.items()} + + with patch("gcp_pilot.pubsub.CloudPublisher.publish") as publish: + self.client.post(path=url, data=data, content_type="application/json", **django_headers) + + expected_attributes = {"HTTP_Traceparent": "trace-this-potato", "any-custom-attribute": "yay!"} + publish.assert_called_once_with(message=ANY, topic_id=ANY, attributes=expected_attributes) diff --git a/sample_project/sample_app/tests/tests_views/tests_subscriber_task_views.py b/sample_project/sample_app/tests/tests_views/tests_subscriber_task_views.py new file mode 100644 index 0000000..b3ae663 --- /dev/null +++ b/sample_project/sample_app/tests/tests_views/tests_subscriber_task_views.py @@ -0,0 +1,41 @@ +from unittest.mock import patch, ANY + +from gcp_pilot.pubsub import Message +from sample_app.tests.tests_base_tasks import AuthenticationMixin + + +class SubscriberTaskViewTest(AuthenticationMixin): + def url(self, name): + return f"/subscriptions/{name}" + + def trigger_subscriber(self, content, attributes): + url = "/subscriptions/ParentSubscriberTask" + message = Message( + id="i-dont-care", + data=content, + attributes=attributes, + subscription="potato", + ) + return self.client.post(path=url, data=message.dump(), content_type="application/json") + + def test_propagate_headers(self): + content = { + "price": 10, + "quantity": 42, + } + attributes = { + "HTTP_traceparent": "trace-this-potato", + "HTTP_another-random-header": "please-do-not-propagate-this", + } + + with patch("gcp_pilot.tasks.CloudTasks.push") as push: + with patch("django_cloud_tasks.tasks.TaskMetadata.from_task_obj"): + self.trigger_subscriber(content=content, attributes=attributes) + + expected_kwargs = { + "queue_name": "tasks", + "url": "http://localhost:8080/tasks/CalculatePriceTask", + "payload": '{"price": 10, "quantity": 42}', + "headers": {"Traceparent": "trace-this-potato", "X-CloudTasks-Projectname": ANY}, + } + push.assert_called_once_with(**expected_kwargs) diff --git a/sample_project/sample_app/tests/tests_views.py b/sample_project/sample_app/tests/tests_views/tests_task_views.py similarity index 60% rename from sample_project/sample_app/tests/tests_views.py rename to sample_project/sample_app/tests/tests_views/tests_task_views.py index 5cb42d9..1562260 100644 --- a/sample_project/sample_app/tests/tests_views.py +++ b/sample_project/sample_app/tests/tests_views/tests_task_views.py @@ -1,16 +1,6 @@ from unittest.mock import patch, ANY -from django.test import SimpleTestCase, TransactionTestCase -from gcp_pilot.pubsub import Message -from gcp_pilot.mocker import patch_auth - - -class AuthenticationMixin(SimpleTestCase): - def setUp(self) -> None: - auth = patch_auth() - auth.start() - self.addCleanup(auth.stop) - super().setUp() +from sample_app.tests.tests_base_tasks import AuthenticationMixin class TaskViewTest(AuthenticationMixin): @@ -113,63 +103,3 @@ def test_absorb_headers(self): expected_content = {"Traceparent": "trace-this-potato"} self.assertEqual(expected_content, response.json()["result"]) - - -class SubscriberTaskViewTest(AuthenticationMixin): - def url(self, name): - return f"/subscriptions/{name}" - - def trigger_subscriber(self, content, attributes): - url = "/subscriptions/ParentSubscriberTask" - message = Message( - id="i-dont-care", - data=content, - attributes=attributes, - subscription="potato", - ) - return self.client.post(path=url, data=message.dump(), content_type="application/json") - - def test_propagate_headers(self): - content = { - "price": 10, - "quantity": 42, - } - attributes = { - "HTTP_traceparent": "trace-this-potato", - "HTTP_another-random-header": "please-do-not-propagate-this", - } - - with patch("gcp_pilot.tasks.CloudTasks.push") as push: - with patch("django_cloud_tasks.tasks.TaskMetadata.from_task_obj"): - self.trigger_subscriber(content=content, attributes=attributes) - - expected_kwargs = { - "queue_name": "tasks", - "url": "http://localhost:8080/tasks/CalculatePriceTask", - "payload": '{"price": 10, "quantity": 42}', - "headers": {"Traceparent": "trace-this-potato", "X-CloudTasks-Projectname": ANY}, - } - push.assert_called_once_with(**expected_kwargs) - - -class PublisherTaskTest(TransactionTestCase): - def setUp(self) -> None: - auth = patch_auth() - auth.start() - self.addCleanup(auth.stop) - super().setUp() - - def test_propagate_headerss(self): - url = "/create-person" - data = {"name": "Harry Potter"} - headers = { - "traceparent": "trace-this-potato", - "another-random-header": "please-do-not-propagate-this", - } - django_headers = {f"HTTP_{key.upper()}": value for key, value in headers.items()} - - with patch("gcp_pilot.pubsub.CloudPublisher.publish") as publish: - self.client.post(path=url, data=data, content_type="application/json", **django_headers) - - expected_attributes = {"HTTP_Traceparent": "trace-this-potato", "any-custom-attribute": "yay!"} - publish.assert_called_once_with(message=ANY, topic_id=ANY, attributes=expected_attributes)