Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 102 additions & 70 deletions tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,31 @@
import datetime
import shutil
from typing import Any, Generator, cast
from typing import Generator, cast

import pytest
from mypy_boto3_s3 import S3Client

from tests.conftest import RDSTestingInstance, S3TestingBucket
from workerfacing_api import settings
from workerfacing_api.core.auth import APIKeyDependency, GroupClaims
from workerfacing_api.core.filesystem import FileSystem, LocalFilesystem, S3Filesystem
from workerfacing_api.core.queue import RDSJobQueue
from workerfacing_api.core.queue import RDSJobQueue, SQLiteRDSJobQueue
from workerfacing_api.dependencies import (
APIKeyDependency,
GroupClaims,
authorizer,
current_user_dep,
filesystem_dep,
queue_dep,
)
from workerfacing_api.main import workerfacing_app
from workerfacing_api.schemas.queue_jobs import (
AppSpecs,
EnvironmentTypes,
HandlerSpecs,
HardwareSpecs,
JobSpecs,
MetaSpecs,
PathsUploadSpecs,
SubmittedJob,
)


@pytest.fixture(scope="session")
Expand All @@ -24,8 +34,8 @@ def test_username() -> str:


@pytest.fixture(scope="session")
def base_dir() -> str:
return "int_test_dir"
def base_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
return str(tmp_path_factory.mktemp("int_test_dir"))


@pytest.fixture(scope="session")
Expand All @@ -35,99 +45,92 @@ def internal_api_key_secret() -> str:

@pytest.fixture(
scope="session",
params=["local", pytest.param("aws", marks=pytest.mark.aws)],
params=["local-fs", pytest.param("aws-fs", marks=pytest.mark.aws)],
)
def env(
request: pytest.FixtureRequest,
rds_testing_instance: RDSTestingInstance,
s3_testing_bucket: S3TestingBucket,
) -> Generator[str, Any, None]:
env = cast(str, request.param)
if env == "aws":
rds_testing_instance.create()
s3_testing_bucket.create()
yield env
if env == "aws":
rds_testing_instance.cleanup()
s3_testing_bucket.cleanup()


@pytest.fixture(scope="session")
def base_filesystem(
env: str,
base_dir: str,
monkeypatch_module: pytest.MonkeyPatch,
s3_testing_bucket: S3TestingBucket,
) -> Generator[FileSystem, Any, None]:
monkeypatch_module.setattr(
settings,
"user_data_root_path",
base_dir,
)
monkeypatch_module.setattr(
settings,
"filesystem",
"local" if env == "local" else "s3",
)

if env == "local":
shutil.rmtree(base_dir, ignore_errors=True)
yield LocalFilesystem(base_dir, base_dir)
shutil.rmtree(base_dir, ignore_errors=True)

elif env == "aws":
# Update settings to use the actual unique bucket name created by S3TestingBucket
monkeypatch_module.setattr(
settings,
"s3_bucket",
s3_testing_bucket.bucket_name,
)
yield S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name)
s3_testing_bucket.cleanup()

request: pytest.FixtureRequest,
) -> FileSystem:
if request.param == "local-fs":
return LocalFilesystem(base_dir, base_dir)
elif request.param == "aws-fs":
s3_testing_bucket.create()
return S3Filesystem(s3_testing_bucket.s3_client, s3_testing_bucket.bucket_name)
else:
raise NotImplementedError


@pytest.fixture(scope="session")
@pytest.fixture(
scope="session",
params=["local-queue", pytest.param("aws-queue", marks=pytest.mark.aws)],
)
def queue(
env: str,
base_filesystem: FileSystem,
rds_testing_instance: RDSTestingInstance,
tmpdir_factory: pytest.TempdirFactory,
) -> Generator[RDSJobQueue, Any, None]:
if env == "local":
queue = RDSJobQueue(
f"sqlite:///{tmpdir_factory.mktemp('integration')}/local.db"
request: pytest.FixtureRequest,
) -> RDSJobQueue:
retry_different = False # allow retries on same worker
if request.param == "local-queue":
queue_path = tmpdir_factory.mktemp("integration") / "local.db"
s3_bucket: str | None = None
s3_client: S3Client | None = None
if isinstance(base_filesystem, S3Filesystem):
s3_bucket = base_filesystem.bucket
s3_client = base_filesystem.s3_client
return SQLiteRDSJobQueue(
f"sqlite:///{queue_path}",
retry_different=retry_different,
s3_client=s3_client,
s3_bucket=s3_bucket,
)
elif request.param == "aws-queue":
rds_testing_instance.create()
return RDSJobQueue(rds_testing_instance.db_url, retry_different=retry_different)
else:
queue = RDSJobQueue(rds_testing_instance.db_url)
queue.create(err_on_exists=True)
yield queue
raise NotImplementedError


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def override_filesystem_dep(
base_filesystem: FileSystem, monkeypatch_module: pytest.MonkeyPatch
) -> None:
base_filesystem: FileSystem,
s3_testing_bucket: S3TestingBucket,
base_dir: str,
monkeypatch_module: pytest.MonkeyPatch,
) -> Generator[None, None, None]:
monkeypatch_module.setitem(
workerfacing_app.dependency_overrides, # type: ignore
filesystem_dep,
lambda: base_filesystem,
)
yield
# cleanup after every test
if isinstance(base_filesystem, S3Filesystem):
s3_testing_bucket.cleanup()
else:
shutil.rmtree(base_dir, ignore_errors=True)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def override_queue_dep(
queue: RDSJobQueue, monkeypatch_module: pytest.MonkeyPatch
) -> None:
queue: RDSJobQueue,
rds_testing_instance: RDSTestingInstance,
monkeypatch_module: pytest.MonkeyPatch,
) -> Generator[None, None, None]:
monkeypatch_module.setitem(
workerfacing_app.dependency_overrides, # type: ignore
queue_dep,
lambda: queue,
)
yield
if isinstance(queue, SQLiteRDSJobQueue):
queue.delete()
else:
rds_testing_instance.cleanup()


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) -> None:
monkeypatch_module.setitem(
workerfacing_app.dependency_overrides, # type: ignore
Expand All @@ -142,7 +145,7 @@ def override_auth(monkeypatch_module: pytest.MonkeyPatch, test_username: str) ->
)


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def override_internal_api_key_secret(
monkeypatch_module: pytest.MonkeyPatch, internal_api_key_secret: str
) -> str:
Expand All @@ -152,3 +155,32 @@ def override_internal_api_key_secret(
APIKeyDependency(internal_api_key_secret),
)
return internal_api_key_secret


@pytest.fixture
def base_job(base_filesystem: FileSystem, test_username: str) -> SubmittedJob:
time_now = datetime.datetime.now(datetime.timezone.utc).isoformat()
if isinstance(base_filesystem, S3Filesystem):
base_path = f"s3://{base_filesystem.bucket}"
else:
base_path = cast(LocalFilesystem, base_filesystem).base_post_path
paths_upload = PathsUploadSpecs(
output=f"{base_path}/{test_username}/test_out/1",
log=f"{base_path}/{test_username}/test_log/1",
artifact=f"{base_path}/{test_username}/test_arti/1",
)
return SubmittedJob(
job=JobSpecs(
app=AppSpecs(cmd=["cmd"], env={"env": "var"}),
handler=HandlerSpecs(image_url="u", files_up={"output": "out"}),
hardware=HardwareSpecs(),
meta=MetaSpecs(
job_id=1,
date_created=time_now,
),
),
environment=EnvironmentTypes.local,
group=None,
priority=1,
paths_upload=paths_upload,
)
13 changes: 8 additions & 5 deletions tests/integration/endpoints/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Generator

import pytest
from fastapi.testclient import TestClient
Expand All @@ -9,6 +9,13 @@
from workerfacing_api.main import workerfacing_app


@pytest.fixture
def client() -> Generator[TestClient, None, None]:
# run everything in lifespan context
with TestClient(workerfacing_app) as client:
yield client


@dataclass
class EndpointParams:
method: str
Expand All @@ -24,10 +31,6 @@ class _TestEndpoint(abc.ABC):
def passing_params(self, *args: Any, **kwargs: Any) -> list[EndpointParams]:
raise NotImplementedError

@pytest.fixture(scope="session")
def client(self) -> TestClient:
return TestClient(workerfacing_app)

def test_required_auth(
self,
monkeypatch: pytest.MonkeyPatch,
Expand Down
54 changes: 23 additions & 31 deletions tests/integration/endpoints/test_files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from io import BytesIO
from typing import cast

import pytest
import requests
Expand All @@ -10,67 +9,60 @@
from workerfacing_api.core.filesystem import FileSystem, S3Filesystem


@pytest.fixture(scope="session")
@pytest.fixture
def data_file1_name(base_dir: str) -> str:
return f"{base_dir}/data/test/data_file1.txt"


@pytest.fixture(scope="session")
def data_file1_path(env: str, data_file1_name: str, base_filesystem: FileSystem) -> str:
if env == "aws":
base_filesystem = cast(S3Filesystem, base_filesystem)
@pytest.fixture
def data_file1_path(data_file1_name: str, base_filesystem: FileSystem) -> str:
if isinstance(base_filesystem, S3Filesystem):
return f"s3://{base_filesystem.bucket}/{data_file1_name}"
return data_file1_name


@pytest.fixture(scope="session")
@pytest.fixture
def data_file1_contents() -> str:
return "data_file1"


@pytest.fixture(scope="session", autouse=True)
@pytest.fixture(autouse=True)
def data_file1(
env: str,
base_filesystem: FileSystem,
data_file1_name: str,
data_file1_contents: str,
base_filesystem: FileSystem, data_file1_name: str, data_file1_contents: str
) -> None:
if env == "local":
os.makedirs(os.path.dirname(data_file1_name), exist_ok=True)
with open(data_file1_name, "w") as f:
f.write(data_file1_contents)
else:
base_filesystem = cast(S3Filesystem, base_filesystem)
if isinstance(base_filesystem, S3Filesystem):
base_filesystem.s3_client.put_object(
Bucket=base_filesystem.bucket,
Key=data_file1_name,
Body=BytesIO(data_file1_contents.encode("utf-8")),
)
else:
os.makedirs(os.path.dirname(data_file1_name), exist_ok=True)
with open(data_file1_name, "w") as f:
f.write(data_file1_contents)


class TestFiles(_TestEndpoint):
endpoint = "/files"

@pytest.fixture(scope="session")
@pytest.fixture
def passing_params(self, data_file1_path: str) -> list[EndpointParams]:
return [
EndpointParams("get", f"{data_file1_path}/url"),
]
return [EndpointParams("get", f"{data_file1_path}/url")]

def test_get_file(
self,
env: str,
data_file1_path: str,
data_file1_contents: str,
client: TestClient,
base_filesystem: FileSystem,
) -> None:
if env == "local":
if isinstance(base_filesystem, S3Filesystem):
file_resp = client.get(f"{self.endpoint}/{data_file1_path}/download")
assert file_resp.status_code == 200
assert file_resp.content.decode("utf-8") == data_file1_contents
assert file_resp.status_code == 403
else:
file_resp = client.get(f"{self.endpoint}/{data_file1_path}/download")
assert file_resp.status_code == 403
assert file_resp.status_code == 200
assert file_resp.content.decode("utf-8") == data_file1_contents

def test_get_file_not_exists(
self, data_file1_path: str, client: TestClient
Expand All @@ -84,21 +76,21 @@ def test_get_file_not_permitted(self, client: TestClient) -> None:

def test_get_file_url(
self,
env: str,
data_file1_path: str,
data_file1_contents: str,
client: TestClient,
base_filesystem: FileSystem,
) -> None:
req = f"{self.endpoint}/{data_file1_path}/url"
url_resp = client.get(req)
assert url_resp.status_code == 200
if env == "local":
assert req.replace("/url", "/download") in url_resp.text
else:
if isinstance(base_filesystem, S3Filesystem):
assert (
requests.request(**url_resp.json()).content.decode("utf-8")
== data_file1_contents
)
else:
assert req.replace("/url", "/download") in url_resp.text

def test_get_file_url_not_exists(
self, data_file1_path: str, client: TestClient
Expand Down
Loading
Loading