diff --git a/conftest.py b/conftest.py index f88ba10e..44246d1d 100644 --- a/conftest.py +++ b/conftest.py @@ -4,6 +4,7 @@ import pytest from hydra import compose, initialize_config_dir from omegaconf import DictConfig +from ray.serve.context import get_global_client from kazu.data.data import Document, SynonymTermWithMetrics from kazu.annotation.label_studio import ( @@ -19,7 +20,7 @@ ) from kazu.tests.utils import CONFIG_DIR, DummyParser from kazu.utils.constants import HYDRA_VERSION_BASE -from kazu.web.server import start, stop +from kazu.web.server import start, stop, KazuWebAPI from kazu.utils.caching import kazu_disk_cache from kazu.steps.linking.post_processing.disambiguation.context_scoring import TfIdfScorer from kazu.utils.utils import Singleton @@ -126,6 +127,13 @@ def _make_label_studio_manager( return _make_label_studio_manager +def _wait_for_api_running(): + client = get_global_client() + # type ignore needed because this attribute is added by the ray serve + # decorators around the class, which mypy doesn't 'understand' what they do + client._wait_for_deployment_healthy(KazuWebAPI.name, timeout_s=600) # type:ignore[attr-defined] + + @pytest.fixture(scope="function") def ray_server(override_kazu_test_config): # clear any residual singleton info, as ray runs separate processes and @@ -135,6 +143,7 @@ def ray_server(override_kazu_test_config): overrides=["ray=local", "ray.serve.detached=true"], ) start(cfg) + _wait_for_api_running() yield {} stop() @@ -149,6 +158,7 @@ def ray_server_with_jwt_auth(override_kazu_test_config): overrides=["ray=local", "ray.serve.detached=true", "Middlewares=jwt"], ) start(cfg) + _wait_for_api_running() yield { "Authorization": f'Bearer {jwt.encode({"username": "user"}, os.environ["KAZU_JWT_KEY"], algorithm="HS256")}' }