From 71f53be102e8ce86bb1982385399d757bdf69b90 Mon Sep 17 00:00:00 2001 From: James Addison Date: Thu, 20 Jun 2024 09:25:18 -0700 Subject: [PATCH] Add pre and post task hooks. --- README.md | 23 +++++++++++ django_dbq/__init__.py | 2 +- django_dbq/management/commands/worker.py | 4 ++ django_dbq/models.py | 22 ++++++++++ django_dbq/tasks.py | 12 ++++++ django_dbq/tests.py | 52 ++++++++++++++++++++++++ 6 files changed, 114 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 5562651..cb77aef 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,29 @@ JOBS = { } ``` +#### Pre & Post Task Hooks +You can also run pre task or post task hooks, which happen in the normal processing of your `Job` instances and are executed in the worker process. + +Both pre and post task hooks receive your `Job` instance as their only argument. Here's an example: + +```python +def my_pre_task_hook(job): + ... # configure something before running your task +``` + +To ensure these hooks gets run, simply add a `pre_task_hook` or `post_task_hook` key (or both, if needed) to your job config like so: + +```python +JOBS = { + "my_job": { + "tasks": ["project.common.jobs.my_task"], + "pre_task_hook": "project.common.jobs.my_pre_task_hook", + "post_task_hook": "project.common.jobs.my_post_task_hook", + }, +} +``` + + ### Start the worker In another terminal: diff --git a/django_dbq/__init__.py b/django_dbq/__init__.py index f5f41e5..1173108 100644 --- a/django_dbq/__init__.py +++ b/django_dbq/__init__.py @@ -1 +1 @@ -__version__ = "3.1.0" +__version__ = "3.2.0" diff --git a/django_dbq/management/commands/worker.py b/django_dbq/management/commands/worker.py index 9215aad..7434981 100644 --- a/django_dbq/management/commands/worker.py +++ b/django_dbq/management/commands/worker.py @@ -61,6 +61,8 @@ def _process_job(self): if not job: return + job.run_pre_task_hook() + logger.info( 'Processing job: name="%s" queue="%s" id=%s state=%s next_task=%s', job.name, @@ -109,6 +111,8 @@ def _process_job(self): logger.exception("Failed to save job: id=%s", job.pk) raise + job.run_post_task_hook() + self.current_job = None diff --git a/django_dbq/models.py b/django_dbq/models.py index d93a05b..31c4aef 100644 --- a/django_dbq/models.py +++ b/django_dbq/models.py @@ -3,6 +3,8 @@ from django.utils.module_loading import import_string from django_dbq.tasks import ( get_next_task_name, + get_pre_task_hook_name, + get_post_task_hook_name, get_failure_hook_name, get_creation_hook_name, ) @@ -126,12 +128,32 @@ def save(self, *args, **kwargs): def update_next_task(self): self.next_task = get_next_task_name(self.name, self.next_task) or "" + def get_pre_task_hook_name(self): + return get_pre_task_hook_name(self.name) + + def get_post_task_hook_name(self): + return get_post_task_hook_name(self.name) + def get_failure_hook_name(self): return get_failure_hook_name(self.name) def get_creation_hook_name(self): return get_creation_hook_name(self.name) + def run_pre_task_hook(self): + pre_task_hook_name = self.get_pre_task_hook_name() + if pre_task_hook_name: + logger.info("Running pre_task hook %s for new job", pre_task_hook_name) + pre_task_hook_function = import_string(pre_task_hook_name) + pre_task_hook_function(self) + + def run_post_task_hook(self): + post_task_hook_name = self.get_post_task_hook_name() + if post_task_hook_name: + logger.info("Running post_task hook %s for new job", post_task_hook_name) + post_task_hook_function = import_string(post_task_hook_name) + post_task_hook_function(self) + def run_creation_hook(self): creation_hook_name = self.get_creation_hook_name() if creation_hook_name: diff --git a/django_dbq/tasks.py b/django_dbq/tasks.py index 3e43da3..a95b4a5 100644 --- a/django_dbq/tasks.py +++ b/django_dbq/tasks.py @@ -2,6 +2,8 @@ TASK_LIST_KEY = "tasks" +PRE_TASK_HOOK_KEY = "pre_task_hook" +POST_TASK_HOOK_KEY = "post_task_hook" FAILURE_HOOK_KEY = "failure_hook" CREATION_HOOK_KEY = "creation_hook" @@ -24,6 +26,16 @@ def get_next_task_name(job_name, current_task=None): return None +def get_pre_task_hook_name(job_name): + """Return the name of the pre task hook for the given job (as a string) or None""" + return settings.JOBS[job_name].get(PRE_TASK_HOOK_KEY) + + +def get_post_task_hook_name(job_name): + """Return the name of the post_task hook for the given job (as a string) or None""" + return settings.JOBS[job_name].get(POST_TASK_HOOK_KEY) + + def get_failure_hook_name(job_name): """Return the name of the failure hook for the given job (as a string) or None""" return settings.JOBS[job_name].get(FAILURE_HOOK_KEY) diff --git a/django_dbq/tests.py b/django_dbq/tests.py index ad376cd..39c0ba5 100644 --- a/django_dbq/tests.py +++ b/django_dbq/tests.py @@ -34,12 +34,25 @@ def failing_task(job): raise Exception("uh oh") +def pre_task_hook(job): + job.workspace["output"] = "pre task hook ran" + job.workspace["job_id"] = str(job.id) + + +def post_task_hook(job): + job.workspace["output"] = "post task hook ran" + job.workspace["job_id"] = str(job.id) + + def failure_hook(job, exception): job.workspace["output"] = "failure hook ran" + job.workspace["exception"] = str(exception) + job.workspace["job_id"] = str(job.id) def creation_hook(job): job.workspace["output"] = "creation hook ran" + job.workspace["job_id"] = str(job.id) @override_settings(JOBS={"testjob": {"tasks": ["a"]}}) @@ -316,6 +329,7 @@ def test_creation_hook(self): job = Job.objects.create(name="testjob") job = Job.objects.get() self.assertEqual(job.workspace["output"], "creation hook ran") + self.assertEqual(job.workspace["job_id"], str(job.id)) def test_creation_hook_only_runs_on_create(self): job = Job.objects.create(name="testjob") @@ -326,6 +340,42 @@ def test_creation_hook_only_runs_on_create(self): self.assertEqual(job.workspace["output"], "creation hook output removed") +@override_settings( + JOBS={ + "testjob": { + "tasks": ["django_dbq.tests.failing_task"], + "pre_task_hook": "django_dbq.tests.pre_task_hook", + } + } +) +class JobPreTaskHookTestCase(TestCase): + def test_pre_task_hook(self): + job = Job.objects.create(name="testjob") + Worker("default", 1)._process_job() + job = Job.objects.get() + self.assertEqual(job.state, Job.STATES.FAILED) + self.assertEqual(job.workspace["output"], "failure hook ran") + self.assertEqual(job.workspace["job_id"], str(job.id)) + + +@override_settings( + JOBS={ + "testjob": { + "tasks": ["django_dbq.tests.failing_task"], + "post_task_hook": "django_dbq.tests.post_task_hook", + } + } +) +class JobPostTaskHookTestCase(TestCase): + def test_post_task_hook(self): + job = Job.objects.create(name="testjob") + Worker("default", 1)._process_job() + job = Job.objects.get() + self.assertEqual(job.state, Job.STATES.FAILED) + self.assertEqual(job.workspace["output"], "post task hook ran") + self.assertEqual(job.workspace["job_id"], str(job.id)) + + @override_settings( JOBS={ "testjob": { @@ -341,6 +391,8 @@ def test_failure_hook(self): job = Job.objects.get() self.assertEqual(job.state, Job.STATES.FAILED) self.assertEqual(job.workspace["output"], "failure hook ran") + self.assertIn("uh oh", job.workspace["exception"]) + self.assertEqual(job.workspace["job_id"], str(job.id)) @override_settings(JOBS={"testjob": {"tasks": ["a"]}})