Skip to content

Commit

Permalink
Add tests for client code
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Jul 9, 2024
1 parent 46f0e13 commit 0f2c772
Show file tree
Hide file tree
Showing 6 changed files with 468 additions and 41 deletions.
5 changes: 3 additions & 2 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from blueapi import __version__
from blueapi.cli.format import OutputFormat
from blueapi.client.client import BlueapiClient
from blueapi.client.event_bus import AnyEvent, BlueskyRemoteError, EventBusClient
from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient
from blueapi.client.rest import BlueskyRemoteControlError
from blueapi.config import ApplicationConfig, ConfigLoader
from blueapi.core import DataEvent
from blueapi.messaging import MessageContext
Expand Down Expand Up @@ -203,7 +204,7 @@ def on_event(event: AnyEvent) -> None:
except ValidationError as e:
pprint(f"failed to validate the task parameters, {task_id}, error: {e}")
return
except BlueskyRemoteError as e:
except (BlueskyRemoteControlError, BlueskyStreamingError) as e:
pprint(f"server error with this message: {e}")
return
except ValueError:
Expand Down
20 changes: 13 additions & 7 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from blueapi.worker import Task, TrackableTask, WorkerEvent, WorkerState
from blueapi.worker.event import ProgressEvent, TaskStatus

from .event_bus import AnyEvent, BlueskyRemoteError, EventBusClient, OnAnyEvent
from .rest import BlueapiRestClient
from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent
from .rest import BlueapiRestClient, BlueskyRemoteControlError


class BlueapiClient:
Expand Down Expand Up @@ -194,9 +194,13 @@ def inner_on_event(ctx: MessageContext, event: AnyEvent) -> None:
if isinstance(event, WorkerEvent) and (
(event.is_complete()) and (ctx.correlation_id == task_id)
):
if len(event.errors) > 0:
if event.task_status is not None and event.task_status.task_failed:
complete.set_exception(
BlueskyRemoteError("\n".join(event.errors))
BlueskyStreamingError(
"\n".join(event.errors)
if len(event.errors) > 0
else "Unknown error"
)
)
else:
complete.set_result(event)
Expand All @@ -223,7 +227,7 @@ def create_and_start_task(self, task: Task) -> TaskResponse:
if worker_response.task_id == response.task_id:
return response
else:
raise BlueskyRemoteError(
raise BlueskyRemoteControlError(
f"Tried to create and start task {response.task_id} "
f"but {worker_response.task_id} was started instead"
)
Expand Down Expand Up @@ -334,7 +338,9 @@ def reload_environment(
try:
status = self._rest.delete_environment()
except Exception as e:
raise BlueskyRemoteError("Failed to tear down the environment") from e
raise BlueskyRemoteControlError(
"Failed to tear down the environment"
) from e
return self._wait_for_reload(
status,
timeout,
Expand All @@ -355,7 +361,7 @@ def _wait_for_reload(
# Poll until the environment is restarted or the timeout is reached
status = self._rest.get_environment()
if status.error_message is not None:
raise BlueskyRemoteError(status.error_message)
raise BlueskyRemoteControlError(status.error_message)
elif status.initialized:
return status
time.sleep(polling_interval)
Expand Down
25 changes: 10 additions & 15 deletions src/blueapi/client/event_bus.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import threading
from collections.abc import Callable

from blueapi.core import DataEvent
from blueapi.messaging import MessageContext, MessagingTemplate
from blueapi.worker import ProgressEvent, WorkerEvent


class BlueskyRemoteError(Exception):
class BlueskyStreamingError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)

Expand All @@ -17,13 +16,9 @@ def __init__(self, message: str) -> None:

class EventBusClient:
app: MessagingTemplate
complete: threading.Event
timed_out: bool | None

def __init__(self, app: MessagingTemplate) -> None:
self.app = app
self.complete = threading.Event()
self.timed_out = None

def __enter__(self) -> None:
self.app.connect()
Expand All @@ -35,12 +30,12 @@ def subscribe_to_all_events(
self,
on_event: Callable[[MessageContext, AnyEvent], None],
) -> None:
self.app.subscribe(
self.app.destinations.topic("public.worker.event"),
on_event,
)

def wait_for_complete(self, timeout: float | None = None) -> None:
self.timed_out = not self.complete.wait(timeout=timeout)

self.complete.clear()
try:
self.app.subscribe(
self.app.destinations.topic("public.worker.event"),
on_event,
)
except Exception as err:
raise BlueskyStreamingError(
"Unable to subscribe to messages from blueapi"
) from err
40 changes: 27 additions & 13 deletions src/blueapi/client/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
)
from blueapi.worker import Task, TrackableTask, WorkerState

from .event_bus import BlueskyRemoteError

T = TypeVar("T")


class BlueskyRemoteControlError(Exception):
def __init__(self, message: str) -> None:
super().__init__(message)


def get_status_message(code: int) -> str:
"""Returns the standard description for a given HTTP status code."""
try:
Expand All @@ -35,6 +38,16 @@ def _is_exception(response: requests.Response) -> bool:
return response.status_code >= 400


def _exception(response: requests.Response) -> Exception | None:
code = response.status_code
if code < 400:
return None
elif code == 404:
return KeyError(str(response.json()))
else:
return BlueskyRemoteControlError(str(response))


class BlueapiRestClient:
_config: RestConfig

Expand Down Expand Up @@ -107,32 +120,33 @@ def cancel_current_task(
data={"new_state": state, "reason": reason},
)

def get_environment(self) -> EnvironmentResponse:
return self._request_and_deserialize("/environment", EnvironmentResponse)

def delete_environment(self) -> EnvironmentResponse:
return self._request_and_deserialize(
"/environment", EnvironmentResponse, method="DELETE"
)

def _request_and_deserialize(
self,
suffix: str,
target_type: type[T],
data: Mapping[str, Any] | None = None,
method="GET",
raise_if: Callable[[requests.Response], bool] = _is_exception,
get_exception: Callable[[requests.Response], Exception | None] = _exception,
) -> T:
url = self._url(suffix)
if data:
response = requests.request(method, url, json=data)
else:
response = requests.request(method, url)
if raise_if(response):
raise BlueskyRemoteError(str(response))
exception = get_exception(response)
if exception is not None:
raise exception
deserialized = parse_obj_as(target_type, response.json())
return deserialized

def _url(self, suffix: str) -> str:
base_url = f"{self._config.protocol}://{self._config.host}:{self._config.port}"
return f"{base_url}{suffix}"

def get_environment(self) -> EnvironmentResponse:
return self._request_and_deserialize("/environment", EnvironmentResponse)

def delete_environment(self) -> EnvironmentResponse:
return self._request_and_deserialize(
"/environment", EnvironmentResponse, method="DELETE"
)
Loading

0 comments on commit 0f2c772

Please sign in to comment.