diff --git a/pyproject.toml b/pyproject.toml index fcb1b2d..c342487 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,8 +20,7 @@ dependencies = [ 'boto3', 'pydantic>=2.0', 'pydantic-settings>=2.0', - 'aind-data-schema==0.33.3', - 'aind-data-transfer-models==0.5.1' + 'aind-data-transfer-models==0.6.2' ] [project.optional-dependencies] diff --git a/src/aind_data_transfer_service/configs/job_configs.py b/src/aind_data_transfer_service/configs/job_configs.py index c59ef37..a1f8ef1 100644 --- a/src/aind_data_transfer_service/configs/job_configs.py +++ b/src/aind_data_transfer_service/configs/job_configs.py @@ -6,9 +6,9 @@ from typing import Any, ClassVar, Dict, List, Optional, Union from aind_data_schema.core.data_description import build_data_name -from aind_data_schema.core.processing import ProcessName -from aind_data_schema.models.modalities import Modality -from aind_data_schema.models.platforms import Platform +from aind_data_schema_models.modalities import Modality +from aind_data_schema_models.platforms import Platform +from aind_data_schema_models.process_names import ProcessName from pydantic import ( ConfigDict, Field, diff --git a/src/aind_data_transfer_service/configs/job_upload_template.py b/src/aind_data_transfer_service/configs/job_upload_template.py index 7da80b1..057132d 100644 --- a/src/aind_data_transfer_service/configs/job_upload_template.py +++ b/src/aind_data_transfer_service/configs/job_upload_template.py @@ -3,8 +3,8 @@ from io import BytesIO from typing import Any, Dict, List -from aind_data_schema.models.modalities import Modality -from aind_data_schema.models.platforms import Platform +from aind_data_schema_models.modalities import Modality +from aind_data_schema_models.platforms import Platform from openpyxl import Workbook from openpyxl.styles import Font from openpyxl.utils import get_column_letter diff --git a/src/aind_data_transfer_service/server.py b/src/aind_data_transfer_service/server.py index 3a9c890..69dc342 100644 --- a/src/aind_data_transfer_service/server.py +++ b/src/aind_data_transfer_service/server.py @@ -97,7 +97,13 @@ async def validate_csv(request: Request): job = map_csv_row_to_job(row=row) # Construct hpc job setting most of the vars from the env basic_jobs.append( - json.loads(job.model_dump_json(round_trip=True)) + json.loads( + job.model_dump_json( + round_trip=True, + exclude_none=True, + warnings=False, + ) + ) ) except ValidationError as e: errors.append(e.json()) @@ -166,7 +172,7 @@ async def submit_jobs(request: Request): content = await request.json() try: model = SubmitJobRequest.model_validate_json(json.dumps(content)) - full_content = json.loads(model.model_dump_json()) + full_content = json.loads(model.model_dump_json(warnings=False)) # TODO: Replace with httpx async client response = requests.post( url=os.getenv("AIND_AIRFLOW_SERVICE_URL"), diff --git a/tests/test_configs.py b/tests/test_configs.py index 7ce51eb..da65ceb 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -6,9 +6,9 @@ from datetime import datetime from pathlib import Path, PurePosixPath -from aind_data_schema.core.processing import ProcessName -from aind_data_schema.models.modalities import Modality -from aind_data_schema.models.platforms import Platform +from aind_data_schema_models.modalities import Modality +from aind_data_schema_models.platforms import Platform +from aind_data_schema_models.process_names import ProcessName from aind_data_transfer_service.configs.job_configs import ( BasicUploadJobConfigs, diff --git a/tests/test_hpc_models.py b/tests/test_hpc_models.py index 4d7b493..628adc0 100644 --- a/tests/test_hpc_models.py +++ b/tests/test_hpc_models.py @@ -6,9 +6,9 @@ from pathlib import Path, PurePosixPath from unittest.mock import patch -from aind_data_schema.models.modalities import Modality -from aind_data_schema.models.platforms import Platform -from aind_data_schema.models.process_names import ProcessName +from aind_data_schema_models.modalities import Modality +from aind_data_schema_models.platforms import Platform +from aind_data_schema_models.process_names import ProcessName from pydantic import SecretStr from aind_data_transfer_service.configs.job_configs import ( diff --git a/tests/test_server.py b/tests/test_server.py index 9f3ff45..15d81a4 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -3,6 +3,7 @@ import json import os import unittest +import warnings from copy import deepcopy from datetime import datetime, timedelta, timezone from io import BytesIO @@ -71,7 +72,7 @@ class TestServer(unittest.TestCase): "HPC_AWS_DEFAULT_REGION": "aws_region", "APP_CSRF_SECRET_KEY": "test_csrf_key", "APP_SECRET_KEY": "test_app_key", - "HPC_STAGING_DIRECTORY": "/stage/dir", + "HPC_STAGING_DIRECTORY": "stage/dir", "HPC_AWS_PARAM_STORE_NAME": "/some/param/store", "OPEN_DATA_AWS_SECRET_ACCESS_KEY": "open_data_aws_key", "OPEN_DATA_AWS_ACCESS_KEY_ID": "open_data_aws_key_id", @@ -1496,7 +1497,7 @@ def test_submit_v1_jobs_200_slurm_settings( "utf-8" ) mock_post.return_value = mock_response - ephys_source_dir = PurePosixPath("/shared_drive/ephys_data/690165") + ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165") s3_bucket = "private" subject_id = "690165" @@ -1538,6 +1539,71 @@ def test_submit_v1_jobs_200_slurm_settings( ) self.assertEqual(200, submit_job_response.status_code) + @patch.dict(os.environ, EXAMPLE_ENV_VAR1, clear=True) + @patch("requests.post") + def test_submit_v1_jobs_200_session_settings_config_file( + self, + mock_post: MagicMock, + ): + """Tests submit jobs success when user adds aind-metadata-mapper + settings pointing to a config file.""" + + session_settings = { + "session_settings": { + "job_settings": { + "user_settings_config_file": "test_bergamo_settings.json", + "job_settings_name": "Bergamo", + } + } + } + + mock_response = Response() + mock_response.status_code = 200 + mock_response._content = json.dumps({"message": "sent"}).encode( + "utf-8" + ) + mock_post.return_value = mock_response + ephys_source_dir = PurePosixPath("shared_drive/ephys_data/690165") + + s3_bucket = "private" + subject_id = "690165" + acq_datetime = datetime(2024, 2, 19, 11, 25, 17) + platform = Platform.ECEPHYS + + ephys_config = ModalityConfigs( + modality=Modality.ECEPHYS, + source=ephys_source_dir, + ) + project_name = "Ephys Platform" + + upload_job_configs = BasicUploadJobConfigs( + project_name=project_name, + s3_bucket=s3_bucket, + platform=platform, + subject_id=subject_id, + acq_datetime=acq_datetime, + modalities=[ephys_config], + metadata_configs=session_settings, + ) + + upload_jobs = [upload_job_configs] + # Suppress serializer warning when using a dict instead of an object + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + submit_request = SubmitJobRequest(upload_jobs=upload_jobs) + + post_request_content = json.loads( + submit_request.model_dump_json( + round_trip=True, warnings=False, exclude_none=True + ) + ) + + with TestClient(app) as client: + submit_job_response = client.post( + url="/api/v1/submit_jobs", json=post_request_content + ) + self.assertEqual(200, submit_job_response.status_code) + if __name__ == "__main__": unittest.main()