diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 6da29aefe..83a547acc 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -12,7 +12,8 @@ from blueapi import __version__ from blueapi.cli.format import OutputFormat from blueapi.client.client import BlueapiClient -from blueapi.client.event_bus import AnyEvent, BlueskyRemoteError, EventBusClient +from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient +from blueapi.client.rest import BlueskyRemoteControlError from blueapi.config import ApplicationConfig, ConfigLoader from blueapi.core import DataEvent from blueapi.messaging import MessageContext @@ -203,7 +204,7 @@ def on_event(event: AnyEvent) -> None: except ValidationError as e: pprint(f"failed to validate the task parameters, {task_id}, error: {e}") return - except BlueskyRemoteError as e: + except (BlueskyRemoteControlError, BlueskyStreamingError) as e: pprint(f"server error with this message: {e}") return except ValueError: diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index eeabccc1b..059787bc5 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -16,8 +16,8 @@ from blueapi.worker import Task, TrackableTask, WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus -from .event_bus import AnyEvent, BlueskyRemoteError, EventBusClient, OnAnyEvent -from .rest import BlueapiRestClient +from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent +from .rest import BlueapiRestClient, BlueskyRemoteControlError class BlueapiClient: @@ -194,9 +194,13 @@ def inner_on_event(ctx: MessageContext, event: AnyEvent) -> None: if isinstance(event, WorkerEvent) and ( (event.is_complete()) and (ctx.correlation_id == task_id) ): - if len(event.errors) > 0: + if event.task_status is not None and event.task_status.task_failed: complete.set_exception( - BlueskyRemoteError("\n".join(event.errors)) + BlueskyStreamingError( + "\n".join(event.errors) + if len(event.errors) > 0 + else "Unknown error" + ) ) else: complete.set_result(event) @@ -223,7 +227,7 @@ def create_and_start_task(self, task: Task) -> TaskResponse: if worker_response.task_id == response.task_id: return response else: - raise BlueskyRemoteError( + raise BlueskyRemoteControlError( f"Tried to create and start task {response.task_id} " f"but {worker_response.task_id} was started instead" ) @@ -334,7 +338,9 @@ def reload_environment( try: status = self._rest.delete_environment() except Exception as e: - raise BlueskyRemoteError("Failed to tear down the environment") from e + raise BlueskyRemoteControlError( + "Failed to tear down the environment" + ) from e return self._wait_for_reload( status, timeout, @@ -355,7 +361,7 @@ def _wait_for_reload( # Poll until the environment is restarted or the timeout is reached status = self._rest.get_environment() if status.error_message is not None: - raise BlueskyRemoteError(status.error_message) + raise BlueskyRemoteControlError(status.error_message) elif status.initialized: return status time.sleep(polling_interval) diff --git a/src/blueapi/client/event_bus.py b/src/blueapi/client/event_bus.py index a6d48f687..bfd0afd18 100644 --- a/src/blueapi/client/event_bus.py +++ b/src/blueapi/client/event_bus.py @@ -1,4 +1,3 @@ -import threading from collections.abc import Callable from blueapi.core import DataEvent @@ -6,7 +5,7 @@ from blueapi.worker import ProgressEvent, WorkerEvent -class BlueskyRemoteError(Exception): +class BlueskyStreamingError(Exception): def __init__(self, message: str) -> None: super().__init__(message) @@ -17,13 +16,9 @@ def __init__(self, message: str) -> None: class EventBusClient: app: MessagingTemplate - complete: threading.Event - timed_out: bool | None def __init__(self, app: MessagingTemplate) -> None: self.app = app - self.complete = threading.Event() - self.timed_out = None def __enter__(self) -> None: self.app.connect() @@ -35,12 +30,12 @@ def subscribe_to_all_events( self, on_event: Callable[[MessageContext, AnyEvent], None], ) -> None: - self.app.subscribe( - self.app.destinations.topic("public.worker.event"), - on_event, - ) - - def wait_for_complete(self, timeout: float | None = None) -> None: - self.timed_out = not self.complete.wait(timeout=timeout) - - self.complete.clear() + try: + self.app.subscribe( + self.app.destinations.topic("public.worker.event"), + on_event, + ) + except Exception as err: + raise BlueskyStreamingError( + "Unable to subscribe to messages from blueapi" + ) from err diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 20e93dae4..81ece17d6 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -1,5 +1,4 @@ from collections.abc import Callable, Mapping -from http import HTTPStatus from typing import Any, Literal, TypeVar import requests @@ -17,22 +16,22 @@ ) from blueapi.worker import Task, TrackableTask, WorkerState -from .event_bus import BlueskyRemoteError - T = TypeVar("T") -def get_status_message(code: int) -> str: - """Returns the standard description for a given HTTP status code.""" - try: - message = HTTPStatus(code).phrase - return message - except ValueError: - return "Unknown Status Code" +class BlueskyRemoteControlError(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) -def _is_exception(response: requests.Response) -> bool: - return response.status_code >= 400 +def _exception(response: requests.Response) -> Exception | None: + code = response.status_code + if code < 400: + return None + elif code == 404: + return KeyError(str(response.json())) + else: + return BlueskyRemoteControlError(str(response)) class BlueapiRestClient: @@ -107,32 +106,33 @@ def cancel_current_task( data={"new_state": state, "reason": reason}, ) + def get_environment(self) -> EnvironmentResponse: + return self._request_and_deserialize("/environment", EnvironmentResponse) + + def delete_environment(self) -> EnvironmentResponse: + return self._request_and_deserialize( + "/environment", EnvironmentResponse, method="DELETE" + ) + def _request_and_deserialize( self, suffix: str, target_type: type[T], data: Mapping[str, Any] | None = None, method="GET", - raise_if: Callable[[requests.Response], bool] = _is_exception, + get_exception: Callable[[requests.Response], Exception | None] = _exception, ) -> T: url = self._url(suffix) if data: response = requests.request(method, url, json=data) else: response = requests.request(method, url) - if raise_if(response): - raise BlueskyRemoteError(str(response)) + exception = get_exception(response) + if exception is not None: + raise exception deserialized = parse_obj_as(target_type, response.json()) return deserialized def _url(self, suffix: str) -> str: base_url = f"{self._config.protocol}://{self._config.host}:{self._config.port}" return f"{base_url}{suffix}" - - def get_environment(self) -> EnvironmentResponse: - return self._request_and_deserialize("/environment", EnvironmentResponse) - - def delete_environment(self) -> EnvironmentResponse: - return self._request_and_deserialize( - "/environment", EnvironmentResponse, method="DELETE" - ) diff --git a/tests/client/test_client.py b/tests/client/test_client.py new file mode 100644 index 000000000..433e7115e --- /dev/null +++ b/tests/client/test_client.py @@ -0,0 +1,411 @@ +from collections.abc import Callable +from unittest.mock import MagicMock, Mock, call + +import pytest + +from blueapi.client.client import BlueapiClient +from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient +from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError +from blueapi.core import DataEvent +from blueapi.messaging.context import MessageContext +from blueapi.service.model import ( + DeviceModel, + DeviceResponse, + EnvironmentResponse, + PlanModel, + PlanResponse, + TaskResponse, + WorkerTask, +) +from blueapi.worker import ProgressEvent, Task, TrackableTask, WorkerEvent, WorkerState +from blueapi.worker.event import TaskStatus + +PLANS = PlanResponse( + plans=[ + PlanModel(name="foo"), + PlanModel(name="bar"), + ] +) +PLAN = PlanModel(name="foo") +DEVICES = DeviceResponse( + devices=[ + DeviceModel(name="foo", protocols=[]), + DeviceModel(name="bar", protocols=[]), + ] +) +DEVICE = DeviceModel(name="foo", protocols=[]) +TASK = TrackableTask(task_id="foo", task=Task(name="bar", params={})) +ACTIVE_TASK = WorkerTask(task_id="bar") +ENV = EnvironmentResponse(initialized=True) +COMPLETE_EVENT = WorkerEvent( + state=WorkerState.IDLE, + task_status=TaskStatus( + task_id="foo", + task_complete=True, + task_failed=False, + ), +) +FAILED_EVENT = WorkerEvent( + state=WorkerState.IDLE, + task_status=TaskStatus( + task_id="foo", + task_complete=True, + task_failed=True, + ), +) + + +@pytest.fixture +def mock_rest() -> BlueapiRestClient: + mock = Mock(spec=BlueapiRestClient) + + mock.get_plans.return_value = PLANS + mock.get_plan.return_value = PLAN + mock.get_devices.return_value = DEVICES + mock.get_device.return_value = DEVICE + mock.get_state.return_value = WorkerState.IDLE + mock.get_task.return_value = TASK + mock.get_active_task.return_value = ACTIVE_TASK + mock.get_environment.return_value = ENV + mock.delete_environment.return_value = EnvironmentResponse(initialized=False) + + return mock + + +@pytest.fixture +def mock_events() -> EventBusClient: + mock_events = MagicMock(spec=EventBusClient) + ctx = Mock() + ctx.correlation_id = "foo" + mock_events.subscribe_to_all_events = lambda on_event: on_event(ctx, COMPLETE_EVENT) + return mock_events + + +@pytest.fixture +def client(mock_rest: Mock) -> BlueapiClient: + return BlueapiClient(rest=mock_rest) + + +@pytest.fixture +def client_with_events(mock_rest: Mock, mock_events: MagicMock): + return BlueapiClient(rest=mock_rest, events=mock_events) + + +def test_get_plans(client: BlueapiClient): + assert client.get_plans() == PLANS + + +def test_get_plan(client: BlueapiClient): + assert client.get_plan("foo") == PLAN + + +def test_get_nonexistant_plan( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.get_plan.side_effect = KeyError("Not found") + with pytest.raises(KeyError): + client.get_plan("baz") + + +def test_get_devices(client: BlueapiClient): + assert client.get_devices() == DEVICES + + +def test_get_device(client: BlueapiClient): + assert client.get_device("foo") == DEVICE + + +def test_get_nonexistant_device( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.get_device.side_effect = KeyError("Not found") + with pytest.raises(KeyError): + client.get_device("baz") + + +def test_get_state(client: BlueapiClient): + assert client.get_state() == WorkerState.IDLE + + +def test_get_task(client: BlueapiClient): + assert client.get_task("foo") == TASK + + +def test_get_nonexistant_task( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.get_task.side_effect = KeyError("Not found") + with pytest.raises(KeyError): + client.get_task("baz") + + +def test_create_task( + client: BlueapiClient, + mock_rest: Mock, +): + client.create_task(task=Task(name="foo")) + mock_rest.create_task.assert_called_once_with(Task(name="foo")) + + +def test_create_task_does_not_start_task( + client: BlueapiClient, + mock_rest: Mock, +): + client.create_task(task=Task(name="foo")) + mock_rest.update_worker_task.assert_not_called() + + +def test_clear_task( + client: BlueapiClient, + mock_rest: Mock, +): + client.clear_task(task_id="foo") + mock_rest.clear_task.assert_called_once_with("foo") + + +def test_get_active_task(client: BlueapiClient): + assert client.get_active_task() == ACTIVE_TASK + + +def test_start_task( + client: BlueapiClient, + mock_rest: Mock, +): + client.start_task(task=WorkerTask(task_id="bar")) + mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="bar")) + + +def test_start_nonexistant_task( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.update_worker_task.side_effect = KeyError("Not found") + with pytest.raises(KeyError): + client.start_task(task=WorkerTask(task_id="bar")) + + +def test_create_and_start_task_calls_both_creating_and_starting_endpoints( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.create_task.return_value = TaskResponse(task_id="baz") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz") + client.create_and_start_task(Task(name="baz")) + mock_rest.create_task.assert_called_once_with(Task(name="baz")) + mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="baz")) + + +def test_create_and_start_task_fails_if_task_creation_fails( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.create_task.side_effect = BlueskyRemoteControlError("No can do") + with pytest.raises(BlueskyRemoteControlError): + client.create_and_start_task(Task(name="baz")) + + +def test_create_and_start_task_fails_if_task_id_is_wrong( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.create_task.return_value = TaskResponse(task_id="baz") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="bar") + with pytest.raises(BlueskyRemoteControlError): + client.create_and_start_task(Task(name="baz")) + + +def test_create_and_start_task_fails_if_task_start_fails( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.create_task.return_value = TaskResponse(task_id="baz") + mock_rest.update_worker_task.side_effect = BlueskyRemoteControlError("No can do") + with pytest.raises(BlueskyRemoteControlError): + client.create_and_start_task(Task(name="baz")) + + +def test_get_environment(client: BlueapiClient): + assert client.get_environment() == ENV + + +def test_reload_environment( + client: BlueapiClient, + mock_rest: Mock, +): + client.reload_environment() + mock_rest.get_environment.assert_called_once() + mock_rest.delete_environment.assert_called_once() + + +def test_reload_environment_failure( + client: BlueapiClient, + mock_rest: Mock, +): + mock_rest.get_environment.return_value = EnvironmentResponse( + initialized=False, error_message="foo" + ) + with pytest.raises(BlueskyRemoteControlError, match="foo"): + client.reload_environment() + + +def test_abort( + client: BlueapiClient, + mock_rest: Mock, +): + client.abort(reason="foo") + mock_rest.cancel_current_task.assert_called_once_with( + WorkerState.ABORTING, + reason="foo", + ) + + +def test_stop( + client: BlueapiClient, + mock_rest: Mock, +): + client.stop() + mock_rest.cancel_current_task.assert_called_once_with(WorkerState.STOPPING) + + +def test_pause( + client: BlueapiClient, + mock_rest: Mock, +): + client.pause(defer=True) + mock_rest.set_state.assert_called_once_with( + WorkerState.PAUSED, + defer=True, + ) + + +def test_resume( + client: BlueapiClient, + mock_rest: Mock, +): + client.resume() + mock_rest.set_state.assert_called_once_with( + WorkerState.RUNNING, + defer=False, + ) + + +def test_cannot_run_task_without_message_bus(client: BlueapiClient): + with pytest.raises( + RuntimeError, + match="Cannot run plans without Stomp configuration to track progress", + ): + client.run_task(Task(name="foo")) + + +def test_run_task_sets_up_control( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, +): + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + client_with_events.run_task(Task(name="foo")) + + mock_rest.create_task.assert_called_once_with(Task(name="foo")) + mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="foo")) + + +def test_run_task_fails_on_failing_event( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, +): + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + mock_events.subscribe_to_all_events = lambda on_event: on_event(ctx, FAILED_EVENT) + + on_event = Mock() + with pytest.raises(BlueskyStreamingError): + client_with_events.run_task(Task(name="foo"), on_event=on_event) + + on_event.assert_called_with(FAILED_EVENT) + + +@pytest.mark.parametrize( + "test_event", + [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id="foo", + task_complete=False, + task_failed=False, + ), + ), + ProgressEvent(task_id="foo"), + DataEvent(name="start", doc={}), + ], +) +def test_run_task_calls_event_callback( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, + test_event: AnyEvent, +): + mock_rest.create_task.return_value = TaskResponse(task_id="foo") + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + + ctx = Mock() + ctx.correlation_id = "foo" + + def callback(on_event: Callable[[MessageContext, AnyEvent], None]): + on_event(ctx, test_event) + on_event(ctx, COMPLETE_EVENT) + + mock_events.subscribe_to_all_events = callback # type: ignore + + mock_on_event = Mock() + client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) + + assert mock_on_event.mock_calls == [call(test_event), call(COMPLETE_EVENT)] + + +@pytest.mark.parametrize( + "test_event", + [ + WorkerEvent( + state=WorkerState.RUNNING, + task_status=TaskStatus( + task_id="bar", + task_complete=False, + task_failed=False, + ), + ), + ProgressEvent(task_id="bar"), + object(), + ], +) +def test_run_task_ignores_non_matching_events( + client_with_events: BlueapiClient, + mock_rest: Mock, + mock_events: MagicMock, + test_event: AnyEvent, +): + mock_rest.create_task.return_value = TaskResponse(task_id="foo") # type: ignore + mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") # type: ignore + + ctx = Mock() + ctx.correlation_id = "foo" + + def callback(on_event: Callable[[MessageContext, AnyEvent], None]): + on_event(ctx, test_event) + on_event(ctx, COMPLETE_EVENT) + + mock_events.subscribe_to_all_events = callback + + mock_on_event = Mock() + client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) + + mock_on_event.assert_called_once_with(COMPLETE_EVENT) diff --git a/tests/client/test_event_bus.py b/tests/client/test_event_bus.py new file mode 100644 index 000000000..45aab501b --- /dev/null +++ b/tests/client/test_event_bus.py @@ -0,0 +1,54 @@ +from unittest.mock import ANY, Mock + +import pytest + +from blueapi.client.event_bus import BlueskyStreamingError, EventBusClient +from blueapi.messaging import MessagingTemplate + + +@pytest.fixture +def mock_template() -> MessagingTemplate: + return Mock(spec=MessagingTemplate) + + +@pytest.fixture +def events(mock_template: MessagingTemplate) -> EventBusClient: + return EventBusClient(app=mock_template) + + +def test_context_manager_connects_and_disconnects( + events: EventBusClient, + mock_template: Mock, +): + mock_template.connect.assert_not_called() + mock_template.disconnect.assert_not_called() + + with events: + mock_template.connect.assert_called_once() + mock_template.disconnect.assert_not_called() + + mock_template.disconnect.assert_called_once() + + +def test_client_subscribes_to_all_events( + events: EventBusClient, + mock_template: Mock, +): + on_event = Mock + with events: + events.subscribe_to_all_events(on_event=on_event) # type: ignore + mock_template.subscribe.assert_called_once_with(ANY, on_event) + + +def test_client_raises_streaming_error_on_subscribe_failure( + events: EventBusClient, + mock_template: Mock, +): + mock_template.subscribe.side_effect = RuntimeError("Foo") + on_event = Mock + with events: + with pytest.raises( + BlueskyStreamingError, + match="Unable to subscribe to messages from blueapi", + ): + events.subscribe_to_all_events(on_event=on_event) # type: ignore diff --git a/tests/test_cli.py b/tests/test_cli.py index 3b9a4fb91..b379433f5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -15,7 +15,7 @@ from blueapi import __version__ from blueapi.cli.cli import main from blueapi.cli.format import OutputFormat -from blueapi.client.event_bus import BlueskyRemoteError +from blueapi.client.rest import BlueskyRemoteControlError from blueapi.core.bluesky_types import Plan from blueapi.service.handler import Handler, teardown_handler from blueapi.service.model import ( @@ -393,8 +393,8 @@ def test_env_reload_server_side_error( result = runner.invoke(main, ["controller", "env", "-r"]) if result.exception is not None: assert isinstance( - result.exception, BlueskyRemoteError - ), "Expected a BlueskyRemoteError" + result.exception, BlueskyRemoteControlError + ), "Expected a BlueskyRemoteControlError" assert result.exception.args[0] == "Failed to tear down the environment" else: raise AssertionError("Expected an exception but got None") @@ -424,7 +424,7 @@ def mock_config(): "exception, expected_exit_code", [ (ValidationError("Invalid parameters", BaseModel), 1), - (BlueskyRemoteError("Server error"), 1), + (BlueskyRemoteControlError("Server error"), 1), (ValueError("Error parsing parameters"), 1), ], )