diff --git a/src/aind_data_transfer_service/configs/job_configs.py b/src/aind_data_transfer_service/configs/job_configs.py index 561e233..47cbddb 100644 --- a/src/aind_data_transfer_service/configs/job_configs.py +++ b/src/aind_data_transfer_service/configs/job_configs.py @@ -82,7 +82,7 @@ class BasicUploadJobConfigs(BaseSettings): _TIME_PATTERN2 = re.compile(r"^\d{1,2}:\d{1,2}:\d{1,2}$") _MODALITY_ENTRY_PATTERN = re.compile(r"^modality(\d*)$") - aws_param_store_name: str + aws_param_store_name: Optional[str] = Field(None) s3_bucket: str = Field( ..., @@ -312,7 +312,7 @@ def _parse_modality_configs_from_row(cls, cleaned_row: dict) -> None: def from_csv_row( cls, row: dict, - aws_param_store_name: str, + aws_param_store_name: Optional[str] = None, temp_directory: Optional[str] = None, ): """ diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index e52fc72..15c8d55 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -1,8 +1,9 @@ """Starts and Runs Starlette Service""" import csv import io +import json import os -from time import sleep +from asyncio import sleep from fastapi import Request from fastapi.responses import JSONResponse @@ -29,6 +30,88 @@ templates = Jinja2Templates(directory=template_directory) +async def validate_csv(request: Request): + """Validate a csv file. Return parsed contents as json.""" + async with request.form() as form: + content = await form["file"].read() + data = content.decode("utf-8") + csv_reader = csv.DictReader(io.StringIO(data)) + basic_jobs = [] + errors = [] + for row in csv_reader: + try: + job = BasicUploadJobConfigs.from_csv_row(row=row) + # Construct hpc job setting most of the vars from the env + basic_jobs.append(job.json()) + except Exception as e: + errors.append(repr(e)) + message = ( + f"Errors: {json.dumps(errors)}" + if len(errors) > 0 + else "Valid Data" + ) + status_code = 406 if len(errors) > 0 else 200 + content = {"message": message, "data": {"jobs": basic_jobs}} + return JSONResponse( + content=content, + status_code=status_code, + ) + + +async def submit_basic_jobs(request: Request): + """Post BasicJobConfigs raw json to hpc server to process.""" + content = await request.json() + hpc_client_conf = HpcClientConfigs() + hpc_client = HpcClient(configs=hpc_client_conf) + basic_jobs = content["jobs"] + hpc_jobs = [] + parsing_errors = [] + for job in basic_jobs: + try: + basic_upload_job = BasicUploadJobConfigs.parse_raw(job) + hpc_job = HpcJobConfigs(basic_upload_job_configs=basic_upload_job) + hpc_jobs.append(hpc_job) + except Exception as e: + parsing_errors.append(f"Error parsing {job}: {e.__class__}") + if parsing_errors: + status_code = 406 + message = f"Errors: {json.dumps(parsing_errors)}" + content = { + "message": message, + "data": {"responses": [], "errors": json.dumps(parsing_errors)}, + } + else: + responses = [] + hpc_errors = [] + for hpc_job in hpc_jobs: + try: + job_def = hpc_job.job_definition + response = hpc_client.submit_job(job_def) + response_json = response.json() + responses.append(response_json) + # Add pause to stagger job requests to the hpc + await sleep(0.05) + except Exception as e: + hpc_errors.append( + f"Error processing {hpc_job.basic_upload_job_configs}: " + f"{e.__class__}" + ) + message = ( + f"Errors: {json.dumps(hpc_errors)}" + if len(hpc_errors) > 0 + else "Submitted Jobs." + ) + status_code = 500 if len(hpc_errors) > 0 else 200 + content = { + "message": message, + "data": {"responses": responses, "errors": hpc_errors}, + } + return JSONResponse( + content=content, + status_code=status_code, + ) + + @csrf_protect async def index(request: Request): """GET|POST /: form handler""" @@ -65,7 +148,7 @@ async def index(request: Request): response_json = response.json() responses.append(response_json) # Add pause to stagger job requests to the hpc - sleep(1) + await sleep(1) return JSONResponse( content={ @@ -119,6 +202,10 @@ async def jobs(request: Request): routes = [ Route("/", endpoint=index, methods=["GET", "POST"]), + Route("/api/validate_csv", endpoint=validate_csv, methods=["POST"]), + Route( + "/api/submit_basic_jobs", endpoint=submit_basic_jobs, methods=["POST"] + ), Route("/jobs", endpoint=jobs, methods=["GET"]), ] diff --git a/tests/test_server.py b/tests/test_server.py index 6b23960..e82f751 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,6 +4,7 @@ import json import os import unittest +from copy import deepcopy from pathlib import Path from unittest.mock import MagicMock, call, patch @@ -12,8 +13,11 @@ from requests import Response from aind_data_transfer_service.server import app +from tests.test_configs import TestJobConfigs TEST_DIRECTORY = Path(os.path.dirname(os.path.realpath(__file__))) +SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample.csv" +# MALFORMED_SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample_malformed.csv" MOCK_DB_FILE = TEST_DIRECTORY / "test_server" / "db.json" @@ -44,6 +48,27 @@ class TestServer(unittest.TestCase): with open(MOCK_DB_FILE) as f: json_contents = json.load(f) + expected_job_configs = deepcopy(TestJobConfigs.expected_job_configs) + for config in expected_job_configs: + config.aws_param_store_name = None + + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + def test_validate_csv(self): + """Tests that form renders at startup as expected.""" + with TestClient(app) as client: + with open(TEST_DIRECTORY / "resources" / "sample.csv", "rb") as f: + files = { + "file": f, + } + response = client.post(url="/api/validate_csv", files=files) + expected_jobs = [j.json() for j in self.expected_job_configs] + expected_response = { + "message": "Valid Data", + "data": {"jobs": expected_jobs}, + } + self.assertEqual(response.status_code, 200) + self.assertEqual(expected_response, response.json()) + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) def test_index(self): """Tests that form renders at startup as expected."""