Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mock sleep #228

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from blueapi.service.main import app
from src.blueapi.core import BlueskyContext

_TIMEOUT = 10.0


def pytest_addoption(parser):
parser.addoption(
Expand Down Expand Up @@ -58,3 +60,8 @@ def no_op():
@pytest.fixture(scope="session")
def client(handler: Handler) -> TestClient:
return Client(handler).client


@pytest.fixture(scope="session")
def timeout() -> float:
return _TIMEOUT
14 changes: 6 additions & 8 deletions tests/core/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from blueapi.core import EventPublisher

_TIMEOUT: float = 10.0


@dataclass
class MyEvent:
Expand All @@ -20,22 +18,22 @@ def publisher() -> EventPublisher[MyEvent]:
return EventPublisher()


def test_publishes_event(publisher: EventPublisher[MyEvent]) -> None:
def test_publishes_event(timeout: float, publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
f: Future = Future()
publisher.subscribe(lambda r, _: f.set_result(r))
publisher.publish(event)
assert f.result(timeout=_TIMEOUT) == event
assert f.result(timeout=timeout) == event


def test_multi_subscriber(publisher: EventPublisher[MyEvent]) -> None:
def test_multi_subscriber(timeout: float, publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
f1: Future = Future()
f2: Future = Future()
publisher.subscribe(lambda r, _: f1.set_result(r))
publisher.subscribe(lambda r, _: f2.set_result(r))
publisher.publish(event)
assert f1.result(timeout=_TIMEOUT) == f2.result(timeout=_TIMEOUT) == event
assert f1.result(timeout=timeout) == f2.result(timeout=timeout) == event


def test_can_unsubscribe(publisher: EventPublisher[MyEvent]) -> None:
Expand Down Expand Up @@ -67,13 +65,13 @@ def test_can_unsubscribe_all(publisher: EventPublisher[MyEvent]) -> None:
assert list(_drain(q)) == [event_a, event_a, event_c]


def test_correlation_id(publisher: EventPublisher[MyEvent]) -> None:
def test_correlation_id(timeout: float, publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
correlation_id = "foobar"
f: Future = Future()
publisher.subscribe(lambda _, c: f.set_result(c))
publisher.publish(event, correlation_id)
assert f.result(timeout=_TIMEOUT) == correlation_id
assert f.result(timeout=timeout) == correlation_id


def _drain(queue: Queue) -> Iterable:
Expand Down
53 changes: 32 additions & 21 deletions tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from blueapi.config import StompConfig
from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate

_TIMEOUT: float = 10.0
_COUNT = itertools.count()


Expand Down Expand Up @@ -41,31 +40,35 @@ def test_topic(template: MessagingTemplate) -> str:


@pytest.mark.stomp
def test_send(template: MessagingTemplate, test_queue: str) -> None:
def test_send(template: MessagingTemplate, timeout: float, test_queue: str) -> None:
f: Future = Future()

def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.subscribe(test_queue, callback)
template.send(test_queue, "test_message")
assert f.result(timeout=_TIMEOUT)
assert f.result(timeout=timeout)


@pytest.mark.stomp
def test_send_to_topic(template: MessagingTemplate, test_topic: str) -> None:
def test_send_to_topic(
template: MessagingTemplate, timeout: float, test_topic: str
) -> None:
f: Future = Future()

def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.subscribe(test_topic, callback)
template.send(test_topic, "test_message")
assert f.result(timeout=_TIMEOUT)
assert f.result(timeout=timeout)


@pytest.mark.stomp
def test_send_on_reply(template: MessagingTemplate, test_queue: str) -> None:
def test_send_on_reply(
template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(template, test_queue)

f: Future = Future()
Expand All @@ -74,26 +77,28 @@ def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.send(test_queue, "test_message", callback)
assert f.result(timeout=_TIMEOUT)
assert f.result(timeout=timeout)


@pytest.mark.stomp
def test_send_and_receive(template: MessagingTemplate, test_queue: str) -> None:
def test_send_and_receive(
template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(template, test_queue)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"


@pytest.mark.stomp
def test_listener(template: MessagingTemplate, test_queue: str) -> None:
def test_listener(template: MessagingTemplate, timeout: float, test_queue: str) -> None:
@template.listener(test_queue)
def server(ctx: MessageContext, message: str) -> None:
reply_queue = ctx.reply_destination
if reply_queue is None:
raise RuntimeError("reply queue is None")
template.send(reply_queue, "ack")

reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"


Expand All @@ -108,7 +113,11 @@ class Foo(BaseModel):
[("test", str), (1, int), (Foo(a=1, b="test"), Foo)],
)
def test_deserialization(
template: MessagingTemplate, test_queue: str, message: Any, message_type: Type
template: MessagingTemplate,
timeout: float,
test_queue: str,
message: Any,
message_type: Type,
) -> None:
def server(ctx: MessageContext, message: message_type) -> None: # type: ignore
reply_queue = ctx.reply_destination
Expand All @@ -118,37 +127,39 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore

template.subscribe(test_queue, server)
reply = template.send_and_receive(test_queue, message, message_type).result(
timeout=_TIMEOUT
timeout=timeout
)
assert reply == message


@pytest.mark.stomp
def test_subscribe_before_connect(
disconnected_template: MessagingTemplate, test_queue: str
disconnected_template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(disconnected_template, test_queue)
disconnected_template.connect()
reply = disconnected_template.send_and_receive(test_queue, "test", str).result(
timeout=_TIMEOUT
timeout=timeout
)
assert reply == "ack"


@pytest.mark.stomp
def test_reconnect(template: MessagingTemplate, test_queue: str) -> None:
def test_reconnect(
template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(template, test_queue)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"
template.disconnect()
template.connect()
reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"


@pytest.mark.stomp
def test_correlation_id(
template: MessagingTemplate, test_queue: str, test_queue_2: str
template: MessagingTemplate, timeout: float, test_queue: str, test_queue_2: str
) -> None:
correlation_id = "foobar"
q: Queue = Queue()
Expand All @@ -164,9 +175,9 @@ def client(ctx: MessageContext, msg: str) -> None:
template.subscribe(test_queue_2, client)
template.send(test_queue, "test", None, correlation_id)

ctx_req: MessageContext = q.get(timeout=_TIMEOUT)
ctx_req: MessageContext = q.get(timeout=timeout)
assert ctx_req.correlation_id == correlation_id
ctx_ack: MessageContext = q.get(timeout=_TIMEOUT)
ctx_ack: MessageContext = q.get(timeout=timeout)
assert ctx_ack.correlation_id == correlation_id


Expand Down
57 changes: 35 additions & 22 deletions tests/worker/test_reworker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
import threading
from concurrent.futures import Future
from typing import Callable, Iterable, List, Optional, TypeVar
from typing import Callable, Generator, Iterable, List, Optional, TypeVar

import mock
import pytest

from blueapi.config import EnvironmentConfig, Source, SourceKind
Expand All @@ -20,8 +21,14 @@
WorkerState,
)

_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 0.0})
_LONG_TASK = RunPlan(name="sleep", params={"time": 1.0})

class SleepMock(mock.MagicMock):
async def __call__(self, *args, **kwargs):
return super(SleepMock, self).__call__(*args, **kwargs)


_SIMPLE_TASK = RunPlan(name="sleep", params={"time": 10.0})
_LONG_TASK = RunPlan(name="sleep", params={"time": 200.0})
_INDEFINITE_TASK = RunPlan(
name="set_absolute",
params={"movable": "fake_device", "value": 4.0},
Expand Down Expand Up @@ -49,15 +56,17 @@ def fake_device() -> FakeDevice:


@pytest.fixture
def context(fake_device: FakeDevice) -> BlueskyContext:
ctx = BlueskyContext()
ctx_config = EnvironmentConfig()
ctx_config.sources.append(
Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices")
)
ctx.device(fake_device)
ctx.with_config(ctx_config)
return ctx
def context(fake_device: FakeDevice) -> Generator[BlueskyContext, None, None]:
with mock.patch("bluesky.run_engine.asyncio.sleep", new_callable=SleepMock):
ctx = BlueskyContext()

ctx_config = EnvironmentConfig()
ctx_config.sources.append(
Source(kind=SourceKind.DEVICE_FUNCTIONS, module="devices")
)
ctx.device(fake_device)
ctx.with_config(ctx_config)
yield ctx


@pytest.fixture
Expand Down Expand Up @@ -173,12 +182,12 @@ def test_does_not_allow_simultaneous_running_tasks(


@pytest.mark.parametrize("num_runs", [0, 1, 2])
def test_produces_worker_events(worker: Worker, num_runs: int) -> None:
def test_produces_worker_events(worker: Worker, timeout: float, num_runs: int) -> None:
task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)]
event_sequences = [_sleep_events(task_id) for task_id in task_ids]

for task_id, events in zip(task_ids, event_sequences):
assert_run_produces_worker_events(events, worker, task_id)
assert_run_produces_worker_events(events, worker, task_id, timeout)


def _sleep_events(task_id: str) -> List[WorkerEvent]:
Expand Down Expand Up @@ -210,7 +219,7 @@ def _sleep_events(task_id: str) -> List[WorkerEvent]:
]


def test_no_additional_progress_events_after_complete(worker: Worker):
def test_no_additional_progress_events_after_complete(worker: Worker, timeout: float):
"""
See https://github.com/bluesky/ophyd/issues/1115
"""
Expand All @@ -222,7 +231,7 @@ def test_no_additional_progress_events_after_complete(worker: Worker):
name="move", params={"moves": {"additional_status_device": 5.0}}
)
task_id = worker.submit_task(task)
begin_task_and_wait_until_complete(worker, task_id)
begin_task_and_wait_until_complete(worker, task_id, timeout)

# Extract all the display_name fields from the events
list_of_dict_keys = [pe.statuses.values() for pe in progress_events]
Expand All @@ -238,23 +247,27 @@ def test_no_additional_progress_events_after_complete(worker: Worker):


def assert_run_produces_worker_events(
expected_events: List[WorkerEvent],
worker: Worker,
task_id: str,
expected_events: List[WorkerEvent], worker: Worker, task_id: str, timeout: float
) -> None:
assert begin_task_and_wait_until_complete(worker, task_id) == expected_events
assert (
begin_task_and_wait_until_complete(worker, task_id, timeout) == expected_events
)


#
# Worker helpers
#


def begin_task_and_wait_until_complete(
worker: Worker,
task_id: str,
timeout: float = 5.0,
timeout: float,
) -> List[WorkerEvent]:
events: "Future[List[WorkerEvent]]" = take_events(
worker.worker_events,
lambda event: event.is_complete(),
)

worker.begin_task(task_id)
return events.result(timeout=timeout)

Expand Down