diff --git a/README.md b/README.md index 5562651..c4e698a 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,35 @@ JOBS = { } ``` +#### Pre & Post Task Hooks +You can also run pre task or post task hooks, which happen in the normal processing of your `Job` instances and are executed inside the worker process. + +Both pre and post task hooks receive your `Job` instance as their only argument. Here's an example: + +```python +def my_pre_task_hook(job): + ... # configure something before running your task +``` + +To ensure these hooks are run, simply add a `pre_task_hook` or `post_task_hook` key (or both, if needed) to your job config like so: + +```python +JOBS = { + "my_job": { + "tasks": ["project.common.jobs.my_task"], + "pre_task_hook": "project.common.jobs.my_pre_task_hook", + "post_task_hook": "project.common.jobs.my_post_task_hook", + }, +} +``` + +Notes: + +* If the `pre_task_hook` fails (raises an exception), the task function is not run, and django-db-queue behaves as if the task function itself had failed: the failure hook is called, and the job is goes into the `FAILED` state. +* The `post_task_hook` is always run, even if the job fails. In this case, it runs after the `failure_hook`. +* If the `post_task_hook` raises an exception, this is logged but the the job is **not marked as failed** and the failure hook does not run. This is because the `post_task_hook` might need to perform cleanup that always happens after the task, no matter whether it succeeds or fails. + + ### Start the worker In another terminal: diff --git a/django_dbq/__init__.py b/django_dbq/__init__.py index f5f41e5..1173108 100644 --- a/django_dbq/__init__.py +++ b/django_dbq/__init__.py @@ -1 +1 @@ -__version__ = "3.1.0" +__version__ = "3.2.0" diff --git a/django_dbq/management/commands/worker.py b/django_dbq/management/commands/worker.py index 9215aad..d166b8d 100644 --- a/django_dbq/management/commands/worker.py +++ b/django_dbq/management/commands/worker.py @@ -74,9 +74,10 @@ def _process_job(self): self.current_job = job try: - task_function = import_string(job.next_task) - task_function(job) + job.run_pre_task_hook() + job.run_next_task() job.update_next_task() + if not job.next_task: job.state = Job.STATES.COMPLETE else: @@ -84,16 +85,12 @@ def _process_job(self): except Exception as exception: logger.exception("Job id=%s failed", job.pk) job.state = Job.STATES.FAILED - - failure_hook_name = job.get_failure_hook_name() - if failure_hook_name: - logger.info( - "Running failure hook %s for job id=%s", failure_hook_name, job.pk - ) - failure_hook_function = import_string(failure_hook_name) - failure_hook_function(job, exception) - else: - logger.info("No failure hook for job id=%s", job.pk) + job.run_failure_hook(exception) + finally: + try: + job.run_post_task_hook() + except: + logger.exception("Job id=%s post_task_hook failed", job.pk) logger.info( 'Updating job: name="%s" id=%s state=%s next_task=%s', diff --git a/django_dbq/models.py b/django_dbq/models.py index d93a05b..b58eef4 100644 --- a/django_dbq/models.py +++ b/django_dbq/models.py @@ -3,6 +3,8 @@ from django.utils.module_loading import import_string from django_dbq.tasks import ( get_next_task_name, + get_pre_task_hook_name, + get_post_task_hook_name, get_failure_hook_name, get_creation_hook_name, ) @@ -126,16 +128,47 @@ def save(self, *args, **kwargs): def update_next_task(self): self.next_task = get_next_task_name(self.name, self.next_task) or "" + def run_next_task(self): + next_task_function = import_string(self.next_task) + next_task_function(self) + + def get_pre_task_hook_name(self): + return get_pre_task_hook_name(self.name) + + def get_post_task_hook_name(self): + return get_post_task_hook_name(self.name) + def get_failure_hook_name(self): return get_failure_hook_name(self.name) def get_creation_hook_name(self): return get_creation_hook_name(self.name) + def run_pre_task_hook(self): + pre_task_hook_name = self.get_pre_task_hook_name() + if pre_task_hook_name: + logger.info("Running pre_task hook %s for job", pre_task_hook_name) + pre_task_hook_function = import_string(pre_task_hook_name) + pre_task_hook_function(self) + + def run_post_task_hook(self): + post_task_hook_name = self.get_post_task_hook_name() + if post_task_hook_name: + logger.info("Running post_task hook %s for job", post_task_hook_name) + post_task_hook_function = import_string(post_task_hook_name) + post_task_hook_function(self) + + def run_failure_hook(self, exception): + failure_hook_name = self.get_failure_hook_name() + if failure_hook_name: + logger.info("Running failure hook %s for job", failure_hook_name) + failure_hook_function = import_string(failure_hook_name) + failure_hook_function(self, exception) + def run_creation_hook(self): creation_hook_name = self.get_creation_hook_name() if creation_hook_name: - logger.info("Running creation hook %s for new job", creation_hook_name) + logger.info("Running creation hook %s for job", creation_hook_name) creation_hook_function = import_string(creation_hook_name) creation_hook_function(self) diff --git a/django_dbq/tasks.py b/django_dbq/tasks.py index 3e43da3..a95b4a5 100644 --- a/django_dbq/tasks.py +++ b/django_dbq/tasks.py @@ -2,6 +2,8 @@ TASK_LIST_KEY = "tasks" +PRE_TASK_HOOK_KEY = "pre_task_hook" +POST_TASK_HOOK_KEY = "post_task_hook" FAILURE_HOOK_KEY = "failure_hook" CREATION_HOOK_KEY = "creation_hook" @@ -24,6 +26,16 @@ def get_next_task_name(job_name, current_task=None): return None +def get_pre_task_hook_name(job_name): + """Return the name of the pre task hook for the given job (as a string) or None""" + return settings.JOBS[job_name].get(PRE_TASK_HOOK_KEY) + + +def get_post_task_hook_name(job_name): + """Return the name of the post_task hook for the given job (as a string) or None""" + return settings.JOBS[job_name].get(POST_TASK_HOOK_KEY) + + def get_failure_hook_name(job_name): """Return the name of the failure hook for the given job (as a string) or None""" return settings.JOBS[job_name].get(FAILURE_HOOK_KEY) diff --git a/django_dbq/tests.py b/django_dbq/tests.py index ad376cd..dd83540 100644 --- a/django_dbq/tests.py +++ b/django_dbq/tests.py @@ -34,12 +34,25 @@ def failing_task(job): raise Exception("uh oh") +def pre_task_hook(job): + job.workspace["output"] = "pre task hook ran" + job.workspace["job_id"] = str(job.id) + + +def post_task_hook(job): + job.workspace["output"] = "post task hook ran" + job.workspace["job_id"] = str(job.id) + + def failure_hook(job, exception): job.workspace["output"] = "failure hook ran" + job.workspace["exception"] = str(exception) + job.workspace["job_id"] = str(job.id) def creation_hook(job): job.workspace["output"] = "creation hook ran" + job.workspace["job_id"] = str(job.id) @override_settings(JOBS={"testjob": {"tasks": ["a"]}}) @@ -316,6 +329,7 @@ def test_creation_hook(self): job = Job.objects.create(name="testjob") job = Job.objects.get() self.assertEqual(job.workspace["output"], "creation hook ran") + self.assertEqual(job.workspace["job_id"], str(job.id)) def test_creation_hook_only_runs_on_create(self): job = Job.objects.create(name="testjob") @@ -326,6 +340,42 @@ def test_creation_hook_only_runs_on_create(self): self.assertEqual(job.workspace["output"], "creation hook output removed") +@override_settings( + JOBS={ + "testjob": { + "tasks": ["django_dbq.tests.test_task"], + "pre_task_hook": "django_dbq.tests.pre_task_hook", + } + } +) +class JobPreTaskHookTestCase(TestCase): + def test_pre_task_hook(self): + job = Job.objects.create(name="testjob") + Worker("default", 1)._process_job() + job = Job.objects.get() + self.assertEqual(job.state, Job.STATES.COMPLETE) + self.assertEqual(job.workspace["output"], "pre task hook ran") + self.assertEqual(job.workspace["job_id"], str(job.id)) + + +@override_settings( + JOBS={ + "testjob": { + "tasks": ["django_dbq.tests.test_task"], + "post_task_hook": "django_dbq.tests.post_task_hook", + } + } +) +class JobPostTaskHookTestCase(TestCase): + def test_post_task_hook(self): + job = Job.objects.create(name="testjob") + Worker("default", 1)._process_job() + job = Job.objects.get() + self.assertEqual(job.state, Job.STATES.COMPLETE) + self.assertEqual(job.workspace["output"], "post task hook ran") + self.assertEqual(job.workspace["job_id"], str(job.id)) + + @override_settings( JOBS={ "testjob": { @@ -341,6 +391,8 @@ def test_failure_hook(self): job = Job.objects.get() self.assertEqual(job.state, Job.STATES.FAILED) self.assertEqual(job.workspace["output"], "failure hook ran") + self.assertIn("uh oh", job.workspace["exception"]) + self.assertEqual(job.workspace["job_id"], str(job.id)) @override_settings(JOBS={"testjob": {"tasks": ["a"]}})