Skip to content

Commit

Permalink
feat: updates pydantic and schema version (#66)
Browse files Browse the repository at this point in the history
* feat: updates pydantic and schema version

* feat: updates error strings

* feat: update error message
  • Loading branch information
jtyoung84 authored Feb 1, 2024
1 parent 0228d12 commit bf6c33a
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 129 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ dev = [

server = [
'boto3',
'aind-data-schema==0.14.6',
'pydantic<2.0',
'aind-data-schema==0.26.5',
'pydantic>=2.0',
'pydantic-settings>=2.0',
'fastapi',
'httpx',
'jinja2',
Expand Down
150 changes: 117 additions & 33 deletions src/aind_data_transfer_service/configs/job_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,39 @@
data transfer jobs."""
import re
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

from aind_data_schema.data_description import (
Modality,
Platform,
build_data_name,
from pathlib import PurePosixPath
from typing import Any, ClassVar, Dict, List, Optional, Union

from aind_data_schema.core.data_description import build_data_name
from aind_data_schema.core.processing import ProcessName
from aind_data_schema.models.modalities import Modality
from aind_data_schema.models.platforms import Platform
from pydantic import (
Field,
PrivateAttr,
SecretStr,
ValidationInfo,
field_validator,
)
from aind_data_schema.processing import ProcessName
from pydantic import BaseSettings, Field, PrivateAttr, SecretStr, validator
from pydantic_settings import BaseSettings


class ModalityConfigs(BaseSettings):
"""Class to contain configs for each modality type"""

# Need some way to extract abbreviations. Maybe a public method can be
# added to the Modality class
_MODALITY_MAP: ClassVar = {
m().abbreviation.upper().replace("-", "_"): m().abbreviation
for m in Modality._ALL
}

# Optional number id to assign to modality config
_number_id: Optional[int] = PrivateAttr(default=None)
modality: Modality = Field(
modality: Modality.ONE_OF = Field(
..., description="Data collection modality", title="Modality"
)
source: Path = Field(
source: PurePosixPath = Field(
...,
description="Location of raw data to be uploaded",
title="Data Source",
Expand All @@ -31,8 +43,9 @@ class ModalityConfigs(BaseSettings):
default=None,
description="Run compression on data",
title="Compress Raw Data",
validate_default=True,
)
extra_configs: Optional[Path] = Field(
extra_configs: Optional[PurePosixPath] = Field(
default=None,
description="Location of additional configuration file",
title="Extra Configs",
Expand All @@ -52,19 +65,34 @@ def number_id(self):
def default_output_folder_name(self):
"""Construct the default folder name for the modality."""
if self._number_id is None:
return self.modality.value.abbreviation
return self.modality.abbreviation
else:
return self.modality.value.abbreviation + str(self._number_id)
return self.modality.abbreviation + str(self._number_id)

@field_validator("modality", mode="before")
def parse_modality_string(
cls, input_modality: Union[str, dict, Modality]
) -> Union[dict, Modality]:
"""Attempts to convert strings to a Modality model. Raises an error
if unable to do so."""
if isinstance(input_modality, str):
modality_abbreviation = cls._MODALITY_MAP.get(
input_modality.upper()
)
if modality_abbreviation is None:
raise AttributeError(f"Unknown Modality: {input_modality}")
return Modality.from_abbreviation(modality_abbreviation)
else:
return input_modality

@validator("compress_raw_data", always=True)
@field_validator("compress_raw_data", mode="after")
def get_compress_source_default(
cls, compress_source: Optional[bool], values: Dict[str, Any]
cls, compress_source: Optional[bool], info: ValidationInfo
) -> bool:
"""Set compress source default to True for ecephys data."""
if (
compress_source is None
and "modality" in values
and values["modality"] == Modality.ECEPHYS
and info.data.get("modality") == Modality.ECEPHYS
):
return True
elif compress_source is not None:
Expand All @@ -76,11 +104,16 @@ def get_compress_source_default(
class BasicUploadJobConfigs(BaseSettings):
"""Configuration for the basic upload job"""

_MODALITY_ENTRY_PATTERN = re.compile(r"^modality(\d*)$")
_DATETIME_PATTERN1 = re.compile(
# Need some way to extract abbreviations. Maybe a public method can be
# added to the Platform class
_PLATFORM_MAP: ClassVar = {
p().abbreviation.upper(): p().abbreviation for p in Platform._ALL
}
_MODALITY_ENTRY_PATTERN: ClassVar = re.compile(r"^modality(\d*)$")
_DATETIME_PATTERN1: ClassVar = re.compile(
r"^\d{4}-\d{2}-\d{2}[ |T]\d{2}:\d{2}:\d{2}$"
)
_DATETIME_PATTERN2 = re.compile(
_DATETIME_PATTERN2: ClassVar = re.compile(
r"^\d{1,2}/\d{1,2}/\d{4} \d{1,2}:\d{2}:\d{2} [APap][Mm]$"
)

Expand All @@ -91,7 +124,9 @@ class BasicUploadJobConfigs(BaseSettings):
description="Bucket where data will be uploaded",
title="S3 Bucket",
)
platform: Platform = Field(..., description="Platform", title="Platform")
platform: Platform.ONE_OF = Field(
..., description="Platform", title="Platform"
)
modalities: List[ModalityConfigs] = Field(
...,
description="Data collection modalities and their directory location",
Expand All @@ -109,13 +144,13 @@ class BasicUploadJobConfigs(BaseSettings):
description="Type of processing performed on the raw data source.",
title="Process Name",
)
metadata_dir: Optional[Path] = Field(
metadata_dir: Optional[PurePosixPath] = Field(
default=None,
description="Directory of metadata",
title="Metadata Directory",
)
# Deprecated. Will be removed in future versions.
behavior_dir: Optional[Path] = Field(
behavior_dir: Optional[PurePosixPath] = Field(
default=None,
description=(
"Directory of behavior data. This field is deprecated and will be "
Expand Down Expand Up @@ -149,7 +184,7 @@ class BasicUploadJobConfigs(BaseSettings):
),
title="Force Cloud Sync",
)
temp_directory: Optional[Path] = Field(
temp_directory: Optional[PurePosixPath] = Field(
default=None,
description=(
"As default, the file systems temporary directory will be used as "
Expand All @@ -163,11 +198,28 @@ class BasicUploadJobConfigs(BaseSettings):
def s3_prefix(self):
"""Construct s3_prefix from configs."""
return build_data_name(
label=f"{self.platform.value.abbreviation}_{self.subject_id}",
label=f"{self.platform.abbreviation}_{self.subject_id}",
creation_datetime=self.acq_datetime,
)

@validator("acq_datetime", pre=True)
@field_validator("platform", mode="before")
def parse_platform_string(
cls, input_platform: Union[str, dict, Platform]
) -> Union[dict, Platform]:
"""Attempts to convert strings to a Platform model. Raises an error
if unable to do so."""
if isinstance(input_platform, str):
platform_abbreviation = cls._PLATFORM_MAP.get(
input_platform.upper()
)
if platform_abbreviation is None:
raise AttributeError(f"Unknown Platform: {input_platform}")
else:
return Platform.from_abbreviation(platform_abbreviation)
else:
return input_platform

@field_validator("acq_datetime", mode="before")
def _parse_datetime(cls, datetime_val: Any) -> datetime:
"""Parses datetime string to %YYYY-%MM-%DD HH:mm:ss"""
is_str = isinstance(datetime_val, str)
Expand All @@ -187,14 +239,46 @@ def _parse_datetime(cls, datetime_val: Any) -> datetime:
else:
return datetime_val

@field_validator("modalities", mode="after")
def update_number_ids(
cls, modality_list: List[ModalityConfigs]
) -> List[ModalityConfigs]:
"""
Loops through the modality list and assigns a number id
to duplicate modalities. For example, if a user inputs
multiple behavior modalities, then it will upload them
as behavior, behavior1, behavior2, etc. folders.
Parameters
----------
modality_list : List[ModalityConfigs]
Returns
-------
List[ModalityConfigs]
Updates the _number_id field in the ModalityConfigs
"""
modality_counts = {}
updated_list = []
for modality in modality_list:
modality_abbreviation = modality.modality.abbreviation
if modality_counts.get(modality_abbreviation) is None:
modality_counts[modality_abbreviation] = 1
else:
modality_count_num = modality_counts[modality_abbreviation]
modality._number_id = modality_count_num
modality_counts[modality_abbreviation] += 1
updated_list.append(modality)
return updated_list

@staticmethod
def _clean_csv_entry(csv_key: str, csv_value: Optional[str]) -> Any:
"""Tries to set the default value for optional settings if the csv
entry is blank."""
if (
csv_value is None or csv_value == "" or csv_value == " "
) and BasicUploadJobConfigs.__fields__.get(csv_key) is not None:
clean_val = BasicUploadJobConfigs.__fields__[csv_key].default
) and BasicUploadJobConfigs.model_fields.get(csv_key) is not None:
clean_val = BasicUploadJobConfigs.model_fields[csv_key].default
else:
clean_val = csv_value.strip()
return clean_val
Expand Down Expand Up @@ -324,13 +408,13 @@ class HpcJobConfigs(BaseSettings):
default=50, description="Memory requested in GB"
)
hpc_partition: str
hpc_current_working_directory: Path
hpc_logging_directory: Path
hpc_current_working_directory: PurePosixPath
hpc_logging_directory: PurePosixPath
hpc_aws_secret_access_key: SecretStr
hpc_aws_access_key_id: str
hpc_aws_default_region: str
hpc_aws_session_token: Optional[str] = Field(default=None)
hpc_sif_location: Path = Field(...)
hpc_sif_location: PurePosixPath = Field(...)
hpc_alt_exec_command: Optional[str] = Field(
default=None,
description=(
Expand All @@ -342,7 +426,7 @@ class HpcJobConfigs(BaseSettings):

def _json_args_str(self) -> str:
"""Serialize job configs to json"""
return self.basic_upload_job_configs.json()
return self.basic_upload_job_configs.model_dump_json()

def _script_command_str(self) -> str:
"""This is the command that will be sent to the hpc"""
Expand Down
12 changes: 8 additions & 4 deletions src/aind_data_transfer_service/hpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import List, Optional, Union

import requests
from pydantic import BaseSettings, Field, SecretStr, validator
from pydantic import Field, SecretStr, field_validator
from pydantic_settings import BaseSettings
from requests.models import Response

from aind_data_transfer_service.hpc.models import HpcJobSubmitSettings
Expand All @@ -20,7 +21,7 @@ class HpcClientConfigs(BaseSettings):
hpc_password: SecretStr = Field(...)
hpc_token: SecretStr = Field(...)

@validator("hpc_host", "hpc_api_endpoint", pre=True)
@field_validator("hpc_host", "hpc_api_endpoint", mode="before")
def _strip_slash(cls, input_str: Optional[str]):
"""Strips trailing slash from domain."""
return None if input_str is None else input_str.strip("/")
Expand Down Expand Up @@ -132,12 +133,15 @@ def submit_hpc_job(
assert job is None or jobs is None
if job is not None:
job_def = {
"job": json.loads(job.json(exclude_none=True)),
"job": json.loads(job.model_dump_json(exclude_none=True)),
"script": script,
}
else:
job_def = {
"jobs": [json.loads(j.json(exclude_none=True)) for j in jobs],
"jobs": [
json.loads(j.model_dump_json(exclude_none=True))
for j in jobs
],
"script": script,
}

Expand Down
23 changes: 8 additions & 15 deletions src/aind_data_transfer_service/hpc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@

import json
from datetime import datetime
from pathlib import Path
from typing import Any, List, Optional, Union
from pathlib import PurePosixPath
from typing import Any, List, Literal, Optional, Union

from pydantic import (
BaseModel,
BaseSettings,
Extra,
Field,
SecretStr,
validator,
)
from pydantic.typing import Literal
from pydantic import BaseModel, Extra, Field, SecretStr, field_validator
from pydantic_settings import BaseSettings


class HpcJobSubmitSettings(BaseSettings):
Expand Down Expand Up @@ -404,7 +397,7 @@ def _set_default_val(values: dict, key: str, default_value: Any) -> None:
@classmethod
def from_upload_job_configs(
cls,
logging_directory: Path,
logging_directory: PurePosixPath,
aws_secret_access_key: SecretStr,
aws_access_key_id: str,
aws_default_region: str,
Expand All @@ -415,7 +408,7 @@ def from_upload_job_configs(
Class constructor to use when submitting a basic upload job request
Parameters
----------
logging_directory : Path
logging_directory : PurePosixPath
aws_secret_access_key : SecretStr
aws_access_key_id : str
aws_default_region : str
Expand Down Expand Up @@ -626,7 +619,7 @@ class JobStatus(BaseModel):
start_time: Optional[datetime] = Field(None)
submit_time: Optional[datetime] = Field(None)

@validator("end_time", "start_time", "submit_time", pre=True)
@field_validator("end_time", "start_time", "submit_time", mode="before")
def _parse_timestamp(
cls, timestamp: Union[int, datetime, None]
) -> Optional[datetime]:
Expand Down Expand Up @@ -656,4 +649,4 @@ def from_hpc_job_status(cls, hpc_job: HpcJobStatusResponse):
@property
def jinja_dict(self):
"""Map model to a dictionary that jinja can render"""
return self.dict(exclude_none=True)
return self.model_dump(exclude_none=True)
Loading

0 comments on commit bf6c33a

Please sign in to comment.