Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds ability to write to open-data account #68

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
# HPC_STAGING_DIRECTORY
# HPC_AWS_PARAM_STORE_NAME
# BASIC_JOB_SCRIPT
# OPEN_DATA_AWS_SECRET_ACCESS_KEY
# OPEN_DATA_AWS_ACCESS_KEY_ID

OPEN_DATA_BUCKET_NAME = os.getenv("OPEN_DATA_BUCKET_NAME", "aind-open-data")


async def validate_csv(request: Request):
Expand Down Expand Up @@ -176,17 +180,27 @@ async def submit_hpc_jobs(request: Request): # noqa: C901
job["upload_job_settings"]
).s3_prefix
upload_job_configs = json.loads(job["upload_job_settings"])
# The aws creds to use are different for aind-open-data and
# everything else
if upload_job_configs.get("s3_bucket") == OPEN_DATA_BUCKET_NAME:
aws_secret_access_key = SecretStr(
os.getenv("OPEN_DATA_AWS_SECRET_ACCESS_KEY")
)
aws_access_key_id = os.getenv("OPEN_DATA_AWS_ACCESS_KEY_ID")
else:
aws_secret_access_key = SecretStr(
os.getenv("HPC_AWS_SECRET_ACCESS_KEY")
)
aws_access_key_id = os.getenv("HPC_AWS_ACCESS_KEY_ID")
hpc_settings = json.loads(job["hpc_settings"])
if basic_job_name is not None:
hpc_settings["name"] = basic_job_name
hpc_job = HpcJobSubmitSettings.from_upload_job_configs(
logging_directory=PurePosixPath(
os.getenv("HPC_LOGGING_DIRECTORY")
),
aws_secret_access_key=SecretStr(
os.getenv("HPC_AWS_SECRET_ACCESS_KEY")
),
aws_access_key_id=os.getenv("HPC_AWS_ACCESS_KEY_ID"),
aws_secret_access_key=aws_secret_access_key,
aws_access_key_id=aws_access_key_id,
aws_default_region=os.getenv("HPC_AWS_DEFAULT_REGION"),
aws_session_token=(
(
Expand Down
79 changes: 76 additions & 3 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import os
import unittest
from copy import deepcopy
from pathlib import Path
from pathlib import Path, PurePosixPath
from unittest.mock import MagicMock, patch

from fastapi.testclient import TestClient
from pydantic import SecretStr
from requests import Response

from aind_data_transfer_service.server import app
Expand Down Expand Up @@ -41,6 +42,8 @@ class TestServer(unittest.TestCase):
"APP_SECRET_KEY": "test_app_key",
"HPC_STAGING_DIRECTORY": "/stage/dir",
"HPC_AWS_PARAM_STORE_NAME": "/some/param/store",
"OPEN_DATA_AWS_SECRET_ACCESS_KEY": "open_data_aws_key",
"OPEN_DATA_AWS_ACCESS_KEY_ID": "open_data_aws_key_id",
}

with open(SAMPLE_CSV, "r") as file:
Expand Down Expand Up @@ -229,7 +232,7 @@ def test_validate_malformed_csv(self):
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
[("AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)")],
["AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)"],
response.json()["data"]["errors"],
)

Expand All @@ -244,7 +247,7 @@ def test_validate_malformed_xlsx(self):
response = client.post(url="/api/validate_csv", files=files)
self.assertEqual(response.status_code, 406)
self.assertEqual(
[("AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)")],
["AttributeError('Unknown Modality: WRONG_MODALITY_HERE',)"],
response.json()["data"]["errors"],
)

Expand Down Expand Up @@ -305,6 +308,76 @@ def test_submit_hpc_jobs(
self.assertEqual(200, submit_job_response.status_code)
self.assertEqual(2, mock_sleep.call_count)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
@patch(
"aind_data_transfer_service.hpc.models.HpcJobSubmitSettings"
".from_upload_job_configs"
)
def test_submit_hpc_jobs_open_data(
self,
mock_from_upload_configs: MagicMock,
mock_submit_job: MagicMock,
mock_sleep: MagicMock,
):
"""Tests submit hpc jobs success."""
mock_response = Response()
mock_response.status_code = 200
mock_response._content = b'{"message": "success"}'
mock_submit_job.return_value = mock_response
# When a user specifies aind-open-data in the upload_job_settings,
# use the credentials for that account.
post_request_content = {
"jobs": [
{
"hpc_settings": '{"qos":"production", "name": "job1"}',
"upload_job_settings": (
'{"s3_bucket": "aind-open-data", '
'"platform": {"name": "Behavior platform", '
'"abbreviation": "behavior"}, '
'"modalities": ['
'{"modality": {"name": "Behavior videos", '
'"abbreviation": "behavior-videos"}, '
'"source": "dir/data_set_2", '
'"compress_raw_data": true, '
'"skip_staging": false}], '
'"subject_id": "123456", '
'"acq_datetime": "2020-10-13T13:10:10", '
'"process_name": "Other", '
'"log_level": "WARNING", '
'"metadata_dir_force": false, '
'"dry_run": false, '
'"force_cloud_sync": false}'
),
"script": "",
}
]
}
with TestClient(app) as client:
submit_job_response = client.post(
url="/api/submit_hpc_jobs", json=post_request_content
)
expected_response = {
"message": "Submitted Jobs.",
"data": {
"responses": [{"message": "success"}],
"errors": [],
},
}
self.assertEqual(expected_response, submit_job_response.json())
self.assertEqual(200, submit_job_response.status_code)
self.assertEqual(1, mock_sleep.call_count)
mock_from_upload_configs.assert_called_with(
logging_directory=PurePosixPath("hpc_logs"),
aws_secret_access_key=SecretStr("open_data_aws_key"),
aws_access_key_id="open_data_aws_key_id",
aws_default_region="aws_region",
aws_session_token=None,
qos="production",
name="behavior_123456_2020-10-13_13-10-10",
)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("aind_data_transfer_service.server.sleep", return_value=None)
@patch("aind_data_transfer_service.hpc.client.HpcClient.submit_hpc_job")
Expand Down
Loading