Skip to content

Commit

Permalink
feat: adds rest api to parse csv file
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 committed Sep 19, 2023
1 parent b72ec6a commit c5c7654
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 21 deletions.
8 changes: 4 additions & 4 deletions src/aind_data_transfer_service/configs/job_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class BasicUploadJobConfigs(BaseSettings):
...,
description="Data collection modalities and their directory location",
title="Modalities",
min_items=1,
)
subject_id: str = Field(..., description="Subject ID", title="Subject ID")
acq_date: date = Field(
Expand Down Expand Up @@ -244,14 +245,13 @@ def _map_row_and_key_to_modality_config(
"""
modality: str = cleaned_row[modality_key]
source = cleaned_row.get(f"{modality_key}.source")
extra_configs = cleaned_row.get(f"{modality_key}.extra_configs")

# Return None if modality not in Modality list
if modality not in list(Modality.__members__.keys()):
if modality is None or modality.strip() == "":
return None

modality_configs = ModalityConfigs(
modality=modality,
source=source,
modality=modality, source=source, extra_configs=extra_configs
)
num_id = modality_counts.get(modality)
modality_configs._number_id = num_id
Expand Down
31 changes: 19 additions & 12 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Starts and Runs Starlette Service"""
import csv
import io
import json
import logging
import os
from asyncio import sleep

Expand Down Expand Up @@ -45,13 +45,12 @@ async def validate_csv(request: Request):
basic_jobs.append(job.json())
except Exception as e:
errors.append(repr(e))
message = (
f"Errors: {json.dumps(errors)}"
if len(errors) > 0
else "Valid Data"
)
message = "There were errors" if len(errors) > 0 else "Valid Data"
status_code = 406 if len(errors) > 0 else 200
content = {"message": message, "data": {"jobs": basic_jobs}}
content = {
"message": message,
"data": {"jobs": basic_jobs, "errors": errors},
}
return JSONResponse(
content=content,
status_code=status_code,
Expand All @@ -69,16 +68,23 @@ async def submit_basic_jobs(request: Request):
for job in basic_jobs:
try:
basic_upload_job = BasicUploadJobConfigs.parse_raw(job)
# Add aws_param_store_name and temp_dir
basic_upload_job.aws_param_store_name = os.getenv(
"HPC_AWS_PARAM_STORE_NAME"
)
basic_upload_job.temp_directory = os.getenv(
"HPC_STAGING_DIRECTORY"
)
hpc_job = HpcJobConfigs(basic_upload_job_configs=basic_upload_job)
hpc_jobs.append(hpc_job)
except Exception as e:
parsing_errors.append(f"Error parsing {job}: {e.__class__}")
if parsing_errors:
status_code = 406
message = f"Errors: {json.dumps(parsing_errors)}"
message = "There were errors parsing the basic job configs"
content = {
"message": message,
"data": {"responses": [], "errors": json.dumps(parsing_errors)},
"data": {"responses": [], "errors": parsing_errors},
}
else:
responses = []
Expand All @@ -92,12 +98,13 @@ async def submit_basic_jobs(request: Request):
# Add pause to stagger job requests to the hpc
await sleep(0.05)
except Exception as e:
logging.error(repr(e))
hpc_errors.append(
f"Error processing {hpc_job.basic_upload_job_configs}: "
f"{e.__class__}"
f"Error processing "
f"{hpc_job.basic_upload_job_configs.s3_prefix}"
)
message = (
f"Errors: {json.dumps(hpc_errors)}"
"There were errors submitting jobs to the hpc."
if len(hpc_errors) > 0
else "Submitted Jobs."
)
Expand Down
4 changes: 4 additions & 0 deletions tests/resources/sample_malformed.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
modality0, modality0.source, modality1, modality1.source, s3-bucket, subject-id, experiment_type, acq-date, acq-time
ECEPHYS, dir/data_set_1, , , some_bucket, 123454, ecephys, 2020-10-10, 14-10-10
WRONG_MODALITY_HERE, dir/data_set_2, MRI, dir/data_set_3, some_bucket2, 123456, Other, 10/13/2020, 13:10:10
OPHYS, dir/data_set_2, OPHYS, dir/data_set_3, some_bucket2, 123456, Other, 10/13/2020, 13:10:10
137 changes: 132 additions & 5 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

TEST_DIRECTORY = Path(os.path.dirname(os.path.realpath(__file__)))
SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample.csv"
# MALFORMED_SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample_malformed.csv"
MALFORMED_SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample_malformed.csv"
MOCK_DB_FILE = TEST_DIRECTORY / "test_server" / "db.json"


Expand All @@ -42,7 +42,7 @@ class TestServer(unittest.TestCase):
"HPC_AWS_PARAM_STORE_NAME": "/some/param/store",
}

with open(TEST_DIRECTORY / "resources" / "sample.csv", "r") as file:
with open(SAMPLE_FILE, "r") as file:
csv_content = file.read()

with open(MOCK_DB_FILE) as f:
Expand All @@ -54,21 +54,148 @@ class TestServer(unittest.TestCase):

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_csv(self):
"""Tests that form renders at startup as expected."""
"""Tests that valid csv file is returned."""
with TestClient(app) as client:
with open(TEST_DIRECTORY / "resources" / "sample.csv", "rb") as f:
with open(SAMPLE_FILE, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
expected_jobs = [j.json() for j in self.expected_job_configs]
expected_response = {
"message": "Valid Data",
"data": {"jobs": expected_jobs},
"data": {"jobs": expected_jobs, "errors": []},
}
self.assertEqual(response.status_code, 200)
self.assertEqual(expected_response, response.json())

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_job")
def test_submit_jobs(
self, mock_submit_job: MagicMock, mock_sleep: MagicMock
):
"""Tests submit jobs success."""
mock_response = Response()
mock_response.status_code = 200
mock_response._content = b'{"message": "success"}'
mock_submit_job.return_value = mock_response
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
basic_jobs = response.json()["data"]
submit_job_response = client.post(
url="/api/submit_basic_jobs", json=basic_jobs
)
expected_response = {
"message": "Submitted Jobs.",
"data": {
"responses": [
{"message": "success"},
{"message": "success"},
{"message": "success"},
],
"errors": [],
},
}
self.assertEqual(expected_response, submit_job_response.json())
self.assertEqual(200, submit_job_response.status_code)
self.assertEqual(3, mock_sleep.call_count)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_job")
@patch("logging.error")
def test_submit_jobs_server_error(
self,
mock_log_error: MagicMock,
mock_submit_job: MagicMock,
mock_sleep: MagicMock,
):
"""Tests that submit jobs returns error if there's an issue with hpc"""
mock_response = Response()
mock_response.status_code = 500
mock_submit_job.return_value = mock_response
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
basic_jobs = response.json()["data"]
submit_job_response = client.post(
url="/api/submit_basic_jobs", json=basic_jobs
)
expected_response = {
"message": "There were errors submitting jobs to the hpc.",
"data": {
"responses": [],
"errors": [
"Error processing ecephys_123454_2020-10-10_14-10-10",
"Error processing Other_123456_2020-10-13_13-10-10",
"Error processing Other_123456_2020-10-13_13-10-10",
],
},
}
self.assertEqual(expected_response, submit_job_response.json())
self.assertEqual(500, submit_job_response.status_code)
self.assertEqual(0, mock_sleep.call_count)
self.assertEqual(3, mock_log_error.call_count)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_job")
@patch("logging.error")
def test_submit_jobs_malformed_json(
self,
mock_log_error: MagicMock,
mock_submit_job: MagicMock,
mock_sleep: MagicMock,
):
"""Tests that submit jobs returns parsing errors."""
mock_response = Response()
mock_response.status_code = 500
mock_submit_job.return_value = mock_response
with TestClient(app) as client:
basic_jobs = {"jobs": ['{"malformed_key": "val"}']}
submit_job_response = client.post(
url="/api/submit_basic_jobs", json=basic_jobs
)
expected_response = {
"message": "There were errors parsing the basic job configs",
"data": {
"responses": [],
"errors": [
(
'Error parsing {"malformed_key": "val"}: '
"<class 'pydantic.error_wrappers.ValidationError'>"
)
],
},
}
self.assertEqual(406, submit_job_response.status_code)
self.assertEqual(expected_response, submit_job_response.json())
self.assertEqual(0, mock_sleep.call_count)
self.assertEqual(0, mock_log_error.call_count)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_malformed_csv(self):
"""Tests that invalid csv returns errors"""
with TestClient(app) as client:
with open(MALFORMED_SAMPLE_FILE, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
["AttributeError('WRONG_MODALITY_HERE')"],
response.json()["data"]["errors"],
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_index(self):
"""Tests that form renders at startup as expected."""
Expand Down

0 comments on commit c5c7654

Please sign in to comment.