Skip to content

Commit

Permalink
Deduplicate worker tests
Browse files Browse the repository at this point in the history
A lot of the setup and teardown boilerplate for Worker tests was
repeated between tests.

This commit eliminates that and ensures that each test is using the same
process to set up and tear down the worker, by adding a `worker` fixture
and allowing it to be parameterized indirectly.
  • Loading branch information
nickstenning committed Aug 23, 2024
1 parent 6c77620 commit 35893d7
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 190 deletions.
45 changes: 45 additions & 0 deletions python/tests/server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from cog.command import ast_openapi_schema
from cog.server.http import create_app
from cog.server.worker import Worker


@define
Expand All @@ -19,6 +20,23 @@ class AppConfig:
options: Optional[Dict[str, Any]]


@define
class WorkerConfig:
fixture_name: str
setup: bool = True


def pytest_make_parametrize_id(config, val):
"""
Generates more readable IDs for parametrized tests that use AppConfig or
WorkerConfig values.
"""
if isinstance(val, AppConfig):
return val.predictor_fixture
elif isinstance(val, WorkerConfig):
return val.fixture_name


def _fixture_path(name):
# HACK: `name` can either be in the form "<name>.py:Predictor" or just "<name>".
if ":" not in name:
Expand Down Expand Up @@ -51,6 +69,21 @@ def uses_predictor_with_client_options(name, **options):
)


def uses_worker(name_or_names, setup=True):
"""
Decorator for tests that require a Worker instance. `name_or_names` can be
a single fixture name, or a sequence (list, tuple) of fixture names. If
it's a sequence, the test will be run once for each worker.
If `setup` is True (the default) setup will be run before the test runs.
"""
if isinstance(name_or_names, (tuple, list)):
values = (WorkerConfig(fixture_name=n, setup=setup) for n in name_or_names)
else:
values = (WorkerConfig(fixture_name=name_or_names, setup=setup),)
return pytest.mark.parametrize("worker", values, indirect=True)


def make_client(
fixture_name: str,
upload_url: Optional[str] = None,
Expand Down Expand Up @@ -104,3 +137,15 @@ def static_schema(client) -> dict:
ref = _fixture_path(client.ref)
module_path = ref.split(":", 1)[0]
return ast_openapi_schema.extract_file(module_path)


@pytest.fixture
def worker(request):
ref = _fixture_path(request.param.fixture_name)
w = Worker(predictor_ref=ref, tee_output=False)
if request.param.setup:
assert not w.setup().result().error
try:
yield w
finally:
w.shutdown()
Loading

0 comments on commit 35893d7

Please sign in to comment.