Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat 51: Enable xlsx file upload #62

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ server = [
'starlette_wtf',
'uvicorn[standard]',
'wtforms',
'requests==2.25.0'
'requests==2.25.0',
'openpyxl'
]

[tool.setuptools.packages.find]
Expand Down
38 changes: 25 additions & 13 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from asyncio import sleep
from pathlib import Path

import openpyxl
from fastapi import Request
from fastapi.responses import JSONResponse
from fastapi.templating import Jinja2Templates
Expand Down Expand Up @@ -45,22 +46,33 @@


async def validate_csv(request: Request):
"""Validate a csv file. Return parsed contents as json."""
"""Validate a csv or xlsx file. Return parsed contents as json."""
async with request.form() as form:
content = await form["file"].read()
# A few csv files created from excel have extra unicode byte chars.
# Adding "utf-8-sig" should remove them.
data = content.decode("utf-8-sig")
csv_reader = csv.DictReader(io.StringIO(data))
basic_jobs = []
errors = []
for row in csv_reader:
try:
job = BasicUploadJobConfigs.from_csv_row(row=row)
# Construct hpc job setting most of the vars from the env
basic_jobs.append(job.json())
except Exception as e:
errors.append(repr(e))
if not form["file"].filename.endswith((".csv", ".xlsx")):
errors.append("Invalid input file type")
else:
content = await form["file"].read()
if form["file"].filename.endswith(".csv"):
# A few csv files created from excel have extra unicode
# byte chars. Adding "utf-8-sig" should remove them.
data = content.decode("utf-8-sig")
else:
xlsx_sheet = openpyxl.load_workbook(io.BytesIO(content)).active
csv_io = io.StringIO()
csv_writer = csv.writer(csv_io)
for r in xlsx_sheet.rows:
csv_writer.writerow([cell.value for cell in r])
data = csv_io.getvalue()
csv_reader = csv.DictReader(io.StringIO(data))
for row in csv_reader:
try:
job = BasicUploadJobConfigs.from_csv_row(row=row)
# Construct hpc job setting most of the vars from the env
basic_jobs.append(job.json())
except Exception as e:
errors.append(repr(e))
message = "There were errors" if len(errors) > 0 else "Valid Data"
status_code = 406 if len(errors) > 0 else 200
content = {
Expand Down
12 changes: 10 additions & 2 deletions src/aind_data_transfer_service/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ <h2>Submit Jobs</h2>
</fieldset>
</div><br><br>
<form id="preview_form" method="post" enctype="multipart/form-data">
<label for="file">Please select a CSV file:</label>
<input type="file" id="file" name="file"><br><br>
<label for="file">Please select a .csv or .xlsx file:</label>
<input type="file" id="file" name="file" accept=".csv,.xlsx"><br><br>
<input type="submit" id="preview" value="preview"><br><br>
</form>
<button type="button" onclick="submitJobs()">Submit</button>
Expand All @@ -63,6 +63,14 @@ <h2>Submit Jobs</h2>
$(function() {
$("#preview_form").on("submit", function(e) {
e.preventDefault();
if ($("#file").prop("files").length != 1) {
alert("No file selected. Please attach a .csv or .xlsx file.");
return;
}
if (![".csv", ".xlsx"].some(ext => $("#file").prop("files")[0].name.endsWith(ext))) {
alert("Invalid file type. Please attach a .csv or .xlsx file.");
return;
}
var formData = new FormData(this);
$.ajax({
url: "/api/validate_csv",
Expand Down
Binary file added tests/resources/sample.xlsx
Binary file not shown.
4 changes: 4 additions & 0 deletions tests/resources/sample_invalid_ext.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
modality0, modality0.source, modality1, modality1.source, s3-bucket, subject-id, platform, acq-datetime
ECEPHYS, dir/data_set_1, ,, some_bucket, 123454, ecephys, 2020-10-10 14:10:10
BEHAVIOR_VIDEOS, dir/data_set_2, MRI, dir/data_set_3, some_bucket2, 123456, BEHAVIOR, 10/13/2020 1:10:10 PM
BEHAVIOR_VIDEOS, dir/data_set_2, BEHAVIOR_VIDEOS, dir/data_set_3, some_bucket2, 123456, BEHAVIOR, 10/13/2020 1:10:10 PM
Binary file added tests/resources/sample_malformed.xlsx
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def test_from_job_and_server_configs(self):
'[{"modality": {"name": '
'"Extracellular electrophysiology", '
'"abbreviation": "ecephys"}, '
'"source": "dir/data_set_1", "compress_raw_data": true, '
f'"source": "{repr(str(Path("dir/data_set_1")))[1:-1]}", '
'"compress_raw_data": true, '
'"extra_configs": null,'
' "skip_staging": false}],'
' "subject_id": "123454",'
Expand Down
4 changes: 2 additions & 2 deletions tests/test_hpc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def test_from_upload_job_configs(self):
"ecephys_123454_2020-10-10_14-10-10", hpc_settings.name
)
self.assertEqual(
"dir/logs/ecephys_123454_2020-10-10_14-10-10_error.out",
str(Path("dir/logs/ecephys_123454_2020-10-10_14-10-10_error.out")),
hpc_settings.standard_error,
)
self.assertEqual(
"dir/logs/ecephys_123454_2020-10-10_14-10-10.out",
str(Path("dir/logs/ecephys_123454_2020-10-10_14-10-10.out")),
hpc_settings.standard_out,
)
self.assertEqual(180, hpc_settings.time_limit)
Expand Down
64 changes: 57 additions & 7 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
from tests.test_configs import TestJobConfigs

TEST_DIRECTORY = Path(os.path.dirname(os.path.realpath(__file__)))
SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample.csv"
MALFORMED_SAMPLE_FILE = TEST_DIRECTORY / "resources" / "sample_malformed.csv"
SAMPLE_INVALID_EXT = TEST_DIRECTORY / "resources" / "sample_invalid_ext.txt"
SAMPLE_CSV = TEST_DIRECTORY / "resources" / "sample.csv"
MALFORMED_SAMPLE_CSV = TEST_DIRECTORY / "resources" / "sample_malformed.csv"
SAMPLE_XLSX = TEST_DIRECTORY / "resources" / "sample.xlsx"
MALFORMED_SAMPLE_XLSX = TEST_DIRECTORY / "resources" / "sample_malformed.xlsx"
MOCK_DB_FILE = TEST_DIRECTORY / "test_server" / "db.json"


Expand All @@ -40,7 +43,7 @@ class TestServer(unittest.TestCase):
"HPC_AWS_PARAM_STORE_NAME": "/some/param/store",
}

with open(SAMPLE_FILE, "r") as file:
with open(SAMPLE_CSV, "r") as file:
csv_content = file.read()

with open(MOCK_DB_FILE) as f:
Expand All @@ -54,7 +57,24 @@ class TestServer(unittest.TestCase):
def test_validate_csv(self):
"""Tests that valid csv file is returned."""
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
with open(SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
expected_jobs = [j.json() for j in self.expected_job_configs]
expected_response = {
"message": "Valid Data",
"data": {"jobs": expected_jobs, "errors": []},
}
self.assertEqual(response.status_code, 200)
self.assertEqual(expected_response, response.json())

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_csv_xlsx(self):
"""Tests that valid xlsx file is returned."""
with TestClient(app) as client:
with open(SAMPLE_XLSX, "rb") as f:
files = {
"file": f,
}
Expand All @@ -79,7 +99,7 @@ def test_submit_jobs(
mock_response._content = b'{"message": "success"}'
mock_submit_job.return_value = mock_response
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
with open(SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
Expand Down Expand Up @@ -118,7 +138,7 @@ def test_submit_jobs_server_error(
mock_response.status_code = 500
mock_submit_job.return_value = mock_response
with TestClient(app) as client:
with open(SAMPLE_FILE, "rb") as f:
with open(SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
Expand Down Expand Up @@ -179,11 +199,41 @@ def test_submit_jobs_malformed_json(
self.assertEqual(0, mock_sleep.call_count)
self.assertEqual(0, mock_log_error.call_count)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_null_csv(self):
"""Tests that invalid file type returns FileNotFoundError"""
with TestClient(app) as client:
with open(SAMPLE_INVALID_EXT, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
["Invalid input file type"],
response.json()["data"]["errors"],
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_malformed_csv(self):
"""Tests that invalid csv returns errors"""
with TestClient(app) as client:
with open(MALFORMED_SAMPLE_FILE, "rb") as f:
with open(MALFORMED_SAMPLE_CSV, "rb") as f:
files = {
"file": f,
}
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
["AttributeError('WRONG_MODALITY_HERE')"],
response.json()["data"]["errors"],
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
def test_validate_malformed_xlsx(self):
"""Tests that invalid xlsx returns errors"""
with TestClient(app) as client:
with open(MALFORMED_SAMPLE_XLSX, "rb") as f:
files = {
"file": f,
}
Expand Down