From 35893d71ffe07a401ee54c01bc25a190b26db3e1 Mon Sep 17 00:00:00 2001 From: Nick Stenning Date: Fri, 23 Aug 2024 15:26:09 +0200 Subject: [PATCH] Deduplicate worker tests A lot of the setup and teardown boilerplate for Worker tests was repeated between tests. This commit eliminates that and ensures that each test is using the same process to set up and tear down the worker, by adding a `worker` fixture and allowing it to be parameterized indirectly. --- python/tests/server/conftest.py | 45 +++++ python/tests/server/test_worker.py | 309 +++++++++++------------------ 2 files changed, 164 insertions(+), 190 deletions(-) diff --git a/python/tests/server/conftest.py b/python/tests/server/conftest.py index 85306bb3db..89bd4cfa1a 100644 --- a/python/tests/server/conftest.py +++ b/python/tests/server/conftest.py @@ -11,6 +11,7 @@ from cog.command import ast_openapi_schema from cog.server.http import create_app +from cog.server.worker import Worker @define @@ -19,6 +20,23 @@ class AppConfig: options: Optional[Dict[str, Any]] +@define +class WorkerConfig: + fixture_name: str + setup: bool = True + + +def pytest_make_parametrize_id(config, val): + """ + Generates more readable IDs for parametrized tests that use AppConfig or + WorkerConfig values. + """ + if isinstance(val, AppConfig): + return val.predictor_fixture + elif isinstance(val, WorkerConfig): + return val.fixture_name + + def _fixture_path(name): # HACK: `name` can either be in the form ".py:Predictor" or just "". if ":" not in name: @@ -51,6 +69,21 @@ def uses_predictor_with_client_options(name, **options): ) +def uses_worker(name_or_names, setup=True): + """ + Decorator for tests that require a Worker instance. `name_or_names` can be + a single fixture name, or a sequence (list, tuple) of fixture names. If + it's a sequence, the test will be run once for each worker. + + If `setup` is True (the default) setup will be run before the test runs. + """ + if isinstance(name_or_names, (tuple, list)): + values = (WorkerConfig(fixture_name=n, setup=setup) for n in name_or_names) + else: + values = (WorkerConfig(fixture_name=name_or_names, setup=setup),) + return pytest.mark.parametrize("worker", values, indirect=True) + + def make_client( fixture_name: str, upload_url: Optional[str] = None, @@ -104,3 +137,15 @@ def static_schema(client) -> dict: ref = _fixture_path(client.ref) module_path = ref.split(":", 1)[0] return ast_openapi_schema.extract_file(module_path) + + +@pytest.fixture +def worker(request): + ref = _fixture_path(request.param.fixture_name) + w = Worker(predictor_ref=ref, tee_output=False) + if request.param.setup: + assert not w.setup().result().error + try: + yield w + finally: + w.shutdown() diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index dd6fe61133..f2c2557843 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -6,7 +6,7 @@ import pytest from attrs import define, field -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st from hypothesis.stateful import RuleBasedStateMachine, precondition, rule @@ -14,6 +14,8 @@ from cog.server.exceptions import FatalWorkerException, InvalidStateException from cog.server.worker import Worker +from .conftest import WorkerConfig, _fixture_path, uses_worker + # Set a longer deadline on CI as the instances are a bit slower. settings.register_profile("ci", max_examples=100, deadline=2000) settings.register_profile("default", max_examples=10, deadline=1500) @@ -27,43 +29,46 @@ ST_NAMES = st.sampled_from(["John", "Barry", "Elspeth", "Hamid", "Ronnie", "Yasmeen"]) SETUP_FATAL_FIXTURES = [ - ("exc_in_setup", {}), - ("exc_in_setup_and_predict", {}), - ("exc_on_import", {}), - ("exit_in_setup", {}), - ("exit_on_import", {}), - ("missing_predictor", {}), - ("nonexistent_file", {}), # this fixture doesn't even exist + "exc_in_setup", + "exc_in_setup_and_predict", + "exc_on_import", + "exit_in_setup", + "exit_on_import", + "missing_predictor", + "nonexistent_file", ] PREDICTION_FATAL_FIXTURES = [ - ("exit_in_predict", {}), - ("killed_in_predict", {}), + "exit_in_predict", + "killed_in_predict", ] RUNNABLE_FIXTURES = [ - ("simple", {}), - ("exc_in_predict", {}), - ("missing_predict", {}), + "simple", + "exc_in_predict", + "missing_predict", ] OUTPUT_FIXTURES = [ ( - "hello_world", + WorkerConfig("hello_world"), {"name": ST_NAMES}, lambda x: f"hello, {x['name']}", ), ( - "count_up", + WorkerConfig("count_up"), {"upto": st.integers(min_value=0, max_value=100)}, lambda x: list(range(x["upto"])), ), - ("complex_output", {}, lambda _: {"number": 42, "text": "meaning of life"}), + ( + WorkerConfig("complex_output"), + {}, + lambda _: {"number": 42, "text": "meaning of life"}, + ), ] SETUP_LOGS_FIXTURES = [ ( - "logging", ( "writing some stuff from C at import time\n" "writing to stdout at import time\n" @@ -75,8 +80,6 @@ PREDICT_LOGS_FIXTURES = [ ( - "logging", - {}, ("writing from C\n" "writing with print\n"), ("WARNING:root:writing log message\n" "writing to stderr\n"), ) @@ -145,274 +148,200 @@ def _process(worker, work, swallow_exceptions=False): return result -def _fixture_path(name): - test_dir = os.path.dirname(os.path.realpath(__file__)) - return os.path.join(test_dir, f"fixtures/{name}.py") + ":Predictor" - - -@pytest.mark.parametrize("name,payloads", SETUP_FATAL_FIXTURES) -def test_fatalworkerexception_from_setup_failures(name, payloads): +@uses_worker(SETUP_FATAL_FIXTURES, setup=False) +def test_fatalworkerexception_from_setup_failures(worker): """ Any failure during setup is fatal and should raise FatalWorkerException. """ - w = Worker(predictor_ref=_fixture_path(name), tee_output=False) - with pytest.raises(FatalWorkerException): - _process(w, w.setup) - - w.terminate() + _process(worker, worker.setup) -@pytest.mark.timeout(HYPOTHESIS_TEST_TIMEOUT) -@pytest.mark.parametrize("name,payloads", PREDICTION_FATAL_FIXTURES) -@given(data=st.data()) -def test_fatalworkerexception_from_irrecoverable_failures(data, name, payloads): +@uses_worker(PREDICTION_FATAL_FIXTURES) +def test_fatalworkerexception_from_irrecoverable_failures(worker): """ Certain kinds of failure during predict (crashes, unexpected exits) are irrecoverable and should raise FatalWorkerException. """ - w = Worker(predictor_ref=_fixture_path(name), tee_output=False) - - result = _process(w, w.setup) - assert not result.done.error - with pytest.raises(FatalWorkerException): - _process(w, lambda: w.predict(data.draw(st.fixed_dictionaries(payloads)))) + _process(worker, lambda: worker.predict({})) with pytest.raises(InvalidStateException): - _process(w, lambda: w.predict(data.draw(st.fixed_dictionaries(payloads)))) - - w.terminate() + _process(worker, lambda: worker.predict({})) -@pytest.mark.timeout(HYPOTHESIS_TEST_TIMEOUT) -@pytest.mark.parametrize("name,payloads", RUNNABLE_FIXTURES) -@given(data=st.data()) -def test_no_exceptions_from_recoverable_failures(data, name, payloads): +@uses_worker(RUNNABLE_FIXTURES) +def test_no_exceptions_from_recoverable_failures(worker): """ Well-behaved predictors, or those that only throw exceptions, should not raise. """ - w = Worker(predictor_ref=_fixture_path(name), tee_output=False) + for _ in range(5): + _process(worker, lambda: worker.predict({})) - try: - result = _process(w, w.setup) - assert not result.done.error - - for _ in range(5): - _process(w, lambda: w.predict(data.draw(st.fixed_dictionaries(payloads)))) - finally: - w.terminate() - -def test_stream_redirector_race_condition(): +@uses_worker("stream_redirector_race_condition") +def test_stream_redirector_race_condition(worker): """ StreamRedirector and _ChildWorker are using the same _events pipe to send data. When there are multiple threads trying to write to the same pipe, it can cause data corruption by race condition. The data corruption will cause pipe receiver to raise an exception due to unpickling error. """ - w = Worker( - predictor_ref=_fixture_path("stream_redirector_race_condition"), - tee_output=False, - ) - - try: - result = _process(w, w.setup) + for _ in range(5): + result = _process(worker, lambda: worker.predict({})) assert not result.done.error - for _ in range(5): - result = _process(w, lambda: w.predict({})) - assert not result.done.error - finally: - w.terminate() - @pytest.mark.timeout(HYPOTHESIS_TEST_TIMEOUT) -@pytest.mark.parametrize("name,payloads,output_generator", OUTPUT_FIXTURES) +@pytest.mark.parametrize( + "worker,payloads,output_generator", OUTPUT_FIXTURES, indirect=["worker"] +) +@settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) @given(data=st.data()) -def test_output(data, name, payloads, output_generator): +def test_output(worker, payloads, output_generator, data): """ We should get the outputs we expect from predictors that generate output. Note that most of the validation work here is actually done in _process. """ - w = Worker(predictor_ref=_fixture_path(name), tee_output=False) + payload = data.draw(st.fixed_dictionaries(payloads)) + expected_output = output_generator(payload) - try: - result = _process(w, w.setup) - assert not result.done.error + result = _process(worker, lambda: worker.predict(payload)) - payload = data.draw(st.fixed_dictionaries(payloads)) - expected_output = output_generator(payload) + assert result.output == expected_output - result = _process(w, lambda: w.predict(payload)) - - assert result.output == expected_output - finally: - w.terminate() - -@pytest.mark.parametrize("name,expected_stdout,expected_stderr", SETUP_LOGS_FIXTURES) -def test_setup_logging(name, expected_stdout, expected_stderr): +@uses_worker("logging", setup=False) +@pytest.mark.parametrize("expected_stdout,expected_stderr", SETUP_LOGS_FIXTURES) +def test_setup_logging(worker, expected_stdout, expected_stderr): """ We should get the logs we expect from predictors that generate logs during setup. """ - w = Worker(predictor_ref=_fixture_path(name), tee_output=False) - - try: - result = _process(w, w.setup) - assert not result.done.error + result = _process(worker, worker.setup) + assert not result.done.error - assert result.stdout == expected_stdout - assert result.stderr == expected_stderr - finally: - w.terminate() + assert result.stdout == expected_stdout + assert result.stderr == expected_stderr -@pytest.mark.parametrize( - "name,payloads,expected_stdout,expected_stderr", PREDICT_LOGS_FIXTURES -) -def test_predict_logging(name, payloads, expected_stdout, expected_stderr): +@uses_worker("logging") +@pytest.mark.parametrize("expected_stdout,expected_stderr", PREDICT_LOGS_FIXTURES) +def test_predict_logging(worker, expected_stdout, expected_stderr): """ We should get the logs we expect from predictors that generate logs during predict. """ - w = Worker(predictor_ref=_fixture_path(name), tee_output=False) + result = _process(worker, lambda: worker.predict({})) - try: - result = _process(w, w.setup) - assert not result.done.error - - result = _process(w, lambda: w.predict({})) - - assert result.stdout == expected_stdout - assert result.stderr == expected_stderr - finally: - w.terminate() + assert result.stdout == expected_stdout + assert result.stderr == expected_stderr -def test_cancel_is_safe(): +@uses_worker("sleep", setup=False) +def test_cancel_is_safe(worker): """ Calls to cancel at any time should not result in unexpected things happening or the cancelation of unexpected predictions. """ - w = Worker(predictor_ref=_fixture_path("sleep"), tee_output=True) - - try: - for _ in range(50): - w.cancel() + for _ in range(50): + worker.cancel() - _process(w, w.setup) + result = _process(worker, worker.setup) + assert not result.done.error - for _ in range(50): - w.cancel() + for _ in range(50): + worker.cancel() - result1 = _process( - w, lambda: w.predict({"sleep": 0.5}), swallow_exceptions=True - ) + result1 = _process( + worker, lambda: worker.predict({"sleep": 0.5}), swallow_exceptions=True + ) - for _ in range(50): - w.cancel() + for _ in range(50): + worker.cancel() - result2 = _process( - w, lambda: w.predict({"sleep": 0.1}), swallow_exceptions=True - ) + result2 = _process( + worker, lambda: worker.predict({"sleep": 0.1}), swallow_exceptions=True + ) - assert not result1.exception - assert not result1.done.canceled - assert not result2.exception - assert not result2.done.canceled - assert result2.output == "done in 0.1 seconds" - finally: - w.terminate() + assert not result1.exception + assert not result1.done.canceled + assert not result2.exception + assert not result2.done.canceled + assert result2.output == "done in 0.1 seconds" -def test_cancel_idempotency(): +@uses_worker("sleep", setup=False) +def test_cancel_idempotency(worker): """ Multiple calls to cancel within the same prediction, while not necessary or recommended, should still only result in a single cancelled prediction, and should not affect subsequent predictions. """ - w = Worker(predictor_ref=_fixture_path("sleep"), tee_output=True) def cancel_a_bunch(_): for _ in range(100): - w.cancel() + worker.cancel() - try: - _process(w, w.setup) - - fut = w.predict({"sleep": 0.5}) - # We call cancel a WHOLE BUNCH to make sure that we don't propagate any - # of those cancelations to subsequent predictions, regardless of the - # internal implementation of exceptions raised inside signal handlers. - for _ in range(5): - time.sleep(0.05) - for _ in range(100): - w.cancel() - result = fut.result() - assert result.canceled - - result2 = _process(w, lambda: w.predict({"sleep": 0.1})) - - assert not result2.done.canceled - assert result2.output == "done in 0.1 seconds" - finally: - w.terminate() + result = _process(worker, worker.setup) + assert not result.done.error + + fut = worker.predict({"sleep": 0.5}) + # We call cancel a WHOLE BUNCH to make sure that we don't propagate any + # of those cancelations to subsequent predictions, regardless of the + # internal implementation of exceptions raised inside signal handlers. + for _ in range(5): + time.sleep(0.05) + for _ in range(100): + worker.cancel() + result1 = fut.result() + assert result1.canceled + result2 = _process(worker, lambda: worker.predict({"sleep": 0.1})) -def test_cancel_multiple_predictions(): + assert not result2.done.canceled + assert result2.output == "done in 0.1 seconds" + + +@uses_worker("sleep") +def test_cancel_multiple_predictions(worker): """ Multiple predictions cancelled in a row shouldn't be a problem. This test is mainly ensuring that the _allow_cancel latch in Worker is correctly reset every time a prediction starts. """ + dones: list[Done] = [] + for _ in range(5): + fut = worker.predict({"sleep": 0.1}) + time.sleep(0.01) + worker.cancel() + dones.append(fut.result()) + assert all(d.canceled for d in dones) - w = Worker(predictor_ref=_fixture_path("sleep"), tee_output=True) - - try: - _process(w, w.setup) - - dones: list[Done] = [] - for _ in range(5): - fut = w.predict({"sleep": 0.1}) - time.sleep(0.01) - w.cancel() - dones.append(fut.result()) - assert all(d.canceled for d in dones) - - done_future = w.predict({"sleep": 0}) - assert not done_future.result().canceled - finally: - w.terminate() + assert not worker.predict({"sleep": 0}).result().canceled -def test_graceful_shutdown(): +@uses_worker("sleep") +def test_graceful_shutdown(worker): """ On shutdown, the worker should finish running the current prediction, and then exit. """ - w = Worker(predictor_ref=_fixture_path("sleep"), tee_output=False) saw_first_event = threading.Event() - try: - _process(w, w.setup) - - # When we see the first event, we'll start the shutdown process. - w.subscribe(lambda event: saw_first_event.set()) + # When we see the first event, we'll start the shutdown process. + worker.subscribe(lambda event: saw_first_event.set()) - fut = w.predict({"sleep": 1}) + fut = worker.predict({"sleep": 1}) - saw_first_event.wait(timeout=1) - w.shutdown(timeout=2) + saw_first_event.wait(timeout=1) + worker.shutdown(timeout=2) - assert fut.result() == Done() - finally: - w.terminate() + assert fut.result() == Done() class WorkerState(RuleBasedStateMachine):