Skip to content

Commit

Permalink
Refactor client logic out of CLI (#548)
Browse files Browse the repository at this point in the history
Fixes #547
  • Loading branch information
callumforrester authored Jul 31, 2024
1 parent 4a2232b commit 0a4abfe
Show file tree
Hide file tree
Showing 11 changed files with 1,014 additions and 250 deletions.
151 changes: 52 additions & 99 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
import json
import logging
from collections import deque
from functools import wraps
from pathlib import Path
from pprint import pprint
from time import sleep

import click
from bluesky.callbacks.best_effort import BestEffortCallback
from pydantic import ValidationError
from requests.exceptions import ConnectionError

from blueapi import __version__
from blueapi.cli.event_bus_client import BlueskyRemoteError, EventBusClient
from blueapi.cli.format import OutputFormat
from blueapi.client.client import BlueapiClient
from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient
from blueapi.client.rest import BlueskyRemoteControlError
from blueapi.config import ApplicationConfig, ConfigLoader
from blueapi.core import DataEvent
from blueapi.messaging import MessageContext
from blueapi.messaging.stomptemplate import StompMessagingTemplate
from blueapi.service.main import start
from blueapi.service.model import WorkerTask
from blueapi.service.openapi import (
DOCS_SCHEMA_LOCATION,
generate_schema,
print_schema_as_yaml,
write_schema_as_yaml,
)
from blueapi.worker import ProgressEvent, Task, WorkerEvent, WorkerState
from blueapi.worker import ProgressEvent, Task, WorkerEvent

from .rest import BlueapiRestClient
from .scratch import setup_scratch
from .updates import CliEventRenderer


@click.group(invoke_without_command=True)
Expand Down Expand Up @@ -107,7 +107,7 @@ def controller(ctx: click.Context, output: str) -> None:
ctx.ensure_object(dict)
config: ApplicationConfig = ctx.obj["config"]
ctx.obj["fmt"] = OutputFormat(output)
ctx.obj["rest_client"] = BlueapiRestClient(config.api)
ctx.obj["client"] = BlueapiClient.from_config(config)


def check_connection(func):
Expand All @@ -126,7 +126,7 @@ def wrapper(*args, **kwargs):
@click.pass_obj
def get_plans(obj: dict) -> None:
"""Get a list of plans available for the worker to use"""
client: BlueapiRestClient = obj["rest_client"]
client: BlueapiClient = obj["client"]
obj["fmt"].display(client.get_plans())


Expand All @@ -135,7 +135,7 @@ def get_plans(obj: dict) -> None:
@click.pass_obj
def get_devices(obj: dict) -> None:
"""Get a list of devices available for the worker to use"""
client: BlueapiRestClient = obj["rest_client"]
client: BlueapiClient = obj["client"]
obj["fmt"].display(client.get_devices())


Expand Down Expand Up @@ -184,52 +184,37 @@ def run_plan(
obj: dict, name: str, parameters: str | None, timeout: float | None
) -> None:
"""Run a plan with parameters"""
config: ApplicationConfig = obj["config"]
client: BlueapiRestClient = obj["rest_client"]

logger = logging.getLogger(__name__)
if config.stomp is not None:
_message_template = StompMessagingTemplate.autoconfigured(config.stomp)
else:
raise RuntimeError(
"Cannot run plans without Stomp configuration to track progress"
)
event_bus_client = EventBusClient(_message_template)
finished_event: deque[WorkerEvent] = deque()

def store_finished_event(event: WorkerEvent) -> None:
if event.is_complete():
finished_event.append(event)
client: BlueapiClient = obj["client"]

parameters = parameters or "{}"
task_id = ""
parsed_params = json.loads(parameters) if isinstance(parameters, str) else {}

progress_bar = CliEventRenderer()
callback = BestEffortCallback()

def on_event(event: AnyEvent) -> None:
if isinstance(event, ProgressEvent):
progress_bar.on_progress_event(event)
elif isinstance(event, DataEvent):
callback(event.name, event.doc)

try:
task = Task(name=name, params=parsed_params)
resp = client.create_task(task)
task_id = resp.task_id
resp = client.run_task(task, on_event=on_event)
except ValidationError as e:
pprint(f"failed to validate the task parameters, {task_id}, error: {e}")
return
except BlueskyRemoteError as e:
except (BlueskyRemoteControlError, BlueskyStreamingError) as e:
pprint(f"server error with this message: {e}")
return
except ValueError:
pprint("task could not run")
return

with event_bus_client:
event_bus_client.subscribe_to_topics(task_id, on_event=store_finished_event)
updated = client.update_worker_task(WorkerTask(task_id=task_id))

event_bus_client.wait_for_complete(timeout=timeout)

if event_bus_client.timed_out:
logger.error(f"Plan did not complete within {timeout} seconds")
return

process_event_after_finished(finished_event.pop(), logger)
pprint(updated.dict())
pprint(resp.dict())
if resp.task_status is not None and not resp.task_status.task_failed:
print("Plan Succeeded")


@controller.command(name="state")
Expand All @@ -238,7 +223,7 @@ def store_finished_event(event: WorkerEvent) -> None:
def get_state(obj: dict) -> None:
"""Print the current state of the worker"""

client: BlueapiRestClient = obj["rest_client"]
client: BlueapiClient = obj["client"]
print(client.get_state().name)


Expand All @@ -249,8 +234,8 @@ def get_state(obj: dict) -> None:
def pause(obj: dict, defer: bool = False) -> None:
"""Pause the execution of the current task"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.set_state(WorkerState.PAUSED, defer=defer))
client: BlueapiClient = obj["client"]
pprint(client.pause(defer=defer))


@controller.command(name="resume")
Expand All @@ -259,8 +244,8 @@ def pause(obj: dict, defer: bool = False) -> None:
def resume(obj: dict) -> None:
"""Resume the execution of the current task"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.set_state(WorkerState.RUNNING))
client: BlueapiClient = obj["client"]
pprint(client.resume())


@controller.command(name="abort")
Expand All @@ -273,8 +258,8 @@ def abort(obj: dict, reason: str | None = None) -> None:
with optional reason
"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.cancel_current_task(state=WorkerState.ABORTING, reason=reason))
client: BlueapiClient = obj["client"]
pprint(client.abort(reason=reason))


@controller.command(name="stop")
Expand All @@ -285,8 +270,8 @@ def stop(obj: dict) -> None:
Stop the execution of the current task, marking as ongoing runs as success
"""

client: BlueapiRestClient = obj["rest_client"]
pprint(client.cancel_current_task(state=WorkerState.STOPPING))
client: BlueapiClient = obj["client"]
pprint(client.stop())


@controller.command(name="env")
Expand All @@ -295,51 +280,35 @@ def stop(obj: dict) -> None:
"-r",
"--reload",
is_flag=True,
type=bool,
help="Reload the current environment",
default=False,
)
@click.option(
"-t",
"--timeout",
type=float,
help="Timeout to wait for reload in seconds, defaults to 10",
default=10.0,
)
@click.pass_obj
def env(obj: dict, reload: bool | None) -> None:
def env(
obj: dict,
reload: bool,
timeout: float | None,
) -> None:
"""
Inspect or restart the environment
"""

assert isinstance(client := obj["rest_client"], BlueapiRestClient)
assert isinstance(client := obj["client"], BlueapiClient)
if reload:
# Reload the environment if needed
print("Reloading the environment...")
try:
deserialized = client.reload_environment()
print(deserialized)

except BlueskyRemoteError as e:
raise BlueskyRemoteError("Failed to reload the environment") from e

# Initialize a variable to keep track of the environment status
environment_initialized = False
polling_count = 0
max_polling_count = 10
# Use a while loop to keep checking until the environment is initialized
while not environment_initialized and polling_count < max_polling_count:
# Fetch the current environment status
environment_status = client.get_environment()

# Check if the environment is initialized
if environment_status.initialized:
print("Environment is initialized.")
environment_initialized = True
else:
print("Waiting for environment to initialize...")
polling_count += 1
sleep(1) # Wait for 1 seconds before checking again
if polling_count == max_polling_count:
raise TimeoutError("Environment initialization timed out.")

# Once out of the loop, print the initialized environment status
print(environment_status)
print("Reloading environment")
status = client.reload_environment(timeout=timeout)
print("Environment is initialized")
else:
print(client.get_environment())
status = client.get_environment()
print(status)


@main.command(name="setup-scratch")
Expand All @@ -350,19 +319,3 @@ def scratch(obj: dict) -> None:
setup_scratch(config.scratch)
else:
raise KeyError("No scratch config supplied")


# helper function
def process_event_after_finished(event: WorkerEvent, logger: logging.Logger):
if event.is_error():
logger.info("Failed with errors: \n")
for error in event.errors:
logger.error(error)
return
if len(event.warnings) != 0:
logger.info("Passed with warnings: \n")
for warning in event.warnings:
logger.warn(warning)
return

logger.info("Plan passed")
76 changes: 0 additions & 76 deletions src/blueapi/cli/event_bus_client.py

This file was deleted.

22 changes: 2 additions & 20 deletions src/blueapi/cli/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,36 +43,18 @@ def _update(self, name: str, view: StatusView) -> None:


class CliEventRenderer:
_task_id: str | None
_pbar_renderer: ProgressBarRenderer

def __init__(
self,
task_id: str | None = None,
pbar_renderer: ProgressBarRenderer | None = None,
) -> None:
self._task_id = task_id
if pbar_renderer is None:
pbar_renderer = ProgressBarRenderer()
self._pbar_renderer = pbar_renderer

def on_progress_event(self, event: ProgressEvent) -> None:
if self._relates_to_task(event):
self._pbar_renderer.update(event.statuses)
self._pbar_renderer.update(event.statuses)

def on_worker_event(self, event: WorkerEvent) -> None:
if self._relates_to_task(event):
print(str(event.state))

def _relates_to_task(self, event: WorkerEvent | ProgressEvent) -> bool:
if self._task_id is None:
return True
elif isinstance(event, WorkerEvent):
return (
event.task_status is not None
and event.task_status.task_id == self._task_id
)
elif isinstance(event, ProgressEvent):
return event.task_id == self._task_id
else:
return False
print(str(event.state))
Empty file added src/blueapi/client/__init__.py
Empty file.
Loading

0 comments on commit 0a4abfe

Please sign in to comment.