Skip to content

Commit

Permalink
Merge pull request #156 from neutrons/workflow_task_class
Browse files Browse the repository at this point in the history
Fix error thrown when the task_class of a task is None
  • Loading branch information
backmari authored Feb 28, 2024
2 parents 57a7f75 + 69c4d49 commit fd5df89
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
6 changes: 5 additions & 1 deletion src/workflow_app/workflow/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ def _call_db_task(self, task_data, headers, message):
:param message: JSON-encoded message content
"""
task_def = json.loads(task_data)
if "task_class" in task_def and len(task_def["task_class"].strip()) > 0:
if (
"task_class" in task_def
and (task_def["task_class"] is not None)
and len(task_def["task_class"].strip()) > 0
):
try:
toks = task_def["task_class"].strip().split(".")
module = ".".join(toks[: len(toks) - 1])
Expand Down
24 changes: 20 additions & 4 deletions src/workflow_app/workflow/tests/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,26 @@ def test_call_db_task(self):
message = "test_msg"
sa._call_db_task(task_data, headers, message)

def test_call(self):
# NOTE: the decorator logged_action is preventing unittest mock from
# isolating the functionality for unit test, hence skipping
pass
@mock.patch("workflow.states.transactions.get_task")
def test_call(self, mock_get_task):
# NOTE: skipping testing of importing a task class as this option is not currently used
from workflow.states import StateAction

mock_connection = mock.Mock()
sa = StateAction(connection=mock_connection, use_db_task=True)
headers = {"destination": "test", "message-id": "test-0"}
message = '{"facility": "SNS", "instrument": "arcs", "ipts": "IPTS-5", "run_number": 3, "data_file": "test"}'

# test with task class: null
mock_get_task.return_value = '{"task_class": null, "task_queues": ["QUEUE-0", "QUEUE-1"]}'
sa(headers, message)
assert mock_connection.send.call_count == 2 # one per task queue
original_call_count = mock_connection.send.call_count

# test with task class: empty string
mock_get_task.return_value = '{"task_class": "", "task_queues": ["QUEUE-0", "QUEUE-1"]}'
sa(headers, message)
assert mock_connection.send.call_count - original_call_count == 2 # one per task queue

@mock.patch("workflow.database.transactions.add_status_entry")
def test_send(self, mockAddStatusEntry):
Expand Down

0 comments on commit fd5df89

Please sign in to comment.