Skip to content

Commit

Permalink
feat: upgrades models to v0.6.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jtyoung84 committed Sep 3, 2024
1 parent e8f680d commit 3cdc1c9
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 17 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ dependencies = [
'boto3',
'pydantic>=2.0',
'pydantic-settings>=2.0',
'aind-data-schema==0.33.3',
'aind-data-transfer-models==0.5.1'
'aind-data-transfer-models==0.6.2'
]

[project.optional-dependencies]
Expand Down
6 changes: 3 additions & 3 deletions src/aind_data_transfer_service/configs/job_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from typing import Any, ClassVar, Dict, List, Optional, Union

from aind_data_schema.core.data_description import build_data_name
from aind_data_schema.core.processing import ProcessName
from aind_data_schema.models.modalities import Modality
from aind_data_schema.models.platforms import Platform
from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.platforms import Platform
from aind_data_schema_models.process_names import ProcessName
from pydantic import (
ConfigDict,
Field,
Expand Down
4 changes: 2 additions & 2 deletions src/aind_data_transfer_service/configs/job_upload_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from io import BytesIO
from typing import Any, Dict, List

from aind_data_schema.models.modalities import Modality
from aind_data_schema.models.platforms import Platform
from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.platforms import Platform
from openpyxl import Workbook
from openpyxl.styles import Font
from openpyxl.utils import get_column_letter
Expand Down
10 changes: 8 additions & 2 deletions src/aind_data_transfer_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ async def validate_csv(request: Request):
job = map_csv_row_to_job(row=row)
# Construct hpc job setting most of the vars from the env
basic_jobs.append(
json.loads(job.model_dump_json(round_trip=True))
json.loads(
job.model_dump_json(
round_trip=True,
exclude_none=True,
warnings=False,
)
)
)
except ValidationError as e:
errors.append(e.json())
Expand Down Expand Up @@ -166,7 +172,7 @@ async def submit_jobs(request: Request):
content = await request.json()
try:
model = SubmitJobRequest.model_validate_json(json.dumps(content))
full_content = json.loads(model.model_dump_json())
full_content = json.loads(model.model_dump_json(warnings=False))
# TODO: Replace with httpx async client
response = requests.post(
url=os.getenv("AIND_AIRFLOW_SERVICE_URL"),
Expand Down
6 changes: 3 additions & 3 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from datetime import datetime
from pathlib import Path, PurePosixPath

from aind_data_schema.core.processing import ProcessName
from aind_data_schema.models.modalities import Modality
from aind_data_schema.models.platforms import Platform
from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.platforms import Platform
from aind_data_schema_models.process_names import ProcessName

from aind_data_transfer_service.configs.job_configs import (
BasicUploadJobConfigs,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_hpc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from pathlib import Path, PurePosixPath
from unittest.mock import patch

from aind_data_schema.models.modalities import Modality
from aind_data_schema.models.platforms import Platform
from aind_data_schema.models.process_names import ProcessName
from aind_data_schema_models.modalities import Modality
from aind_data_schema_models.platforms import Platform
from aind_data_schema_models.process_names import ProcessName
from pydantic import SecretStr

from aind_data_transfer_service.configs.job_configs import (
Expand Down
70 changes: 68 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import unittest
import warnings
from copy import deepcopy
from datetime import datetime, timedelta, timezone
from io import BytesIO
Expand Down Expand Up @@ -71,7 +72,7 @@ class TestServer(unittest.TestCase):
"HPC_AWS_DEFAULT_REGION": "aws_region",
"APP_CSRF_SECRET_KEY": "test_csrf_key",
"APP_SECRET_KEY": "test_app_key",
"HPC_STAGING_DIRECTORY": "/stage/dir",
"HPC_STAGING_DIRECTORY": "stage/dir",
"HPC_AWS_PARAM_STORE_NAME": "/some/param/store",
"OPEN_DATA_AWS_SECRET_ACCESS_KEY": "open_data_aws_key",
"OPEN_DATA_AWS_ACCESS_KEY_ID": "open_data_aws_key_id",
Expand Down Expand Up @@ -1496,7 +1497,7 @@ def test_submit_v1_jobs_200_slurm_settings(
"utf-8"
)
mock_post.return_value = mock_response
ephys_source_dir = PurePosixPath("/shared_drive/ephys_data/690165")
ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165")

s3_bucket = "private"
subject_id = "690165"
Expand Down Expand Up @@ -1538,6 +1539,71 @@ def test_submit_v1_jobs_200_slurm_settings(
)
self.assertEqual(200, submit_job_response.status_code)

@patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True)
@patch("requests.post")
def test_submit_v1_jobs_200_session_settings_config_file(
self,
mock_post: MagicMock,
):
"""Tests submit jobs success when user adds aind-metadata-mapper
settings pointing to a config file."""

session_settings = {
"session_settings": {
"job_settings": {
"user_settings_config_file": "test_bergamo_settings.json",
"job_settings_name": "Bergamo",
}
}
}

mock_response = Response()
mock_response.status_code = 200
mock_response._content = json.dumps({"message": "sent"}).encode(
"utf-8"
)
mock_post.return_value = mock_response
ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165")

s3_bucket = "private"
subject_id = "690165"
acq_datetime = datetime(2024, 2, 19, 11, 25, 17)
platform = Platform.ECEPHYS

ephys_config = ModalityConfigs(
modality=Modality.ECEPHYS,
source=ephys_source_dir,
)
project_name = "Ephys Platform"

upload_job_configs = BasicUploadJobConfigs(
project_name=project_name,
s3_bucket=s3_bucket,
platform=platform,
subject_id=subject_id,
acq_datetime=acq_datetime,
modalities=[ephys_config],
metadata_configs=session_settings,
)

upload_jobs = [upload_job_configs]
# Suppress serializer warning when using a dict instead of an object
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
submit_request = SubmitJobRequest(upload_jobs=upload_jobs)

post_request_content = json.loads(
submit_request.model_dump_json(
round_trip=True, warnings=False, exclude_none=True
)
)

with TestClient(app) as client:
submit_job_response = client.post(
url="/api/v1/submit_jobs", json=post_request_content
)
self.assertEqual(200, submit_job_response.status_code)


if __name__ == "__main__":
unittest.main()

0 comments on commit 3cdc1c9

Please sign in to comment.