Skip to content

Commit

Permalink
feat: adds unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 committed Sep 19, 2023
1 parent cd2faba commit b72ec6a
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/aind_data_transfer_service/configs/job_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class BasicUploadJobConfigs(BaseSettings):
_TIME_PATTERN2 = re.compile(r"^\d{1,2}:\d{1,2}:\d{1,2}$")
_MODALITY_ENTRY_PATTERN = re.compile(r"^modality(\d*)$")

aws_param_store_name: str
aws_param_store_name: Optional[str] = Field(None)

s3_bucket: str = Field(
...,
Expand Down Expand Up @@ -312,7 +312,7 @@ def _parse_modality_configs_from_row(cls, cleaned_row: dict) -> None:
def from_csv_row(
cls,
row: dict,
aws_param_store_name: str,
aws_param_store_name: Optional[str] = None,
temp_directory: Optional[str] = None,
):
"""
Expand Down
91 changes: 89 additions & 2 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Starts and Runs Starlette Service"""
import csv
import io
import json
import os
from time import sleep
from asyncio import sleep

from fastapi import Request
from fastapi.responses import JSONResponse
Expand All @@ -29,6 +30,88 @@
templates = Jinja2Templates(directory=template_directory)


async def validate_csv(request: Request):
"""Validate a csv file. Return parsed contents as json."""
async with request.form() as form:
content = await form["file"].read()
data = content.decode("utf-8")
csv_reader = csv.DictReader(io.StringIO(data))
basic_jobs = []
errors = []
for row in csv_reader:
try:
job = BasicUploadJobConfigs.from_csv_row(row=row)
# Construct hpc job setting most of the vars from the env
basic_jobs.append(job.json())
except Exception as e:
errors.append(repr(e))
message = (
f"Errors: {json.dumps(errors)}"
if len(errors) > 0
else "Valid Data"
)
status_code = 406 if len(errors) > 0 else 200
content = {"message": message, "data": {"jobs": basic_jobs}}
return JSONResponse(
content=content,
status_code=status_code,
)


async def submit_basic_jobs(request: Request):
"""Post BasicJobConfigs raw json to hpc server to process."""
content = await request.json()
hpc_client_conf = HpcClientConfigs()
hpc_client = HpcClient(configs=hpc_client_conf)
basic_jobs = content["jobs"]
hpc_jobs = []
parsing_errors = []
for job in basic_jobs:
try:
basic_upload_job = BasicUploadJobConfigs.parse_raw(job)
hpc_job = HpcJobConfigs(basic_upload_job_configs=basic_upload_job)
hpc_jobs.append(hpc_job)
except Exception as e:
parsing_errors.append(f"Error parsing {job}: {e.__class__}")
if parsing_errors:
status_code = 406
message = f"Errors: {json.dumps(parsing_errors)}"
content = {
"message": message,
"data": {"responses": [], "errors": json.dumps(parsing_errors)},
}
else:
responses = []
hpc_errors = []
for hpc_job in hpc_jobs:
try:
job_def = hpc_job.job_definition
response = hpc_client.submit_job(job_def)
response_json = response.json()
responses.append(response_json)
# Add pause to stagger job requests to the hpc
await sleep(0.05)
except Exception as e:
hpc_errors.append(
f"Error processing {hpc_job.basic_upload_job_configs}: "
f"{e.__class__}"
)
message = (
f"Errors: {json.dumps(hpc_errors)}"
if len(hpc_errors) > 0
else "Submitted Jobs."
)
status_code = 500 if len(hpc_errors) > 0 else 200
content = {
"message": message,
"data": {"responses": responses, "errors": hpc_errors},
}
return JSONResponse(
content=content,
status_code=status_code,
)


@csrf_protect
async def index(request: Request):
"""GET|POST /: form handler"""
Expand Down Expand Up @@ -65,7 +148,7 @@ async def index(request: Request):
response_json = response.json()
responses.append(response_json)
# Add pause to stagger job requests to the hpc
sleep(1)
await sleep(1)

return JSONResponse(
content={
Expand Down Expand Up @@ -119,6 +202,10 @@ async def jobs(request: Request):

routes = [
Route("/", endpoint=index, methods=["GET", "POST"]),
Route("/api/validate_csv", endpoint=validate_csv, methods=["POST"]),
Route(
"/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"]
),
Route("/jobs", endpoint=jobs, methods=["GET"]),
]

Expand Down
25 changes: 25 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import unittest
from copy import deepcopy
from pathlib import Path
from unittest.mock import MagicMock, call, patch

Expand All @@ -12,8 +13,11 @@
from requests import Response

from aind_data_transfer_service.server import app
from tests.test_configs import TestJobConfigs

TEST_DIRECTORY = Path(os.path.dirname(os.path.realpath(__file__)))
SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample.csv"
# MALFORMED_SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample_malformed.csv"
MOCK_DB_FILE = TEST_DIRECTORY / "test_server" / "db.json"


Expand Down Expand Up @@ -44,6 +48,27 @@ class TestServer(unittest.TestCase):
with open(MOCK_DB_FILE) as f:
json_contents = json.load(f)

expected_job_configs = deepcopy(TestJobConfigs.expected_job_configs)
for config in expected_job_configs:
config.aws_param_store_name = None

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_csv(self):
"""Tests that form renders at startup as expected."""
with TestClient(app) as client:
with open(TEST_DIRECTORY / "resources" / "sample.csv", "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
expected_jobs = [j.json() for j in self.expected_job_configs]
expected_response = {
"message": "Valid Data",
"data": {"jobs": expected_jobs},
}
self.assertEqual(response.status_code, 200)
self.assertEqual(expected_response, response.json())

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_index(self):
"""Tests that form renders at startup as expected."""
Expand Down

0 comments on commit b72ec6a

Please sign in to comment.