diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 9553e78..b071fb5 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws or not(aws)" --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws and not(deprecated) or not(aws)" --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main diff --git a/pyproject.toml b/pyproject.toml index 4ee6cd6..aff5cbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ serve = "scripts.serve:main" [tool.pytest.ini_options] markers = [ - "aws: requires aws credentials" + "aws: requires aws credentials", + "deprecated: tests for deprecated features", ] addopts = "-m 'not aws'" diff --git a/tests/conftest.py b/tests/conftest.py index c1da77f..0bd4a58 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,16 +17,10 @@ REGION_NAME: BucketLocationConstraintType = "eu-central-1" -@pytest.fixture(scope="session") -def monkeypatch_module() -> Generator[pytest.MonkeyPatch, Any, None]: - with pytest.MonkeyPatch.context() as mp: - yield mp - - -@pytest.fixture(autouse=True, scope="session") -def patch_update_job(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock: +@pytest.fixture(autouse=True) +def patch_update_job(monkeypatch: pytest.MonkeyPatch) -> MagicMock: mock_update_job = MagicMock() - monkeypatch_module.setattr(job_tracking, "update_job", mock_update_job) + monkeypatch.setattr(job_tracking, "update_job", mock_update_job) return mock_update_job @@ -40,7 +34,7 @@ def create(self) -> None: self.add_ingress_rule() self.db_url = self.create_db_url() self.engine = self.get_engine() - self.delete_db_tables() + self.cleanup() def get_engine(self) -> Engine: for _ in range(5): @@ -79,7 +73,20 @@ def add_ingress_rule(self) -> None: else: raise e - def delete_db_tables(self) -> None: + def remove_ingress_rules(self) -> None: + # cleans up earlier tests too (in case of failures) + security_groups = self.ec2_client.describe_security_groups( + GroupNames=[self.vpc_sg_rule_params["GroupName"]] + ) + for sg in security_groups["SecurityGroups"]: + for rule in sg["IpPermissions"]: + if rule.get("FromPort") == 5432 and rule.get("ToPort") == 5432: + self.ec2_client.revoke_security_group_ingress( + GroupId=sg["GroupId"], + IpPermissions=[rule], # type: ignore + ) + + def cleanup(self) -> None: metadata = MetaData() engine = self.engine metadata.reflect(engine) @@ -138,14 +145,11 @@ def create_db_url(self) -> str: address = response["DBInstances"][0]["Endpoint"]["Address"] return f"postgresql://{user}:{password}@{address}:5432/{self.db_name}" - def cleanup(self) -> None: - self.delete_db_tables() - self.ec2_client.revoke_security_group_ingress(**self.vpc_sg_rule_params) - def delete(self) -> None: # never used (AWS tests skipped) if not hasattr(self, "rds_client"): return + self.remove_ingress_rules() self.rds_client.delete_db_instance( DBInstanceIdentifier=self.db_name, SkipFinalSnapshot=True, diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fd58219..656b4e3 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,21 +1,31 @@ +import datetime import shutil -from typing import Any, Generator, cast +from typing import Generator, cast import pytest +from mypy_boto3_s3 import S3Client from tests.conftest import RDSTestingInstance, S3TestingBucket -from workerfacing_api import settings +from workerfacing_api.core.auth import APIKeyDependency, GroupClaims from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem -from workerfacing_api.core.queue import RDSJobQueue +from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue from workerfacing_api.dependencies import ( - APIKeyDependency, - GroupClaims, authorizer, current_user_dep, filesystem_dep, queue_dep, ) from workerfacing_api.main import workerfacing_app +from workerfacing_api.schemas.queue_jobs import ( + AppSpecs, + EnvironmentTypes, + HandlerSpecs, + HardwareSpecs, + JobSpecs, + MetaSpecs, + PathsUploadSpecs, + SubmittedJob, +) @pytest.fixture(scope="session") @@ -24,8 +34,8 @@ def test_username() -> str: @pytest.fixture(scope="session") -def base_dir() -> str: - return "int_test_dir" +def base_dir(tmp_path_factory: pytest.TempPathFactory) -> str: + return str(tmp_path_factory.mktemp("int_test_dir")) @pytest.fixture(scope="session") @@ -35,101 +45,97 @@ def internal_api_key_secret() -> str: @pytest.fixture( scope="session", - params=["local", pytest.param("aws", marks=pytest.mark.aws)], + params=["local-fs", pytest.param("aws-fs", marks=pytest.mark.aws)], ) -def env( - request: pytest.FixtureRequest, - rds_testing_instance: RDSTestingInstance, - s3_testing_bucket: S3TestingBucket, -) -> Generator[str, Any, None]: - env = cast(str, request.param) - if env == "aws": - rds_testing_instance.create() - s3_testing_bucket.create() - yield env - if env == "aws": - rds_testing_instance.cleanup() - s3_testing_bucket.cleanup() - - -@pytest.fixture(scope="session") def base_filesystem( - env: str, base_dir: str, - monkeypatch_module: pytest.MonkeyPatch, s3_testing_bucket: S3TestingBucket, -) -> Generator[FileSystem, Any, None]: - monkeypatch_module.setattr( - settings, - "user_data_root_path", - base_dir, - ) - monkeypatch_module.setattr( - settings, - "filesystem", - "local" if env == "local" else "s3", - ) - - if env == "local": - shutil.rmtree(base_dir, ignore_errors=True) - yield LocalFilesystem(base_dir, base_dir) - shutil.rmtree(base_dir, ignore_errors=True) - - elif env == "aws": - # Update settings to use the actual unique bucket name created by S3TestingBucket - monkeypatch_module.setattr( - settings, - "s3_bucket", - s3_testing_bucket.bucket_name, - ) - yield S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name) - s3_testing_bucket.cleanup() - + request: pytest.FixtureRequest, +) -> FileSystem: + if request.param == "local-fs": + return LocalFilesystem(base_dir, base_dir) + elif request.param == "aws-fs": + s3_testing_bucket.create() + return S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name) else: raise NotImplementedError -@pytest.fixture(scope="session") +@pytest.fixture( + scope="session", + params=["local-queue", pytest.param("aws-queue", marks=pytest.mark.aws)], +) def queue( - env: str, + base_filesystem: FileSystem, + s3_testing_bucket: S3TestingBucket, rds_testing_instance: RDSTestingInstance, tmpdir_factory: pytest.TempdirFactory, -) -> Generator[RDSJobQueue, Any, None]: - if env == "local": - queue = RDSJobQueue( - f"sqlite:///{tmpdir_factory.mktemp('integration')}/local.db" + request: pytest.FixtureRequest, +) -> RDSJobQueue: + retry_different = False # allow retries on same worker + if request.param == "local-queue": + queue_path = tmpdir_factory.mktemp("integration") / "local.db" + s3_bucket: str | None = None + s3_client: S3Client | None = None + if isinstance(base_filesystem, S3Filesystem): + s3_bucket = s3_testing_bucket.bucket_name + s3_client = s3_testing_bucket.s3_client + return SQLiteRDSJobQueue( + f"sqlite:///{queue_path}", + retry_different=retry_different, + s3_client=s3_client, + s3_bucket=s3_bucket, ) + elif request.param == "aws-queue": + if isinstance(base_filesystem, LocalFilesystem): + pytest.skip("Only testing RDS queue in combination with S3 filesystem") + rds_testing_instance.create() + return RDSJobQueue(rds_testing_instance.db_url, retry_different=retry_different) else: - queue = RDSJobQueue(rds_testing_instance.db_url) - queue.create(err_on_exists=True) - yield queue + raise NotImplementedError -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_filesystem_dep( - base_filesystem: FileSystem, monkeypatch_module: pytest.MonkeyPatch -) -> None: - monkeypatch_module.setitem( + base_filesystem: FileSystem, + s3_testing_bucket: S3TestingBucket, + base_dir: str, + monkeypatch: pytest.MonkeyPatch, +) -> Generator[None, None, None]: + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore filesystem_dep, lambda: base_filesystem, ) + yield + # cleanup after every test + if isinstance(base_filesystem, S3Filesystem): + s3_testing_bucket.cleanup() + else: + shutil.rmtree(base_dir, ignore_errors=True) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_queue_dep( - queue: RDSJobQueue, monkeypatch_module: pytest.MonkeyPatch -) -> None: - monkeypatch_module.setitem( + queue: RDSJobQueue, + rds_testing_instance: RDSTestingInstance, + monkeypatch: pytest.MonkeyPatch, +) -> Generator[None, None, None]: + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore queue_dep, lambda: queue, ) + yield + if isinstance(queue, SQLiteRDSJobQueue): + queue.delete() + else: + rds_testing_instance.cleanup() -@pytest.fixture(scope="session", autouse=True) -def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> None: - monkeypatch_module.setitem( +@pytest.fixture(autouse=True) +def override_auth(monkeypatch: pytest.MonkeyPatch, test_username: str) -> None: + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore current_user_dep, lambda: GroupClaims( @@ -142,13 +148,42 @@ def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> ) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def override_internal_api_key_secret( - monkeypatch_module: pytest.MonkeyPatch, internal_api_key_secret: str + monkeypatch: pytest.MonkeyPatch, internal_api_key_secret: str ) -> str: - monkeypatch_module.setitem( + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore authorizer, APIKeyDependency(internal_api_key_secret), ) return internal_api_key_secret + + +@pytest.fixture +def base_job(base_filesystem: FileSystem, test_username: str) -> SubmittedJob: + time_now = datetime.datetime.now(datetime.timezone.utc).isoformat() + if isinstance(base_filesystem, S3Filesystem): + base_path = f"s3://{base_filesystem.bucket}" + else: + base_path = cast(LocalFilesystem, base_filesystem).base_post_path + paths_upload = PathsUploadSpecs( + output=f"{base_path}/{test_username}/test_out/1", + log=f"{base_path}/{test_username}/test_log/1", + artifact=f"{base_path}/{test_username}/test_arti/1", + ) + return SubmittedJob( + job=JobSpecs( + app=AppSpecs(cmd=["cmd"], env={"env": "var"}), + handler=HandlerSpecs(image_url="u", files_up={"output": "out"}), + hardware=HardwareSpecs(), + meta=MetaSpecs( + job_id=1, + date_created=time_now, + ), + ), + environment=EnvironmentTypes.local, + group=None, + priority=1, + paths_upload=paths_upload, + ) diff --git a/tests/integration/endpoints/conftest.py b/tests/integration/endpoints/conftest.py index 6a8fd0d..23a69e3 100644 --- a/tests/integration/endpoints/conftest.py +++ b/tests/integration/endpoints/conftest.py @@ -1,6 +1,6 @@ import abc from dataclasses import dataclass, field -from typing import Any +from typing import Any, Generator import pytest from fastapi.testclient import TestClient @@ -9,6 +9,13 @@ from workerfacing_api.main import workerfacing_app +@pytest.fixture +def client() -> Generator[TestClient, None, None]: + # run everything in lifespan context + with TestClient(workerfacing_app) as client: + yield client + + @dataclass class EndpointParams: method: str @@ -24,10 +31,6 @@ class _TestEndpoint(abc.ABC): def passing_params(self, *args: Any, **kwargs: Any) -> list[EndpointParams]: raise NotImplementedError - @pytest.fixture(scope="session") - def client(self) -> TestClient: - return TestClient(workerfacing_app) - def test_required_auth( self, monkeypatch: pytest.MonkeyPatch, diff --git a/tests/integration/endpoints/test_files.py b/tests/integration/endpoints/test_files.py index 69d7283..62ee28b 100644 --- a/tests/integration/endpoints/test_files.py +++ b/tests/integration/endpoints/test_files.py @@ -1,6 +1,5 @@ import os from io import BytesIO -from typing import cast import pytest import requests @@ -10,67 +9,60 @@ from workerfacing_api.core.filesystem import FileSystem, S3Filesystem -@pytest.fixture(scope="session") +@pytest.fixture def data_file1_name(base_dir: str) -> str: return f"{base_dir}/data/test/data_file1.txt" -@pytest.fixture(scope="session") -def data_file1_path(env: str, data_file1_name: str, base_filesystem: FileSystem) -> str: - if env == "aws": - base_filesystem = cast(S3Filesystem, base_filesystem) +@pytest.fixture +def data_file1_path(data_file1_name: str, base_filesystem: FileSystem) -> str: + if isinstance(base_filesystem, S3Filesystem): return f"s3://{base_filesystem.bucket}/{data_file1_name}" return data_file1_name -@pytest.fixture(scope="session") +@pytest.fixture def data_file1_contents() -> str: return "data_file1" -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def data_file1( - env: str, - base_filesystem: FileSystem, - data_file1_name: str, - data_file1_contents: str, + base_filesystem: FileSystem, data_file1_name: str, data_file1_contents: str ) -> None: - if env == "local": - os.makedirs(os.path.dirname(data_file1_name), exist_ok=True) - with open(data_file1_name, "w") as f: - f.write(data_file1_contents) - else: - base_filesystem = cast(S3Filesystem, base_filesystem) + if isinstance(base_filesystem, S3Filesystem): base_filesystem.s3_client.put_object( Bucket=base_filesystem.bucket, Key=data_file1_name, Body=BytesIO(data_file1_contents.encode("utf-8")), ) + else: + os.makedirs(os.path.dirname(data_file1_name), exist_ok=True) + with open(data_file1_name, "w") as f: + f.write(data_file1_contents) class TestFiles(_TestEndpoint): endpoint = "/files" - @pytest.fixture(scope="session") + @pytest.fixture def passing_params(self, data_file1_path: str) -> list[EndpointParams]: - return [ - EndpointParams("get", f"{data_file1_path}/url"), - ] + return [EndpointParams("get", f"{data_file1_path}/url")] def test_get_file( self, - env: str, data_file1_path: str, data_file1_contents: str, client: TestClient, + base_filesystem: FileSystem, ) -> None: - if env == "local": + if isinstance(base_filesystem, S3Filesystem): file_resp = client.get(f"{self.endpoint}/{data_file1_path}/download") - assert file_resp.status_code == 200 - assert file_resp.content.decode("utf-8") == data_file1_contents + assert file_resp.status_code == 403 else: file_resp = client.get(f"{self.endpoint}/{data_file1_path}/download") - assert file_resp.status_code == 403 + assert file_resp.status_code == 200 + assert file_resp.content.decode("utf-8") == data_file1_contents def test_get_file_not_exists( self, data_file1_path: str, client: TestClient @@ -84,21 +76,21 @@ def test_get_file_not_permitted(self, client: TestClient) -> None: def test_get_file_url( self, - env: str, data_file1_path: str, data_file1_contents: str, client: TestClient, + base_filesystem: FileSystem, ) -> None: req = f"{self.endpoint}/{data_file1_path}/url" url_resp = client.get(req) assert url_resp.status_code == 200 - if env == "local": - assert req.replace("/url", "/download") in url_resp.text - else: + if isinstance(base_filesystem, S3Filesystem): assert ( requests.request(**url_resp.json()).content.decode("utf-8") == data_file1_contents ) + else: + assert req.replace("/url", "/download") in url_resp.text def test_get_file_url_not_exists( self, data_file1_path: str, client: TestClient diff --git a/tests/integration/endpoints/test_jobs.py b/tests/integration/endpoints/test_jobs.py index 3243c44..4650d9f 100644 --- a/tests/integration/endpoints/test_jobs.py +++ b/tests/integration/endpoints/test_jobs.py @@ -1,4 +1,3 @@ -import datetime import os import time from io import BytesIO @@ -9,91 +8,22 @@ import requests from fastapi.testclient import TestClient -from tests.conftest import RDSTestingInstance from tests.integration.endpoints.conftest import EndpointParams, _TestEndpoint from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem from workerfacing_api.core.queue import RDSJobQueue from workerfacing_api.crud import job_tracking from workerfacing_api.exceptions import JobDeletedException -from workerfacing_api.schemas.queue_jobs import ( - AppSpecs, - EnvironmentTypes, - HandlerSpecs, - HardwareSpecs, - JobSpecs, - MetaSpecs, - PathsUploadSpecs, - SubmittedJob, -) +from workerfacing_api.schemas.queue_jobs import EnvironmentTypes, SubmittedJob from workerfacing_api.schemas.rds_models import JobStates -@pytest.fixture(scope="session") -def app() -> AppSpecs: - return AppSpecs(cmd=["cmd"], env={"env": "var"}) - - -@pytest.fixture(scope="session") -def handler() -> HandlerSpecs: - return HandlerSpecs(image_url="u", files_up={"output": "out"}) - - -@pytest.fixture(scope="session") -def paths_upload( - env: str, test_username: str, base_filesystem: FileSystem -) -> PathsUploadSpecs: - if env == "local": - base_path = cast(LocalFilesystem, base_filesystem).base_post_path - else: - base_path = f"s3://{cast(S3Filesystem, base_filesystem).bucket}" - return PathsUploadSpecs( - output=f"{base_path}/{test_username}/test_out/1", - log=f"{base_path}/{test_username}/test_log/1", - artifact=f"{base_path}/{test_username}/test_arti/1", - ) - - class TestJobs(_TestEndpoint): endpoint = "/jobs" - @pytest.fixture(scope="session") + @pytest.fixture def passing_params(self) -> list[EndpointParams]: return [EndpointParams("get", params={"memory": 1})] - @pytest.fixture(scope="function", autouse=True) - def cleanup_queue( - self, - queue: RDSJobQueue, - env: str, - rds_testing_instance: RDSTestingInstance, - ) -> None: - if env == "local": - queue.delete() - else: - rds_testing_instance.cleanup() - queue.create() - - @pytest.fixture(scope="function") - def base_job( - self, app: AppSpecs, handler: HandlerSpecs, paths_upload: PathsUploadSpecs - ) -> SubmittedJob: - time_now = datetime.datetime.now(datetime.timezone.utc).isoformat() - return SubmittedJob( - job=JobSpecs( - app=app, - handler=handler, - hardware=HardwareSpecs(), - meta=MetaSpecs( - job_id=1, - date_created=time_now, - ), - ), - environment=EnvironmentTypes.local, - group=None, - priority=1, - paths_upload=paths_upload, - ) - def test_get_jobs( self, queue: RDSJobQueue, @@ -295,7 +225,6 @@ def mock_update_job(*args: Any, **kwargs: Any) -> None: def test_job_files_post( self, - env: str, queue: RDSJobQueue, base_filesystem: FileSystem, base_job: SubmittedJob, @@ -309,7 +238,7 @@ def test_job_files_post( params={"type": "output", "base_path": "test"}, ) assert res.status_code == 201 - if env == "local": + if isinstance(base_filesystem, LocalFilesystem): req_base = client else: req_base = requests # type: ignore @@ -324,8 +253,7 @@ def test_job_files_post( }, ) res.raise_for_status() - if env == "local": - base_filesystem = cast(LocalFilesystem, base_filesystem) + if isinstance(base_filesystem, LocalFilesystem): assert os.path.exists( f"{base_filesystem.base_post_path}/{test_username}/test_out/1/test/file.txt" ) diff --git a/tests/integration/endpoints/test_jobs_post.py b/tests/integration/endpoints/test_jobs_post.py index 03387bc..88ac4f2 100644 --- a/tests/integration/endpoints/test_jobs_post.py +++ b/tests/integration/endpoints/test_jobs_post.py @@ -12,13 +12,11 @@ endpoint = "/_jobs" -@pytest.fixture(scope="function") -def queue_enqueue( - monkeypatch_module: pytest.MonkeyPatch, -) -> MagicMock: +@pytest.fixture +def queue_enqueue(monkeypatch: pytest.MonkeyPatch) -> MagicMock: queue = MagicMock() queue.enqueue = MagicMock() - monkeypatch_module.setitem( + monkeypatch.setitem( workerfacing_app.dependency_overrides, # type: ignore queue_dep, lambda: queue, @@ -26,7 +24,7 @@ def queue_enqueue( return queue.enqueue -@pytest.fixture(scope="function") +@pytest.fixture def queue_job() -> dict[str, Any]: return { "job": { diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py new file mode 100644 index 0000000..1468ec8 --- /dev/null +++ b/tests/integration/test_main.py @@ -0,0 +1,173 @@ +import gzip +import sqlite3 +import tempfile +import time +from typing import cast + +import pytest +from fastapi.testclient import TestClient + +from tests.conftest import S3TestingBucket +from workerfacing_api import settings +from workerfacing_api.core.filesystem import FileSystem, S3Filesystem +from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue +from workerfacing_api.dependencies import queue_dep +from workerfacing_api.main import workerfacing_app +from workerfacing_api.schemas.queue_jobs import SubmittedJob +from workerfacing_api.schemas.rds_models import JobStates + + +@pytest.fixture +def client() -> TestClient: + return TestClient(workerfacing_app) + + +class TestCronHandleTimeouts: + @pytest.fixture(autouse=True) + def setup_timeout_failure(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Set timeout_failure to 5 seconds for faster testing (but sufficient margin).""" + monkeypatch.setattr(settings, "timeout_failure", 5) + + @pytest.fixture(autouse=True) + def setup_max_retries(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Set max retries to 1 for faster testing.""" + monkeypatch.setattr(settings, "max_retries", 1) + + @pytest.fixture(autouse=True) + def setup_cron_interval(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Set cron interval to 1 second for faster testing.""" + monkeypatch.setattr(settings, "cron_timeout_check_interval", 1) + + def test_handle_timeouts( + self, + queue: RDSJobQueue, + base_job: SubmittedJob, + client: TestClient, + ) -> None: + job_id = base_job.job.meta.job_id + with client: + # Push the job + queue.enqueue(base_job) + job = queue.get_job(job_id) + assert job.status == JobStates.queued.value + assert job.num_retries == 0 + + # Pull the job + get_params = {"memory": 1} + assert len(client.get("/jobs", params=get_params).json()) == 1 + job = queue.get_job(job_id) + assert job.status == JobStates.pulled.value + assert job.num_retries == 0 + + # Job kept alive by periodic status updates + for _ in range(5): + client.put( + f"/jobs/{job_id}/status", + params={"status": "running", "runtime_details": "Processing..."}, + ) + assert len(client.get("/jobs", params=get_params).json()) == 0 + job = queue.get_job(job_id) + assert job.status == JobStates.running.value + assert job.num_retries == 0 + time.sleep(2) + + # Let timeout (wait longer than timeout_failure) + time.sleep(10) + job = queue.get_job(job_id) + assert job.status == JobStates.queued.value + assert job.num_retries == 1 + + # Pull again + assert len(client.get("/jobs", params=get_params).json()) == 1 + job = queue.get_job(job_id) + assert job.status == JobStates.pulled.value + assert job.num_retries == 1 + + # Let timeout and fail (wait longer than timeout_failure) + time.sleep(10) + job = queue.get_job(job_id) + assert job.status == JobStates.error.value + assert job.num_retries == 1 + + +class TestCronBackupDatabase: + @pytest.fixture(autouse=True) + def skip_if_not_sqlite_s3( + self, queue: RDSJobQueue, base_filesystem: FileSystem + ) -> None: + """Skip tests if not using SQLite queue with S3 filesystem.""" + if not isinstance(queue, SQLiteRDSJobQueue) or not isinstance( + base_filesystem, S3Filesystem + ): + pytest.skip("Backup tests only run with SQLite queue and S3 filesystem") + + @pytest.fixture(autouse=True) + def setup_backup_cron_interval(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Set backup cron interval to 1 seconds for faster testing.""" + monkeypatch.setattr(settings, "cron_backup_interval", 1) + + def get_backup_nrows(self, s3_testing_bucket: S3TestingBucket) -> int: + """Helper to get number of rows in backup database.""" + response = s3_testing_bucket.s3_client.get_object( + Bucket=s3_testing_bucket.bucket_name, + Key=SQLiteRDSJobQueue.BACKUP_KEY, + ) + backup_data_gzip = response["Body"].read() + backup_data = gzip.decompress(backup_data_gzip) + with tempfile.NamedTemporaryFile(suffix=".db") as tmp_file: + tmp_file.write(backup_data) + tmp_path = tmp_file.name + conn = sqlite3.connect(tmp_path) + cursor = conn.cursor() + cursor.execute("SELECT COUNT(*) FROM queued_jobs") + n_rows = cursor.fetchone()[0] + conn.close() + return cast(int, n_rows) + + def test_sqlite_backup( + self, + queue: SQLiteRDSJobQueue, + base_job: SubmittedJob, + client: TestClient, + s3_testing_bucket: S3TestingBucket, + tmpdir_factory: pytest.TempdirFactory, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test the backup and restore functionality of the SQLiteRDSJobQueue.""" + # Startup: no backup present + with pytest.raises(s3_testing_bucket.s3_client.exceptions.NoSuchKey): + self.get_backup_nrows(s3_testing_bucket) + + with client: + # First start-up: no jobs + time.sleep(2) # wait for backup to run + assert self.get_backup_nrows(s3_testing_bucket) == 0 + + # Enqueue a job and verify it's backed up + queue.enqueue(base_job) + time.sleep(2) # wait for backup to run + assert self.get_backup_nrows(s3_testing_bucket) == 1 + + # Enqueue a second job and shutdown before backup runs + queue.enqueue(base_job) + + # On shutdown, final backup should run with both jobs + assert self.get_backup_nrows(s3_testing_bucket) == 2 + + # New queue (e.g., application started again) should restore from backup + new_db_url = f"sqlite:///{tmpdir_factory.mktemp('integration') / 'restored.db'}" + new_queue = SQLiteRDSJobQueue( + new_db_url, + s3_client=s3_testing_bucket.s3_client, + s3_bucket=s3_testing_bucket.bucket_name, + ) + monkeypatch.setitem( + workerfacing_app.dependency_overrides, # type: ignore + queue_dep, + lambda: new_queue, + ) + with client: + assert ( + len(client.get("/jobs", params={"memory": 1, "limit": 5}).json()) == 2 + ) + assert self.get_backup_nrows(s3_testing_bucket) == 2 diff --git a/tests/unit/core/test_queue.py b/tests/unit/core/test_queue.py index 5d527f7..440750a 100644 --- a/tests/unit/core/test_queue.py +++ b/tests/unit/core/test_queue.py @@ -69,7 +69,7 @@ def queue( success = False for _ in range(10): # i.p. SQS, RDS, etc. might need some time to delete try: - base_queue.create(err_on_exists=True) + base_queue.create() success = True break except Exception: @@ -171,6 +171,7 @@ def base_queue( base_queue.delete() +@pytest.mark.deprecated class TestSQSQueue(_TestJobQueue): @pytest.fixture( params=[True, pytest.param(False, marks=pytest.mark.aws)], scope="class" diff --git a/workerfacing_api/core/auth.py b/workerfacing_api/core/auth.py new file mode 100644 index 0000000..6fd0237 --- /dev/null +++ b/workerfacing_api/core/auth.py @@ -0,0 +1,33 @@ +from typing import Any + +from fastapi import Header, HTTPException +from fastapi.security import HTTPAuthorizationCredentials +from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore +from pydantic import Field + + +# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py +class APIKeyDependency: + def __init__(self, key: str | None): + self.key = key + + def __call__(self, x_api_key: str | None = Header(...)) -> str | None: + if x_api_key != self.key: + raise HTTPException(status_code=401, detail="unauthorized") + return x_api_key + + +class GroupClaims(CognitoClaims): # type: ignore + cognito_groups: list[str] | None = Field(alias="cognito:groups") + + +class WorkerGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore + user_info = GroupClaims + + async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any: + user_info = await super().call(http_auth) + if "workers" not in (getattr(user_info, "cognito_groups") or []): + raise HTTPException( + status_code=403, detail="Not a member of the 'workers' group" + ) + return user_info diff --git a/workerfacing_api/core/queue.py b/workerfacing_api/core/queue.py index a9e668b..81ed92d 100644 --- a/workerfacing_api/core/queue.py +++ b/workerfacing_api/core/queue.py @@ -1,22 +1,27 @@ import datetime +import gzip import json import os import pickle +import sqlite3 +import tempfile import threading import time from abc import ABC, abstractmethod +from contextlib import nullcontext from types import TracebackType -from typing import Any, Type +from typing import Any, Type, cast import botocore.exceptions +from botocore.exceptions import ClientError from deprecated import deprecated from dict_hash import sha256 +from mypy_boto3_s3 import S3Client from mypy_boto3_sqs import SQSClient -from sqlalchemy import create_engine, inspect, not_ +from sqlalchemy import create_engine, not_ from sqlalchemy.engine import Engine from sqlalchemy.orm import Query, Session -from workerfacing_api import settings from workerfacing_api.crud import job_tracking from workerfacing_api.exceptions import JobDeletedException, JobNotAssignedException from workerfacing_api.schemas.queue_jobs import ( @@ -49,30 +54,11 @@ def __exit__( self.lock.release() -class MockUpdateLock: - """ - Mock context manager. - Used for RDSQueue on databases that are not SQLite, - since locking is already achieved via `with_for_update`. - """ - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: Type[BaseException] | None, - exc_val: BaseException | None, - exc_tb: TracebackType | None, - ) -> None: - pass - - class JobQueue(ABC): """Abstract multi-environment job queue.""" @abstractmethod - def create(self, err_on_exists: bool = True) -> None: + def create(self) -> None: """Create the initialized queue.""" raise NotImplementedError @@ -136,9 +122,7 @@ def __init__(self, queue_path: str): self.queue_path = queue_path self.update_lock = UpdateLock() - def create(self, err_on_exists: bool = True) -> None: - if os.path.exists(self.queue_path) and err_on_exists: - raise ValueError("A queue at this path already exists.") + def create(self) -> None: queue: dict[EnvironmentTypes, list[SubmittedJob]] = { env: [] for env in EnvironmentTypes } @@ -231,7 +215,7 @@ def __init__(self, sqs_client: SQSClient): except self.sqs_client.exceptions.QueueDoesNotExist: pass - def create(self, err_on_exists: bool = True) -> None: + def create(self) -> None: for environment, queue_name in self.queue_names.items(): try: res = self.sqs_client.create_queue( @@ -244,10 +228,7 @@ def create(self, err_on_exists: bool = True) -> None: ) self.queue_urls[environment] = res["QueueUrl"] except self.sqs_client.exceptions.QueueNameExists: - if err_on_exists: - raise ValueError( - f"A queue with the name {queue_name} already exists." - ) + pass def delete(self) -> None: for queue_url in self.queue_urls.values(): @@ -332,44 +313,47 @@ class RDSJobQueue(JobQueue): Allows job tracking. """ - def __init__(self, db_url: str, max_retries: int = 10, retry_wait: int = 60): + def __init__( + self, + db_url: str, + retry_different: bool = True, + connect_kwargs: dict[str, Any] | None = None, + locking_context: UpdateLock | None = None, + ): self.db_url = db_url - self.update_lock = ( - UpdateLock() if self.db_url.startswith("sqlite") else MockUpdateLock() - ) - self.engine = self._get_engine(self.db_url, max_retries, retry_wait) + self.retry_different = retry_different + self.update_lock = locking_context or nullcontext() + self.connect_kwargs = connect_kwargs or {} self.table_name = QueuedJob.__tablename__ - def _get_engine(self, db_url: str, max_retries: int, retry_wait: int) -> Engine: + @property + def engine(self) -> Engine: + if hasattr(self, "_engine"): + return cast(Engine, self._engine) # type: ignore[has-type] retries = 0 while True: try: - engine = create_engine( - db_url, - connect_args=( - {"check_same_thread": False} - if db_url.startswith("sqlite") - else {} - ), - ) + engine = create_engine(self.db_url, connect_args=self.connect_kwargs) # Attempt to create a connection or perform any necessary operations engine.connect() + self._engine = engine return engine # Connection successful except Exception as e: - if retries >= max_retries: + if retries >= 10: raise RuntimeError(f"Could not create engine: {str(e)}") retries += 1 - time.sleep(retry_wait) + time.sleep(60) - def create(self, err_on_exists: bool = True) -> None: - inspector = inspect(self.engine) - if inspector.has_table(self.table_name) and err_on_exists: - raise ValueError(f"A table with the name {self.table_name} already exists.") + def create(self) -> None: Base.metadata.create_all(self.engine) def delete(self) -> None: Base.metadata.drop_all(self.engine) + def get_all(self) -> Any: + with Session(self.engine) as session: + return session.query(QueuedJob).all() + def enqueue(self, job: SubmittedJob) -> None: with Session(self.engine) as session: session.add( @@ -426,7 +410,7 @@ def filter_sort_query(query: Query[QueuedJob]) -> QueuedJob | None: (QueuedJob.gpu_mem <= filter.gpu_mem) | (QueuedJob.gpu_mem.is_(None)), ) - if settings.retry_different: + if self.retry_different: # only if worker did not already try running this job query = query.filter(not_(QueuedJob.workers.contains(hostname))) query = query.order_by(QueuedJob.priority.desc()).order_by( @@ -541,7 +525,11 @@ def handle_timeouts( < time_now - datetime.timedelta(seconds=timeout_failure) ), ) - jobs_retry = jobs_timeout.filter(QueuedJob.num_retries < max_retries) + # Evaluate both queries before modifying any jobs to avoid race condition + jobs_retry = jobs_timeout.filter(QueuedJob.num_retries < max_retries).all() + jobs_failed = jobs_timeout.filter( + QueuedJob.num_retries >= max_retries + ).all() for job in jobs_retry: # TODO: increase priority? job.num_retries += 1 @@ -556,7 +544,6 @@ def handle_timeouts( except JobDeletedException: # job probably deleted by user, skip updating status pass - jobs_failed = jobs_timeout.filter(QueuedJob.num_retries >= max_retries) for job in jobs_failed: try: self.update_job_status( @@ -570,3 +557,88 @@ def handle_timeouts( pass session.commit() return n_retry, n_failed + + def backup(self) -> bool: + """Backup the database. To be implemented by subclasses if supported.""" + return False + + +class SQLiteRDSJobQueue(RDSJobQueue): + """SQLite-specific RDS job queue with optional S3 backup support. + + Extends RDSJobQueue with specifics of SQLite databases. + Allows S3 backup and restore functionality. + """ + + BACKUP_KEY = "workerapi_sqlite_backup/backup.db.gz" + + def __init__( + self, + db_url: str, + retry_different: bool = True, + s3_client: S3Client | None = None, + s3_bucket: str | None = None, + ): + if not db_url.startswith("sqlite:///"): + raise ValueError(f"SQLiteRDSJobQueue requires SQLite DB URL, got: {db_url}") + if not ((s3_client is None) == (s3_bucket is None)): + raise ValueError( + "Both s3_client and s3_bucket must be provided for S3 backup/restore, or both must be None." + ) + self.s3_client = s3_client + self.s3_bucket = s3_bucket + super().__init__( + db_url, + retry_different=retry_different, + connect_kwargs={"check_same_thread": False}, + locking_context=UpdateLock(), + ) + + def create(self) -> None: + self._restore_database() + super().create() + + @property + def db_path(self) -> str: + return self.db_url[len("sqlite:///") :] + + def backup(self) -> bool: + """Backup the SQLite database to S3.""" + if not self.s3_bucket or not self.s3_client: + return False + + with tempfile.TemporaryDirectory() as temp_dir: + tmp_backup_path = os.path.join(temp_dir, "backup.db") + tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz") + with sqlite3.connect(self.db_path) as source_conn: + with sqlite3.connect(tmp_backup_path) as backup_conn: + source_conn.backup(backup_conn) + + with open(tmp_backup_path, "rb") as f_in: + with gzip.open(tmp_gzip_path, "wb") as f_out: + f_out.writelines(f_in) + self.s3_client.upload_file(tmp_gzip_path, self.s3_bucket, self.BACKUP_KEY) + return True + + def _restore_database(self) -> bool: + """Restore the SQLite database from S3.""" + if not self.s3_bucket or not self.s3_client: + return False + + try: + self.s3_client.head_object(Bucket=self.s3_bucket, Key=self.BACKUP_KEY) + except ClientError as e: + if e.response["Error"]["Code"] == "404": + return False + raise + + with tempfile.TemporaryDirectory() as temp_dir: + tmp_gzip_path = os.path.join(temp_dir, "backup.db.gz") + tmp_backup_path = os.path.join(temp_dir, "backup.db") + self.s3_client.download_file(self.s3_bucket, self.BACKUP_KEY, tmp_gzip_path) + with gzip.open(tmp_gzip_path, "rb") as f_in: + with open(tmp_backup_path, "wb") as f_out: + f_out.write(f_in.read()) + os.makedirs(os.path.dirname(self.db_path), exist_ok=True) + os.rename(tmp_backup_path, self.db_path) + return True diff --git a/workerfacing_api/dependencies.py b/workerfacing_api/dependencies.py index d2fc7e6..09f4c1c 100644 --- a/workerfacing_api/dependencies.py +++ b/workerfacing_api/dependencies.py @@ -1,20 +1,35 @@ -from typing import Any - import boto3 from botocore.config import Config from botocore.utils import fix_s3_host -from fastapi import Depends, Header, HTTPException, Request -from fastapi.security import HTTPAuthorizationCredentials -from fastapi_cloudauth.cognito import CognitoClaims, CognitoCurrentUser # type: ignore -from pydantic import Field +from fastapi import Depends, Request +from fastapi_cloudauth.cognito import CognitoClaims # type: ignore from workerfacing_api import settings -from workerfacing_api.core import filesystem, queue +from workerfacing_api.core import auth, filesystem, queue + +# S3 client setup +s3_client = None +if settings.s3_bucket: + s3_client = boto3.client( + "s3", + region_name=settings.s3_region, + config=Config(signature_version="v4", s3={"addressing_style": "path"}), + ) + # this and config=... required to avoid DNS problems with new buckets + s3_client.meta.events.unregister("before-sign.s3", fix_s3_host) # Queue queue_db_url = settings.queue_db_url -queue_ = queue.RDSJobQueue(queue_db_url) -queue_.create(err_on_exists=False) +retry_different = settings.retry_different +if queue_db_url.startswith("sqlite"): + queue_: queue.RDSJobQueue = queue.SQLiteRDSJobQueue( + db_url=queue_db_url, + retry_different=retry_different, + s3_client=s3_client, + s3_bucket=settings.s3_bucket, + ) +else: + queue_ = queue.RDSJobQueue(db_url=queue_db_url, retry_different=retry_different) def queue_dep() -> queue.RDSJobQueue: @@ -22,38 +37,11 @@ def queue_dep() -> queue.RDSJobQueue: # App-internal authentication (i.e. user-facing API <-> worker-facing API) -# https://github.com/iwpnd/fastapi-key-auth/blob/main/fastapi_key_auth/dependency/authorizer.py -class APIKeyDependency: - def __init__(self, key: str | None): - self.key = key - - def __call__(self, x_api_key: str | None = Header(...)) -> str | None: - if x_api_key != self.key: - raise HTTPException(status_code=401, detail="unauthorized") - return x_api_key - - -authorizer = APIKeyDependency(key=settings.internal_api_key_secret) +authorizer = auth.APIKeyDependency(key=settings.internal_api_key_secret) # Worker authentication -class GroupClaims(CognitoClaims): # type: ignore - cognito_groups: list[str] | None = Field(alias="cognito:groups") - - -class WorkerGroupCognitoCurrentUser(CognitoCurrentUser): # type: ignore - user_info = GroupClaims - - async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any: - user_info = await super().call(http_auth) - if "workers" not in (getattr(user_info, "cognito_groups") or []): - raise HTTPException( - status_code=403, detail="Not a member of the 'workers' group" - ) - return user_info - - -current_user_dep = WorkerGroupCognitoCurrentUser( +current_user_dep = auth.WorkerGroupCognitoCurrentUser( region=settings.cognito_region, userPoolId=settings.cognito_user_pool_id, client_id=settings.cognito_client_id, @@ -67,18 +55,11 @@ async def current_user_global_dep( return current_user -# Files +# Filesystem async def filesystem_dep() -> filesystem.FileSystem: if settings.filesystem == "s3": - s3_client = boto3.client( - "s3", - region_name=settings.s3_region, - config=Config(signature_version="v4", s3={"addressing_style": "path"}), - ) - # this and config=... required to avoid DNS problems with new buckets - s3_client.meta.events.unregister("before-sign.s3", fix_s3_host) - if settings.s3_bucket is None: - raise ValueError("S3 bucket not configured") + if s3_client is None or settings.s3_bucket is None: + raise ValueError("S3 bucket or client not configured") return filesystem.S3Filesystem(s3_client, settings.s3_bucket) elif settings.filesystem == "local": if settings.user_data_root_path is None: diff --git a/workerfacing_api/main.py b/workerfacing_api/main.py index 9c851d1..c4be533 100644 --- a/workerfacing_api/main.py +++ b/workerfacing_api/main.py @@ -1,13 +1,64 @@ +import asyncio +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + import dotenv from fastapi import Depends, FastAPI -from fastapi_utils.tasks import repeat_every dotenv.load_dotenv() +logger = logging.getLogger(__name__) + from workerfacing_api import dependencies, settings, tags +from workerfacing_api.core.queue import RDSJobQueue from workerfacing_api.endpoints import access, files, jobs, jobs_post -workerfacing_app = FastAPI(openapi_tags=tags.tags_metadata) + +async def cron_handle_timeouts(queue: RDSJobQueue) -> None: + while True: + logger.info("Silent fails check: starting...") + try: + max_retries = settings.max_retries + timeout_failure = settings.timeout_failure + n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) + logger.info(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") + except Exception as e: + logger.error(f"Silent fails check: failed with {e}") + await asyncio.sleep(settings.cron_timeout_check_interval) + + +async def cron_backup_database(queue: RDSJobQueue) -> None: + while True: + logger.info("Database backup: starting...") + # Run backup in thread pool to avoid blocking event loop; + # Fine instead of making backup async since it runs infrequently. + try: + if await asyncio.to_thread(queue.backup): + logger.info("Backed up database.") + except Exception as e: + logger.error(f"Database backup failed with {e}") + await asyncio.sleep(settings.cron_backup_interval) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + queue = app.dependency_overrides.get( + dependencies.queue_dep, dependencies.queue_dep + )() + assert isinstance(queue, RDSJobQueue) + queue.create() + task_failed_jobs = asyncio.create_task(cron_handle_timeouts(queue)) + task_backup = asyncio.create_task(cron_backup_database(queue)) + yield + task_failed_jobs.cancel() + task_backup.cancel() + await asyncio.gather(task_failed_jobs, task_backup, return_exceptions=True) + if queue.backup(): + logger.info("Created final backup on shutdown.") + + +workerfacing_app = FastAPI(openapi_tags=tags.tags_metadata, lifespan=lifespan) workerfacing_app.include_router( jobs.router, @@ -27,24 +78,6 @@ ) -queue = dependencies.queue_dep() - - -@workerfacing_app.on_event("startup") # type: ignore -@repeat_every(seconds=60, raise_exceptions=True) -async def find_failed_jobs() -> dict[str, int]: - print("Silent fails check: starting...") - try: - max_retries = settings.max_retries - timeout_failure = settings.timeout_failure - n_retry, n_fail = queue.handle_timeouts(max_retries, timeout_failure) - print(f"Silent fails check: {n_retry} re-queued, {n_fail} failed.") - return {"n_retry": n_retry, "n_fail": n_fail} - except Exception as e: - print(f"Silent fails check: failed with {e}") - return {"n_retry": 0, "n_fail": 0} - - @workerfacing_app.get("/") async def root() -> dict[str, str]: return {"message": "Welcome to the DECODE OpenCloud Worker-facing API"} diff --git a/workerfacing_api/settings.py b/workerfacing_api/settings.py index 8d06d19..7831df5 100644 --- a/workerfacing_api/settings.py +++ b/workerfacing_api/settings.py @@ -12,6 +12,11 @@ def get_secret_from_env(secret_name: str) -> str | None: return secret +# Cron job intervals +cron_timeout_check_interval = 300 # 5 minutes +cron_backup_interval = 3600 # 1 hour + + # Data filesystem = os.environ.get("FILESYSTEM") # filesystem s3_bucket = os.environ.get("S3_BUCKET")