Skip to content

Commit

Permalink
increase task_class input validation and log messages and add unit te…
Browse files Browse the repository at this point in the history
…st (#191)
  • Loading branch information
backmari authored Oct 15, 2024
1 parent b6b0d49 commit 100be00
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 8 deletions.
42 changes: 34 additions & 8 deletions src/workflow_app/workflow/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
from .settings import POSTPROCESS_ERROR, CATALOG_DATA_READY
from .settings import REDUCTION_DATA_READY, REDUCTION_CATALOG_DATA_READY
from .database import transactions

import importlib
import inspect
import json
import logging
import re


class StateAction:
Expand Down Expand Up @@ -47,6 +51,30 @@ def _call_default_task(self, headers, message):
action_cls = globals()[destination]
action_cls(connection=self._send_connection)(headers, message)

def _get_class_from_path(self, class_path: str):
"""
Returns the class given by the class path
:param class_path: the class, e.g. "module_name.ClassName"
:return: class or None
"""
# check that the string is in the format "package_name.module_name.class_name"
pattern = r"^[a-zA-Z0-9_\.]+\.[a-zA-Z0-9_]+$"
if not re.match(pattern, class_path):
logging.error(f"task_class {class_path} does not match pattern module_name.ClassName")
return None
module_name, class_name = class_path.rsplit(".", 1)

# try importing the class
try:
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
if not inspect.isclass(cls):
raise ValueError
return cls
except (ModuleNotFoundError, AttributeError, ValueError):
logging.error(f"task_class {class_path} cannot be imported")
return None

def _call_db_task(self, task_data, headers, message):
"""
:param task_data: JSON-encoded task definition
Expand All @@ -59,14 +87,12 @@ def _call_db_task(self, task_data, headers, message):
and (task_def["task_class"] is not None)
and len(task_def["task_class"].strip()) > 0
):
try:
toks = task_def["task_class"].strip().split(".")
module = ".".join(toks[: len(toks) - 1])
cls = toks[len(toks) - 1]
exec("from %s import %s as action_cls" % (module, cls))
action_cls(connection=self._send_connection)(headers, message) # noqa: F821
except: # noqa: E722
logging.exception("Task [%s] failed:", headers["destination"])
action_cls = self._get_class_from_path(task_def["task_class"])
if action_cls:
try:
action_cls(connection=self._send_connection)(headers, message) # noqa: F821
except: # noqa: E722
logging.exception("Task [%s] failed:", headers["destination"])
if "task_queues" in task_def:
for item in task_def["task_queues"]:
destination = "/queue/%s" % item
Expand Down
64 changes: 64 additions & 0 deletions src/workflow_app/workflow/tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,20 @@
_ = [workflow]


class FakeTestClass:
def __init__(self, connection):
pass

def __call__(self, headers, message):
raise ValueError


class StateActionTest(TestCase):

@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self.caplog = caplog

def test_call_default_task(self):
from workflow.states import StateAction

Expand Down Expand Up @@ -47,6 +60,57 @@ def test_call(self, mock_get_task):
sa(headers, message)
assert mock_connection.send.call_count - original_call_count == 2 # one per task queue

@mock.patch("workflow.states.transactions.get_task")
def test_task_class_path(self, mock_get_task):
from workflow.states import StateAction

mock_connection = mock.Mock()
sa = StateAction(connection=mock_connection, use_db_task=True)
headers = {"destination": "test", "message-id": "test-0"}
message = '{"facility": "SNS", "instrument": "arcs", "ipts": "IPTS-5", "run_number": 3, "data_file": "test"}'

# test with task class "-" (inserted by Django admin interface when left empty)
mock_get_task.return_value = '{"task_class": "-"}'
self.caplog.clear()
sa(headers, message)
assert "does not match pattern" in self.caplog.text

# test with task class that does not follow the pattern "module_name.ClassName"
mock_get_task.return_value = '{"task_class": "FakeClass"}'
self.caplog.clear()
sa(headers, message)
assert "does not match pattern" in self.caplog.text

# test with module that does not exist
mock_get_task.return_value = '{"task_class": "fake_module.FakeClass"}'
self.caplog.clear()
sa(headers, message)
assert "cannot be imported" in self.caplog.text

# test with module exists but class does not
mock_get_task.return_value = '{"task_class": "workflow.states.FakeClass"}'
self.caplog.clear()
sa(headers, message)
assert "cannot be imported" in self.caplog.text

# test with module attribute is not a class
mock_get_task.return_value = '{"task_class": "workflow.state_utilities.decode_message"}'
self.caplog.clear()
sa(headers, message)
assert "cannot be imported" in self.caplog.text

# test with calling class fails
mock_get_task.return_value = '{"task_class": "workflow.tests.test_states.FakeTestClass"}'
self.caplog.clear()
sa(headers, message)
assert "Task [test] failed" in self.caplog.text

# test with valid class
mock_get_task.return_value = '{"task_class": "workflow.states.Reduction_request"}'
self.caplog.clear()
sa(headers, message)
assert mock_connection.send.call_count == 1

@mock.patch("workflow.database.transactions.add_status_entry")
def test_send(self, mockAddStatusEntry):
from workflow.states import StateAction
Expand Down

0 comments on commit 100be00

Please sign in to comment.