diff --git a/README.md b/README.md index 11516d0..03f099f 100644 --- a/README.md +++ b/README.md @@ -223,9 +223,14 @@ Jobs have a `state` field which can have one of the following values: * `NEW` (has been created, waiting for a worker process to run the next task) * `READY` (has run a task before, awaiting a worker process to run the next task) * `PROCESSING` (a task is currently being processed by a worker) +* `STOPPING` (the worker process has received a signal from the OS requesting it to exit) * `COMPLETED` (all job tasks have completed successfully) * `FAILED` (a job task failed) +#### State diagram + +![state diagram](states.png) + ### API #### Model methods diff --git a/django_dbq/management/commands/worker.py b/django_dbq/management/commands/worker.py index 9d72871..5f8d7b2 100644 --- a/django_dbq/management/commands/worker.py +++ b/django_dbq/management/commands/worker.py @@ -14,72 +14,13 @@ DEFAULT_QUEUE_NAME = "default" -def process_job(queue_name): - """This function grabs the next available job for a given queue, and runs its next task.""" - - with transaction.atomic(): - job = Job.objects.get_ready_or_none(queue_name) - if not job: - return - - logger.info( - 'Processing job: name="%s" queue="%s" id=%s state=%s next_task=%s', - job.name, - queue_name, - job.pk, - job.state, - job.next_task, - ) - job.state = Job.STATES.PROCESSING - job.save() - - try: - task_function = import_string(job.next_task) - task_function(job) - job.update_next_task() - if not job.next_task: - job.state = Job.STATES.COMPLETE - else: - job.state = Job.STATES.READY - except Exception as exception: - logger.exception("Job id=%s failed", job.pk) - job.state = Job.STATES.FAILED - - failure_hook_name = job.get_failure_hook_name() - if failure_hook_name: - logger.info( - "Running failure hook %s for job id=%s", failure_hook_name, job.pk - ) - failure_hook_function = import_string(failure_hook_name) - failure_hook_function(job, exception) - else: - logger.info("No failure hook for job id=%s", job.pk) - - logger.info( - 'Updating job: name="%s" id=%s state=%s next_task=%s', - job.name, - job.pk, - job.state, - job.next_task or "none", - ) - - try: - job.save() - except: - logger.error( - "Failed to save job: id=%s org=%s", - job.pk, - job.workspace.get("organisation_id"), - ) - raise - - class Worker: def __init__(self, name, rate_limit_in_seconds): self.queue_name = name self.rate_limit_in_seconds = rate_limit_in_seconds self.alive = True self.last_job_finished = None + self.current_job = None self.init_signals() def init_signals(self): @@ -93,6 +34,9 @@ def init_signals(self): def shutdown(self, signum, frame): self.alive = False + if self.current_job: + self.current_job.state = Job.STATES.STOPPING + self.current_job.save(update_fields=["state"]) def run(self): while self.alive: @@ -107,9 +51,66 @@ def process_job(self): ): return - process_job(self.queue_name) + self._process_job() + self.last_job_finished = timezone.now() + def _process_job(self): + with transaction.atomic(): + job = Job.objects.get_ready_or_none(self.queue_name) + if not job: + return + + logger.info( + 'Processing job: name="%s" queue="%s" id=%s state=%s next_task=%s', + job.name, + self.queue_name, + job.pk, + job.state, + job.next_task, + ) + job.state = Job.STATES.PROCESSING + job.save() + self.current_job = job + + try: + task_function = import_string(job.next_task) + task_function(job) + job.update_next_task() + if not job.next_task: + job.state = Job.STATES.COMPLETE + else: + job.state = Job.STATES.READY + except Exception as exception: + logger.exception("Job id=%s failed", job.pk) + job.state = Job.STATES.FAILED + + failure_hook_name = job.get_failure_hook_name() + if failure_hook_name: + logger.info( + "Running failure hook %s for job id=%s", failure_hook_name, job.pk + ) + failure_hook_function = import_string(failure_hook_name) + failure_hook_function(job, exception) + else: + logger.info("No failure hook for job id=%s", job.pk) + + logger.info( + 'Updating job: name="%s" id=%s state=%s next_task=%s', + job.name, + job.pk, + job.state, + job.next_task or "none", + ) + + try: + job.save() + except: + logger.exception("Failed to save job: id=%s", job.pk) + raise + + self.current_job = None + class Command(BaseCommand): diff --git a/django_dbq/migrations/0006_alter_job_state.py b/django_dbq/migrations/0006_alter_job_state.py new file mode 100644 index 0000000..e7c51cb --- /dev/null +++ b/django_dbq/migrations/0006_alter_job_state.py @@ -0,0 +1,30 @@ +# Generated by Django 3.2rc1 on 2021-11-29 04:48 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("django_dbq", "0005_job_run_after"), + ] + + operations = [ + migrations.AlterField( + model_name="job", + name="state", + field=models.CharField( + choices=[ + ("NEW", "New"), + ("READY", "Ready"), + ("PROCESSING", "Processing"), + ("STOPPING", "Stopping"), + ("FAILED", "Failed"), + ("COMPLETE", "Complete"), + ], + db_index=True, + default="NEW", + max_length=20, + ), + ), + ] diff --git a/django_dbq/models.py b/django_dbq/models.py index 5669861..3e12289 100644 --- a/django_dbq/models.py +++ b/django_dbq/models.py @@ -53,7 +53,11 @@ def delete_old(self): """ Delete all jobs older than DELETE_JOBS_AFTER_HOURS """ - delete_jobs_in_states = [Job.STATES.FAILED, Job.STATES.COMPLETE] + delete_jobs_in_states = [ + Job.STATES.FAILED, + Job.STATES.COMPLETE, + Job.STATES.STOPPING, + ] delete_jobs_created_before = timezone.now() - datetime.timedelta( hours=DELETE_JOBS_AFTER_HOURS ) @@ -82,6 +86,7 @@ class STATES(TextChoices): NEW = "NEW" READY = "READY" PROCESSING = "PROCESSING" + STOPPING = "STOPPING" FAILED = "FAILED" COMPLETE = "COMPLETE" diff --git a/django_dbq/tests.py b/django_dbq/tests.py index c1f828d..3ae7ab9 100644 --- a/django_dbq/tests.py +++ b/django_dbq/tests.py @@ -7,7 +7,7 @@ from django.test.utils import override_settings from django.utils import timezone -from django_dbq.management.commands.worker import process_job, Worker +from django_dbq.management.commands.worker import Worker from django_dbq.models import Job from io import StringIO @@ -123,41 +123,53 @@ def test_queue_depth_for_queue_with_zero_jobs(self): @freezegun.freeze_time() @mock.patch("django_dbq.management.commands.worker.sleep") -@mock.patch("django_dbq.management.commands.worker.process_job") class WorkerProcessProcessJobTestCase(TestCase): def setUp(self): super().setUp() - self.MockWorker = mock.MagicMock() - self.MockWorker.queue_name = "default" - self.MockWorker.rate_limit_in_seconds = 5 - self.MockWorker.last_job_finished = None + self.mock_worker = mock.MagicMock() + self.mock_worker.queue_name = "default" + self.mock_worker.rate_limit_in_seconds = 5 + self.mock_worker.last_job_finished = None - def test_process_job_no_previous_job_run(self, mock_process_job, mock_sleep): - Worker.process_job(self.MockWorker) + def test_process_job_no_previous_job_run(self, mock_sleep): + Worker.process_job(self.mock_worker) self.assertEqual(mock_sleep.call_count, 1) - self.assertEqual(mock_process_job.call_count, 1) - self.assertEqual(self.MockWorker.last_job_finished, timezone.now()) + self.assertEqual(self.mock_worker._process_job.call_count, 1) + self.assertEqual(self.mock_worker.last_job_finished, timezone.now()) - def test_process_job_previous_job_too_soon(self, mock_process_job, mock_sleep): - self.MockWorker.last_job_finished = timezone.now() - timezone.timedelta( + def test_process_job_previous_job_too_soon(self, mock_sleep): + self.mock_worker.last_job_finished = timezone.now() - timezone.timedelta( seconds=2 ) - Worker.process_job(self.MockWorker) + Worker.process_job(self.mock_worker) self.assertEqual(mock_sleep.call_count, 1) - self.assertEqual(mock_process_job.call_count, 0) + self.assertEqual(self.mock_worker._process_job.call_count, 0) self.assertEqual( - self.MockWorker.last_job_finished, + self.mock_worker.last_job_finished, timezone.now() - timezone.timedelta(seconds=2), ) - def test_process_job_previous_job_long_time_ago(self, mock_process_job, mock_sleep): - self.MockWorker.last_job_finished = timezone.now() - timezone.timedelta( + def test_process_job_previous_job_long_time_ago(self, mock_sleep): + self.mock_worker.last_job_finished = timezone.now() - timezone.timedelta( seconds=7 ) - Worker.process_job(self.MockWorker) + Worker.process_job(self.mock_worker) self.assertEqual(mock_sleep.call_count, 1) - self.assertEqual(mock_process_job.call_count, 1) - self.assertEqual(self.MockWorker.last_job_finished, timezone.now()) + self.assertEqual(self.mock_worker._process_job.call_count, 1) + self.assertEqual(self.mock_worker.last_job_finished, timezone.now()) + + +@override_settings(JOBS={"testjob": {"tasks": ["a"]}}) +class ShutdownTestCase(TestCase): + def test_shutdown_sets_state_to_stopping(self): + job = Job.objects.create(name="testjob") + worker = Worker("default", 1) + worker.current_job = job + + worker.shutdown(None, None) + + job.refresh_from_db() + self.assertEqual(job.state, Job.STATES.STOPPING) @override_settings(JOBS={"testjob": {"tasks": ["a"]}}) @@ -267,7 +279,7 @@ def test_task_sequence(self): class ProcessJobTestCase(TestCase): def test_process_job(self): job = Job.objects.create(name="testjob") - process_job("default") + Worker("default", 1)._process_job() job = Job.objects.get() self.assertEqual(job.state, Job.STATES.COMPLETE) @@ -276,7 +288,7 @@ def test_process_job_wrong_queue(self): Processing a different queue shouldn't touch our other job """ job = Job.objects.create(name="testjob", queue_name="lol") - process_job("default") + Worker("default", 1)._process_job() job = Job.objects.get() self.assertEqual(job.state, Job.STATES.NEW) @@ -315,7 +327,7 @@ def test_creation_hook_only_runs_on_create(self): class JobFailureHookTestCase(TestCase): def test_failure_hook(self): job = Job.objects.create(name="testjob") - process_job("default") + Worker("default", 1)._process_job() job = Job.objects.get() self.assertEqual(job.state, Job.STATES.FAILED) self.assertEqual(job.workspace["output"], "failure hook ran") @@ -334,14 +346,18 @@ def test_delete_old_jobs(self): j2.created = two_days_ago j2.save() - j3 = Job.objects.create(name="testjob", state=Job.STATES.NEW) + j3 = Job.objects.create(name="testjob", state=Job.STATES.STOPPING) j3.created = two_days_ago j3.save() - j4 = Job.objects.create(name="testjob", state=Job.STATES.COMPLETE) + j4 = Job.objects.create(name="testjob", state=Job.STATES.NEW) + j4.created = two_days_ago + j4.save() + + j5 = Job.objects.create(name="testjob", state=Job.STATES.COMPLETE) Job.objects.delete_old() self.assertEqual(Job.objects.count(), 2) - self.assertTrue(j3 in Job.objects.all()) self.assertTrue(j4 in Job.objects.all()) + self.assertTrue(j5 in Job.objects.all()) diff --git a/states.png b/states.png new file mode 100644 index 0000000..acc2858 Binary files /dev/null and b/states.png differ