Skip to content

Commit

Permalink
added timeout fixture to be consistent across all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Rose Yemelyanova committed May 23, 2023
1 parent 08ba695 commit e54cffb
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 39 deletions.
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from blueapi.service.main import app
from blueapi.worker.reworker import RunEngineWorker

_TIMEOUT = 10.0


def pytest_addoption(parser):
parser.addoption(
Expand Down Expand Up @@ -65,3 +67,8 @@ def handler() -> MockHandler:
@pytest.fixture(scope="session")
def client(handler: MockHandler) -> TestClient:
return Client(handler).client


@pytest.fixture(scope="session")
def timeout() -> float:
return _TIMEOUT
14 changes: 6 additions & 8 deletions tests/core/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from blueapi.core import EventPublisher

_TIMEOUT: float = 10.0


@dataclass
class MyEvent:
Expand All @@ -20,22 +18,22 @@ def publisher() -> EventPublisher[MyEvent]:
return EventPublisher()


def test_publishes_event(publisher: EventPublisher[MyEvent]) -> None:
def test_publishes_event(timeout: float, publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
f: Future = Future()
publisher.subscribe(lambda r, _: f.set_result(r))
publisher.publish(event)
assert f.result(timeout=_TIMEOUT) == event
assert f.result(timeout=timeout) == event


def test_multi_subscriber(publisher: EventPublisher[MyEvent]) -> None:
def test_multi_subscriber(timeout: float, publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
f1: Future = Future()
f2: Future = Future()
publisher.subscribe(lambda r, _: f1.set_result(r))
publisher.subscribe(lambda r, _: f2.set_result(r))
publisher.publish(event)
assert f1.result(timeout=_TIMEOUT) == f2.result(timeout=_TIMEOUT) == event
assert f1.result(timeout=timeout) == f2.result(timeout=timeout) == event


def test_can_unsubscribe(publisher: EventPublisher[MyEvent]) -> None:
Expand Down Expand Up @@ -67,13 +65,13 @@ def test_can_unsubscribe_all(publisher: EventPublisher[MyEvent]) -> None:
assert list(_drain(q)) == [event_a, event_a, event_c]


def test_correlation_id(publisher: EventPublisher[MyEvent]) -> None:
def test_correlation_id(timeout: float, publisher: EventPublisher[MyEvent]) -> None:
event = MyEvent("a")
correlation_id = "foobar"
f: Future = Future()
publisher.subscribe(lambda _, c: f.set_result(c))
publisher.publish(event, correlation_id)
assert f.result(timeout=_TIMEOUT) == correlation_id
assert f.result(timeout=timeout) == correlation_id


def _drain(queue: Queue) -> Iterable:
Expand Down
53 changes: 32 additions & 21 deletions tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from blueapi.config import StompConfig
from blueapi.messaging import MessageContext, MessagingTemplate, StompMessagingTemplate

_TIMEOUT: float = 10.0
_COUNT = itertools.count()


Expand Down Expand Up @@ -41,31 +40,35 @@ def test_topic(template: MessagingTemplate) -> str:


@pytest.mark.stomp
def test_send(template: MessagingTemplate, test_queue: str) -> None:
def test_send(template: MessagingTemplate, timeout: float, test_queue: str) -> None:
f: Future = Future()

def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.subscribe(test_queue, callback)
template.send(test_queue, "test_message")
assert f.result(timeout=_TIMEOUT)
assert f.result(timeout=timeout)


@pytest.mark.stomp
def test_send_to_topic(template: MessagingTemplate, test_topic: str) -> None:
def test_send_to_topic(
template: MessagingTemplate, timeout: float, test_topic: str
) -> None:
f: Future = Future()

def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.subscribe(test_topic, callback)
template.send(test_topic, "test_message")
assert f.result(timeout=_TIMEOUT)
assert f.result(timeout=timeout)


@pytest.mark.stomp
def test_send_on_reply(template: MessagingTemplate, test_queue: str) -> None:
def test_send_on_reply(
template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(template, test_queue)

f: Future = Future()
Expand All @@ -74,26 +77,28 @@ def callback(ctx: MessageContext, message: str) -> None:
f.set_result(message)

template.send(test_queue, "test_message", callback)
assert f.result(timeout=_TIMEOUT)
assert f.result(timeout=timeout)


@pytest.mark.stomp
def test_send_and_receive(template: MessagingTemplate, test_queue: str) -> None:
def test_send_and_receive(
template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(template, test_queue)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"


@pytest.mark.stomp
def test_listener(template: MessagingTemplate, test_queue: str) -> None:
def test_listener(template: MessagingTemplate, timeout: float, test_queue: str) -> None:
@template.listener(test_queue)
def server(ctx: MessageContext, message: str) -> None:
reply_queue = ctx.reply_destination
if reply_queue is None:
raise RuntimeError("reply queue is None")
template.send(reply_queue, "ack")

reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"


Expand All @@ -108,7 +113,11 @@ class Foo(BaseModel):
[("test", str), (1, int), (Foo(a=1, b="test"), Foo)],
)
def test_deserialization(
template: MessagingTemplate, test_queue: str, message: Any, message_type: Type
template: MessagingTemplate,
timeout: float,
test_queue: str,
message: Any,
message_type: Type,
) -> None:
def server(ctx: MessageContext, message: message_type) -> None: # type: ignore
reply_queue = ctx.reply_destination
Expand All @@ -118,37 +127,39 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore

template.subscribe(test_queue, server)
reply = template.send_and_receive(test_queue, message, message_type).result(
timeout=_TIMEOUT
timeout=timeout
)
assert reply == message


@pytest.mark.stomp
def test_subscribe_before_connect(
disconnected_template: MessagingTemplate, test_queue: str
disconnected_template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(disconnected_template, test_queue)
disconnected_template.connect()
reply = disconnected_template.send_and_receive(test_queue, "test", str).result(
timeout=_TIMEOUT
timeout=timeout
)
assert reply == "ack"


@pytest.mark.stomp
def test_reconnect(template: MessagingTemplate, test_queue: str) -> None:
def test_reconnect(
template: MessagingTemplate, timeout: float, test_queue: str
) -> None:
acknowledge(template, test_queue)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"
template.disconnect()
template.connect()
reply = template.send_and_receive(test_queue, "test", str).result(timeout=_TIMEOUT)
reply = template.send_and_receive(test_queue, "test", str).result(timeout=timeout)
assert reply == "ack"


@pytest.mark.stomp
def test_correlation_id(
template: MessagingTemplate, test_queue: str, test_queue_2: str
template: MessagingTemplate, timeout: float, test_queue: str, test_queue_2: str
) -> None:
correlation_id = "foobar"
q: Queue = Queue()
Expand All @@ -164,9 +175,9 @@ def client(ctx: MessageContext, msg: str) -> None:
template.subscribe(test_queue_2, client)
template.send(test_queue, "test", None, correlation_id)

ctx_req: MessageContext = q.get(timeout=_TIMEOUT)
ctx_req: MessageContext = q.get(timeout=timeout)
assert ctx_req.correlation_id == correlation_id
ctx_ack: MessageContext = q.get(timeout=_TIMEOUT)
ctx_ack: MessageContext = q.get(timeout=timeout)
assert ctx_ack.correlation_id == correlation_id


Expand Down
20 changes: 10 additions & 10 deletions tests/worker/test_reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from concurrent.futures import Future
from typing import Callable, Iterable, List, Optional, TypeVar

import mock
import pytest

from blueapi.config import EnvironmentConfig, Source, SourceKind
Expand All @@ -19,7 +20,6 @@
WorkerEvent,
WorkerState,
)
import mock


class SleepMock(mock.MagicMock):
Expand Down Expand Up @@ -158,12 +158,12 @@ def test_does_not_allow_simultaneous_running_tasks(


@pytest.mark.parametrize("num_runs", [0, 1, 2])
def test_produces_worker_events(worker: Worker, num_runs: int) -> None:
def test_produces_worker_events(worker: Worker, timeout: float, num_runs: int) -> None:
task_ids = [worker.submit_task(_SIMPLE_TASK) for _ in range(num_runs)]
event_sequences = [_sleep_events(task_id) for task_id in task_ids]

for task_id, events in zip(task_ids, event_sequences):
assert_run_produces_worker_events(events, worker, task_id)
assert_run_produces_worker_events(events, worker, task_id, timeout)


def _sleep_events(task_id: str) -> List[WorkerEvent]:
Expand Down Expand Up @@ -195,7 +195,7 @@ def _sleep_events(task_id: str) -> List[WorkerEvent]:
]


def test_no_additional_progress_events_after_complete(worker: Worker):
def test_no_additional_progress_events_after_complete(worker: Worker, timeout: float):
"""
See https://github.com/bluesky/ophyd/issues/1115
"""
Expand All @@ -207,7 +207,7 @@ def test_no_additional_progress_events_after_complete(worker: Worker):
name="move", params={"moves": {"additional_status_device": 5.0}}
)
task_id = worker.submit_task(task)
begin_task_and_wait_until_complete(worker, task_id)
begin_task_and_wait_until_complete(worker, task_id, timeout)

# Extract all the display_name fields from the events
list_of_dict_keys = [pe.statuses.values() for pe in progress_events]
Expand All @@ -223,17 +223,17 @@ def test_no_additional_progress_events_after_complete(worker: Worker):


def assert_run_produces_worker_events(
expected_events: List[WorkerEvent],
worker: Worker,
task_id: str,
expected_events: List[WorkerEvent], worker: Worker, task_id: str, timeout: float
) -> None:
assert begin_task_and_wait_until_complete(worker, task_id) == expected_events
assert (
begin_task_and_wait_until_complete(worker, task_id, timeout) == expected_events
)


def begin_task_and_wait_until_complete(
worker: Worker,
task_id: str,
timeout: float = 200.0,
timeout: float,
) -> List[WorkerEvent]:
events: "Future[List[WorkerEvent]]" = take_events(
worker.worker_events,
Expand Down

0 comments on commit e54cffb

Please sign in to comment.