diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index 9553e78..ee39913 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -51,7 +51,7 @@ jobs: aws-region: eu-central-1 - name: Run tests run: | - poetry run pytest -m "aws or not(aws)" --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt + poetry run pytest -m "aws or not(aws)" -x --junitxml=pytest.xml --cov-report=term-missing --cov=workerfacing_api | tee pytest-coverage.txt echo "test_exit_code=${PIPESTATUS[0]}" >> $GITHUB_ENV - name: Coverage comment uses: MishaKav/pytest-coverage-comment@main diff --git a/tests/conftest.py b/tests/conftest.py index c1da77f..76bb4fb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,8 @@ TEST_BUCKET_PREFIX = "decode-cloud-worker-api-tests-" REGION_NAME: BucketLocationConstraintType = "eu-central-1" +# True for Aurora clusters, False for RDS instances +USE_AURORA_CLUSTERS = True @pytest.fixture(scope="session") @@ -30,9 +32,10 @@ def patch_update_job(monkeypatch_module: pytest.MonkeyPatch) -> MagicMock: return mock_update_job -class RDSTestingInstance: +class DatabaseTestingInstance: def __init__(self, db_name: str): self.db_name = db_name + self.use_aurora = USE_AURORA_CLUSTERS def create(self) -> None: self.rds_client = boto3.client("rds", "eu-central-1") @@ -103,39 +106,125 @@ def get_db_password(self) -> str: def create_db_url(self) -> str: user = "postgres" password = self.get_db_password() - try: - response = self.rds_client.describe_db_instances( - DBInstanceIdentifier=self.db_name - ) - except self.rds_client.exceptions.DBInstanceNotFoundFault: - while True: + + if self.use_aurora: + # Aurora cluster mode + try: + response_cluster = self.rds_client.describe_db_clusters( + DBClusterIdentifier=self.db_name + ) + except self.rds_client.exceptions.DBClusterNotFoundFault: + while True: + try: + self.rds_client.create_db_cluster( + DatabaseName=self.db_name, + DBClusterIdentifier=self.db_name, + Engine="aurora-postgresql", + EngineMode="provisioned", + ServerlessV2ScalingConfiguration={ + "MinCapacity": 0, + "MaxCapacity": 1, + }, + MasterUsername=user, + MasterUserPassword=password, + DeletionProtection=False, + EnableCloudwatchLogsExports=[], + ) + break + except self.rds_client.exceptions.DBClusterAlreadyExistsFault: + pass + + max_wait_time = 600 # 10 minutes + start_time = time.time() + while True: + if time.time() - start_time > max_wait_time: + raise TimeoutError( + f"Aurora cluster {self.db_name} did not become available within {max_wait_time} seconds" + ) + response_cluster = self.rds_client.describe_db_clusters( + DBClusterIdentifier=self.db_name + ) + assert len(response_cluster["DBClusters"]) == 1 + if response_cluster["DBClusters"][0]["Status"] == "available": + break + else: + time.sleep(5) + try: - self.rds_client.create_db_instance( - DBName=self.db_name, - DBInstanceIdentifier=self.db_name, - AllocatedStorage=20, - DBInstanceClass="db.t4g.micro", - Engine="postgres", - MasterUsername=user, - MasterUserPassword=password, - DeletionProtection=False, - BackupRetentionPeriod=0, - MultiAZ=False, - EnablePerformanceInsights=False, + self.rds_client.describe_db_instances( + DBInstanceIdentifier=self.db_name ) - break - except self.rds_client.exceptions.DBInstanceAlreadyExistsFault: - pass - while True: - response = self.rds_client.describe_db_instances( + except self.rds_client.exceptions.DBInstanceNotFoundFault: + while True: + try: + self.rds_client.create_db_instance( + DBInstanceIdentifier=self.db_name, + DBClusterIdentifier=self.db_name, + DBInstanceClass="db.serverless", + Engine="aurora-postgresql", + ) + break + except self.rds_client.exceptions.DBInstanceAlreadyExistsFault: + pass + max_wait_time = 600 # 10 minutes + start_time = time.time() + while True: + if time.time() - start_time > max_wait_time: + raise TimeoutError( + f"Aurora instance {self.db_name} did not become available within {max_wait_time} seconds" + ) + response_instance = self.rds_client.describe_db_instances( + DBInstanceIdentifier=self.db_name + ) + assert len(response_instance["DBInstances"]) == 1 + if ( + response_instance["DBInstances"][0]["DBInstanceStatus"] + == "available" + ): + break + time.sleep(5) + + # Get the cluster endpoint + address = response_cluster["DBClusters"][0]["Endpoint"] + else: + # RDS instance mode + try: + response_instance = self.rds_client.describe_db_instances( DBInstanceIdentifier=self.db_name ) - assert len(response["DBInstances"]) == 1 - if response["DBInstances"][0]["DBInstanceStatus"] == "available": - break - else: - time.sleep(5) - address = response["DBInstances"][0]["Endpoint"]["Address"] + except self.rds_client.exceptions.DBInstanceNotFoundFault: + while True: + try: + self.rds_client.create_db_instance( + DBName=self.db_name, + DBInstanceIdentifier=self.db_name, + AllocatedStorage=20, + DBInstanceClass="db.t4g.micro", + Engine="postgres", + MasterUsername=user, + MasterUserPassword=password, + DeletionProtection=False, + BackupRetentionPeriod=0, + MultiAZ=False, + EnablePerformanceInsights=False, + ) + break + except self.rds_client.exceptions.DBInstanceAlreadyExistsFault: + pass + while True: + response_instance = self.rds_client.describe_db_instances( + DBInstanceIdentifier=self.db_name + ) + assert len(response_instance["DBInstances"]) == 1 + if ( + response_instance["DBInstances"][0]["DBInstanceStatus"] + == "available" + ): + break + else: + time.sleep(5) + address = response_instance["DBInstances"][0]["Endpoint"]["Address"] + return f"postgresql://{user}:{password}@{address}:5432/{self.db_name}" def cleanup(self) -> None: @@ -146,11 +235,37 @@ def delete(self) -> None: # never used (AWS tests skipped) if not hasattr(self, "rds_client"): return - self.rds_client.delete_db_instance( - DBInstanceIdentifier=self.db_name, - SkipFinalSnapshot=True, - DeleteAutomatedBackups=True, - ) + + if self.use_aurora: + # Delete Aurora instance first, then cluster + instance_identifier = f"{self.db_name}-instance" + try: + self.rds_client.delete_db_instance( + DBInstanceIdentifier=instance_identifier, + SkipFinalSnapshot=True, + ) + # Wait for instance deletion before deleting cluster + while True: + try: + self.rds_client.describe_db_instances( + DBInstanceIdentifier=instance_identifier + ) + time.sleep(5) + except self.rds_client.exceptions.DBInstanceNotFoundFault: + break + except self.rds_client.exceptions.DBInstanceNotFoundFault: + pass + + self.rds_client.delete_db_cluster( + DBClusterIdentifier=self.db_name, + SkipFinalSnapshot=True, + ) + else: + self.rds_client.delete_db_instance( + DBInstanceIdentifier=self.db_name, + SkipFinalSnapshot=True, + DeleteAutomatedBackups=True, + ) class S3TestingBucket: @@ -201,11 +316,12 @@ def delete(self) -> None: @pytest.fixture(scope="session") -def rds_testing_instance() -> Generator[RDSTestingInstance, Any, None]: - # tests themselves must create the instance by calling instance.create(); - # this way, if no test that needs the DB is run, no RDS instance is created - # instance.delete() only deletes the RDS instance if it was created - instance = RDSTestingInstance("decodecloudintegrationtestsworkerapi") +def database_testing_instance() -> Generator[DatabaseTestingInstance, Any, None]: + # tests themselves must create the database by calling instance.create(); + # this way, if no test that needs the DB is run, no database is created + # instance.delete() only deletes the database if it was created + # Uses Aurora clusters if USE_AURORA_CLUSTERS=True, RDS instances if False + instance = DatabaseTestingInstance("decodecloudintegrationtestsworkerapi") yield instance instance.delete() diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fd58219..f68420b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,7 +3,7 @@ import pytest -from tests.conftest import RDSTestingInstance, S3TestingBucket +from tests.conftest import DatabaseTestingInstance, S3TestingBucket from workerfacing_api import settings from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem from workerfacing_api.core.queue import RDSJobQueue @@ -39,16 +39,16 @@ def internal_api_key_secret() -> str: ) def env( request: pytest.FixtureRequest, - rds_testing_instance: RDSTestingInstance, + database_testing_instance: DatabaseTestingInstance, s3_testing_bucket: S3TestingBucket, ) -> Generator[str, Any, None]: env = cast(str, request.param) if env == "aws": - rds_testing_instance.create() + database_testing_instance.create() s3_testing_bucket.create() yield env if env == "aws": - rds_testing_instance.cleanup() + database_testing_instance.cleanup() s3_testing_bucket.cleanup() @@ -92,7 +92,7 @@ def base_filesystem( @pytest.fixture(scope="session") def queue( env: str, - rds_testing_instance: RDSTestingInstance, + database_testing_instance: DatabaseTestingInstance, tmpdir_factory: pytest.TempdirFactory, ) -> Generator[RDSJobQueue, Any, None]: if env == "local": @@ -100,7 +100,7 @@ def queue( f"sqlite:///{tmpdir_factory.mktemp('integration')}/local.db" ) else: - queue = RDSJobQueue(rds_testing_instance.db_url) + queue = RDSJobQueue(database_testing_instance.db_url) queue.create(err_on_exists=True) yield queue diff --git a/tests/integration/endpoints/test_jobs.py b/tests/integration/endpoints/test_jobs.py index 3243c44..2ae00dd 100644 --- a/tests/integration/endpoints/test_jobs.py +++ b/tests/integration/endpoints/test_jobs.py @@ -9,7 +9,7 @@ import requests from fastapi.testclient import TestClient -from tests.conftest import RDSTestingInstance +from tests.conftest import DatabaseTestingInstance from tests.integration.endpoints.conftest import EndpointParams, _TestEndpoint from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem from workerfacing_api.core.queue import RDSJobQueue @@ -65,12 +65,12 @@ def cleanup_queue( self, queue: RDSJobQueue, env: str, - rds_testing_instance: RDSTestingInstance, + database_testing_instance: DatabaseTestingInstance, ) -> None: if env == "local": queue.delete() else: - rds_testing_instance.cleanup() + database_testing_instance.cleanup() queue.create() @pytest.fixture(scope="function") diff --git a/tests/unit/core/test_queue.py b/tests/unit/core/test_queue.py index 5d527f7..7c9247b 100644 --- a/tests/unit/core/test_queue.py +++ b/tests/unit/core/test_queue.py @@ -10,7 +10,7 @@ import pytest from moto import mock_aws -from tests.conftest import RDSTestingInstance +from tests.conftest import DatabaseTestingInstance from workerfacing_api.core.queue import ( JobQueue, LocalJobQueue, @@ -302,8 +302,8 @@ def base_queue( class TestRDSAWSQueue(_TestRDSQueue): @pytest.fixture(scope="class") def base_queue( - self, rds_testing_instance: RDSTestingInstance + self, database_testing_instance: DatabaseTestingInstance ) -> Generator[RDSJobQueue, Any, None]: - rds_testing_instance.create() - yield RDSJobQueue(rds_testing_instance.db_url) - rds_testing_instance.cleanup() + database_testing_instance.create() + yield RDSJobQueue(database_testing_instance.db_url) + database_testing_instance.cleanup()