diff --git a/src/workflow_app/workflow/states.py b/src/workflow_app/workflow/states.py index e4ca05a3..6e1ba87a 100644 --- a/src/workflow_app/workflow/states.py +++ b/src/workflow_app/workflow/states.py @@ -53,7 +53,7 @@ def _call_db_task(self, task_data, headers, message): :param message: JSON-encoded message content """ task_def = json.loads(task_data) - if "task_class" in task_def and len(task_def["task_class"].strip()) > 0: + if "task_class" in task_def and (task_def["task_class"] is not None) and len(task_def["task_class"].strip()) > 0: try: toks = task_def["task_class"].strip().split(".") module = ".".join(toks[: len(toks) - 1]) diff --git a/src/workflow_app/workflow/tests/test_states.py b/src/workflow_app/workflow/tests/test_states.py index 9d1397f5..899fda85 100644 --- a/src/workflow_app/workflow/tests/test_states.py +++ b/src/workflow_app/workflow/tests/test_states.py @@ -26,10 +26,26 @@ def test_call_db_task(self): message = "test_msg" sa._call_db_task(task_data, headers, message) - def test_call(self): - # NOTE: the decorator logged_action is preventing unittest mock from - # isolating the functionality for unit test, hence skipping - pass + @mock.patch("workflow.states.transactions.get_task") + def test_call(self, mock_get_task): + # NOTE: skipping testing of importing a task class as this option is not currently used + from workflow.states import StateAction + + mock_connection = mock.Mock() + sa = StateAction(connection=mock_connection, use_db_task=True) + headers = {"destination": "test", "message-id": "test-0"} + message = '{"facility": "SNS", "instrument": "arcs", "ipts": "IPTS-5", "run_number": 3, "data_file": "test"}' + + # test with task class: null + mock_get_task.return_value = '{"task_class": null, "task_queues": ["QUEUE-0", "QUEUE-1"]}' + sa(headers, message) + assert mock_connection.send.call_count == 2 # one per task queue + original_call_count = mock_connection.send.call_count + + # test with task class: empty string + mock_get_task.return_value = '{"task_class": "", "task_queues": ["QUEUE-0", "QUEUE-1"]}' + sa(headers, message) + assert mock_connection.send.call_count - original_call_count == 2 # one per task queue @mock.patch("workflow.database.transactions.add_status_entry") def test_send(self, mockAddStatusEntry):