diff --git a/src/aind_slims_api/core.py b/src/aind_slims_api/core.py index d22f4ec..f7646cf 100644 --- a/src/aind_slims_api/core.py +++ b/src/aind_slims_api/core.py @@ -8,13 +8,24 @@ methods and integration with SlimsBaseModel subtypes """ -import logging +from datetime import datetime from functools import lru_cache +from pydantic import ( + BaseModel, + ValidationInfo, + field_serializer, + field_validator, +) +from pydantic.fields import FieldInfo +import logging from typing import Literal, Optional -from slims.criteria import Criterion, conjunction, equals -from slims.internal import Record as SlimsRecord from slims.slims import Slims, _SlimsApiException +from slims.internal import ( + Column as SlimsColumn, + Record as SlimsRecord, +) +from slims.criteria import Criterion, conjunction, equals from aind_slims_api import config @@ -33,6 +44,94 @@ ] +class UnitSpec: + """Used in type annotation metadata to specify units""" + + units: list[str] + preferred_unit: str = None + + def __init__(self, *args, preferred_unit=None): + """Set list of acceptable units from args, and preferred_unit""" + self.units = args + if len(self.units) == 0: + raise ValueError("One or more units must be specified") + if preferred_unit is None: + self.preferred_unit = self.units[0] + + +def _find_unit_spec(field: FieldInfo) -> UnitSpec | None: + """Given a Pydantic FieldInfo, find the UnitSpec in its metadata""" + metadata = field.metadata + for m in metadata: + if isinstance(m, UnitSpec): + return m + return None + + +class SlimsBaseModel( + BaseModel, + from_attributes=True, + validate_assignment=True, +): + """Pydantic model to represent a SLIMS record. + Subclass with fields matching those in the SLIMS record. + + For Quantities, specify acceptable units like so: + + class MyModel(SlimsBaseModel): + myfield: Annotated[float | None, UnitSpec("g","kg")] + + Quantities will be serialized using the first unit passed + + Datetime fields will be serialized to an integer ms timestamp + """ + + pk: int = None + json_entity: dict = None + _slims_table: SLIMSTABLES + + @field_validator("*", mode="before") + def _validate(cls, value, info: ValidationInfo): + """Validates a field, accounts for Quantities""" + if isinstance(value, SlimsColumn): + if value.datatype == "QUANTITY": + unit_spec = _find_unit_spec(cls.model_fields[info.field_name]) + if unit_spec is None: + msg = ( + f'Quantity field "{info.field_name}"' + "must be annotated with a UnitSpec" + ) + raise TypeError(msg) + if value.unit not in unit_spec.units: + msg = ( + f'Unexpected unit "{value.unit}" for field ' + f"{info.field_name}, Expected {unit_spec.units}" + ) + raise ValueError(msg) + return value.value + else: + return value + + @field_serializer("*") + def _serialize(self, field, info): + """Serialize a field, accounts for Quantities and datetime""" + unit_spec = _find_unit_spec(self.model_fields[info.field_name]) + if unit_spec and field is not None: + quantity = { + "amount": field, + "unit_display": unit_spec.preferred_unit, + } + return quantity + elif isinstance(field, datetime): + return int(field.timestamp() * 10**3) + else: + return field + + # TODO: Add links - need Record.json_entity['links']['self'] + # TODO: Add Table - need Record.json_entity['tableName'] + # TODO: Support attachments + + class SlimsClient: """Wrapper around slims-python-api client with convenience methods""" @@ -133,3 +232,52 @@ def rest_link(self, table: SLIMSTABLES, **kwargs): base_url = f"{self.url}/rest/{table}" queries = [f"?{k}={v}" for k, v in kwargs.items()] return base_url + "".join(queries) + + def add_model(self, model: SlimsBaseModel, *args, **kwargs) -> SlimsBaseModel: + """Given a SlimsBaseModel object, add it to SLIMS + Args + model (SlimsBaseModel): object to add + *args (str): fields to include in the serialization + **kwargs: passed to model.model_dump() + + Returns + An instance of the same type of model, with data from + the resulting SLIMS record + """ + fields_to_include = set(args) or None + fields_to_exclude = set(kwargs.get("exclude", [])) + fields_to_exclude.add("pk") + rtn = self.add( + model._slims_table, + model.model_dump( + include=fields_to_include, + exclude=fields_to_exclude, + **kwargs, + by_alias=True, + ), + ) + return type(model).model_validate(rtn) + + def update_model(self, model: SlimsBaseModel, *args, **kwargs): + """Given a SlimsBaseModel object, update its (existing) SLIMS record + + Args + model (SlimsBaseModel): object to update + *args (str): fields to include in the serialization + **kwargs: passed to model.model_dump() + + Returns + An instance of the same type of model, with data from + the resulting SLIMS record + """ + fields_to_include = set(args) or None + rtn = self.update( + model._slims_table, + model.pk, + model.model_dump( + include=fields_to_include, + by_alias=True, + **kwargs, + ), + ) + return type(model).model_validate(rtn) diff --git a/src/aind_slims_api/mouse.py b/src/aind_slims_api/mouse.py index b9f5aaa..f11717e 100644 --- a/src/aind_slims_api/mouse.py +++ b/src/aind_slims_api/mouse.py @@ -1,17 +1,51 @@ """Contains a model for the mouse content, and a method for fetching it""" import logging -from typing import Optional +from typing import Annotated -from aind_slims_api.core import SlimsClient +from pydantic import Field, BeforeValidator, ValidationError + +from aind_slims_api.core import SlimsBaseModel, SlimsClient, UnitSpec, SLIMSTABLES logger = logging.getLogger() +class SlimsMouseContent(SlimsBaseModel): + """Model for an instance of the Mouse ContentType""" + + baseline_weight_g: Annotated[float | None, UnitSpec("g")] = Field( + ..., alias="cntn_cf_baselineWeight" + ) + point_of_contact: str | None = Field(..., alias="cntn_cf_scientificPointOfContact") + water_restricted: Annotated[bool, BeforeValidator(lambda x: x or False)] = Field( + ..., alias="cntn_cf_waterRestricted" + ) + barcode: str = Field(..., alias="cntn_barCode") + pk: int = Field(..., alias="cntn_pk") + + _slims_table: SLIMSTABLES = "Content" + + # TODO: Include other helpful fields (genotype, gender...) + + # pk: callable + # cntn_fk_category: SlimsColumn + # cntn_fk_contentType: SlimsColumn + # cntn_barCode: SlimsColumn + # cntn_id: SlimsColumn + # cntn_cf_contactPerson: SlimsColumn + # cntn_status: SlimsColumn + # cntn_fk_status: SlimsColumn + # cntn_fk_user: SlimsColumn + # cntn_cf_fk_fundingCode: SlimsColumn + # cntn_cf_genotype: SlimsColumn + # cntn_cf_labtracksId: SlimsColumn + # cntn_cf_parentBarcode: SlimsColumn + + def fetch_mouse_content( client: SlimsClient, mouse_name: str, -) -> Optional[dict]: +) -> SlimsMouseContent | dict | None: """Fetches mouse information for a mouse with labtracks id {mouse_name}""" mice = client.fetch( "Content", @@ -28,6 +62,12 @@ def fetch_mouse_content( ) else: logger.warning("Warning, Mouse not in SLIMS") - mouse_details = None + return + + try: + mouse = SlimsMouseContent.model_validate(mouse_details) + except ValidationError as e: + logger.error(f"SLIMS data validation failed, {repr(e)}") + return mouse_details.json_entity - return None if mouse_details is None else mouse_details.json_entity + return mouse diff --git a/src/aind_slims_api/unit.py b/src/aind_slims_api/unit.py new file mode 100644 index 0000000..73c2efc --- /dev/null +++ b/src/aind_slims_api/unit.py @@ -0,0 +1,20 @@ +"""Contains a model for a unit""" + +import logging +from typing import Optional + +from pydantic import Field + +from aind_slims_api.core import SlimsBaseModel + +logger = logging.getLogger() + + +class SlimsUnit(SlimsBaseModel): + """Model for unit information in SLIMS""" + + name: str = Field(..., alias="unit_name") + abbreviation: Optional[str] = Field("", alias="unit_abbreviation") + pk: int = Field(..., alias="unit_pk") + + _slims_table: str = "Unit" diff --git a/src/aind_slims_api/user.py b/src/aind_slims_api/user.py index 1e5ac46..74a6994 100644 --- a/src/aind_slims_api/user.py +++ b/src/aind_slims_api/user.py @@ -3,15 +3,31 @@ import logging from typing import Optional -from aind_slims_api.core import SlimsClient +from pydantic import Field, ValidationError + +from aind_slims_api.core import SlimsBaseModel, SlimsClient logger = logging.getLogger() +# TODO: Tighten this up once users are more commonly used +class SlimsUser(SlimsBaseModel): + """Model for user information in SLIMS""" + + username: str = Field(..., alias="user_userName") + first_name: Optional[str] = Field("", alias="user_firstName") + last_name: Optional[str] = Field("", alias="user_lastName") + full_name: Optional[str] = Field("", alias="user_fullName") + email: Optional[str] = Field("", alias="user_email") + pk: int = Field(..., alias="user_pk") + + _slims_table: str = "User" + + def fetch_user( client: SlimsClient, username: str, -) -> Optional[dict]: +) -> SlimsUser | dict | None: """Fetches user information for a user with username {username}""" users = client.fetch( "User", @@ -23,11 +39,17 @@ def fetch_user( if len(users) > 1: logger.warning( f"Warning, Multiple users in SLIMS with " - f"username {[u.json_entity for u in users]}, " + f"username {username}, " f"using pk={user_details.pk()}" ) else: logger.warning("Warning, User not in SLIMS") - user_details = None + return + + try: + user = SlimsUser.model_validate(user_details) + except ValidationError as e: + logger.error(f"SLIMS data validation failed, {repr(e)}") + return user_details.json_entity - return None if user_details is None else user_details.json_entity + return user diff --git a/tests/test_configuration.py b/tests/test_configuration.py index d1c5f92..ab12d38 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -10,6 +10,11 @@ class TestAindSlimsApiSettings(unittest.TestCase): """Tests methods in AindSlimsApiSettings class""" + @patch.dict( + os.environ, + {}, + clear=True, + ) def test_default_settings(self): """Tests that the class will be set with defaults""" default_settings = AindSlimsApiSettings() diff --git a/tests/test_core.py b/tests/test_core.py index 5445b91..6ac5923 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -11,6 +11,7 @@ from slims.internal import Record, _SlimsApiException from aind_slims_api.core import SlimsClient +from aind_slims_api.unit import SlimsUnit RESOURCES_DIR = Path(os.path.dirname(os.path.realpath(__file__))) / "resources" @@ -133,7 +134,10 @@ def test_add(self, mock_slims_add: MagicMock, mock_log: MagicMock): @patch("logging.Logger.info") @patch("slims.internal.Record.update") def test_update( - self, mock_update: MagicMock, mock_log: MagicMock, mock_fetch_by_pk: MagicMock + self, + mock_update: MagicMock, + mock_log: MagicMock, + mock_fetch_by_pk: MagicMock, ): """Tests update method success""" input_data = deepcopy(self.example_fetch_unit_response[0].json_entity) @@ -155,18 +159,60 @@ def test_update( @patch("logging.Logger.info") @patch("slims.internal.Record.update") def test_update_failure( - self, mock_update: MagicMock, mock_log: MagicMock, mock_fetch_by_pk: MagicMock + self, + mock_update: MagicMock, + mock_log: MagicMock, + mock_fetch_by_pk: MagicMock, ): """Tests update method when a failure occurs""" mock_fetch_by_pk.return_value = None with self.assertRaises(ValueError) as e: self.example_client.update(table="Unit", pk=30000, data={}) self.assertEqual( - 'No data in SLIMS "Unit" table for pk "30000"', e.exception.args[0] + 'No data in SLIMS "Unit" table for pk "30000"', + e.exception.args[0], ) mock_update.assert_not_called() mock_log.assert_not_called() + @patch("logging.Logger.info") + @patch("slims.slims.Slims.add") + def test_add_model(self, mock_slims_add: MagicMock, mock_log: MagicMock): + """Tests add_model method with mock mouse data""" + record = self.example_fetch_unit_response[0] + mock_slims_add.return_value = record + input_model = SlimsUnit.model_validate(record) + added = self.example_client.add_model(input_model) + self.assertEqual(input_model, added) + mock_log.assert_called_once_with("SLIMS Add: Unit/31") + + @patch("slims.slims.Slims.fetch_by_pk") + @patch("logging.Logger.info") + @patch("slims.internal.Record.update") + def test_update_model( + self, + mock_update: MagicMock, + mock_log: MagicMock, + mock_fetch_by_pk: MagicMock, + ): + """Tests update method success""" + input_data = deepcopy(self.example_fetch_unit_response[0].json_entity) + mock_record = Record( + json_entity=input_data, slims_api=self.example_client.db.slims_api + ) + mock_fetch_by_pk.return_value = mock_record + updated_model = SlimsUnit.model_validate(mock_record) + new_data = deepcopy(input_data) + new_data["columns"][0]["value"] = "PM^3" + mocked_updated_record = Record( + json_entity=new_data, slims_api=self.example_client.db.slims_api + ) + mock_update.return_value = mocked_updated_record + updated_model = SlimsUnit.model_validate(mocked_updated_record) + returned_model = self.example_client.update_model(updated_model) + self.assertEqual(updated_model, returned_model) + mock_log.assert_called_once_with("SLIMS Update: Unit/31") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_mouse.py b/tests/test_mouse.py index 9c698fb..9a8f72d 100644 --- a/tests/test_mouse.py +++ b/tests/test_mouse.py @@ -5,6 +5,7 @@ import unittest from pathlib import Path from unittest.mock import MagicMock, patch +from copy import deepcopy from slims.internal import Record @@ -37,7 +38,7 @@ def test_fetch_mouse_content_success(self, mock_fetch: MagicMock): mock_fetch.return_value = self.example_fetch_mouse_response mouse_details = fetch_mouse_content(self.example_client, mouse_name="123456") self.assertEqual( - self.example_fetch_mouse_response[0].json_entity, mouse_details + self.example_fetch_mouse_response[0].json_entity, mouse_details.json_entity ) @patch("logging.Logger.warning") @@ -63,12 +64,28 @@ def test_fetch_mouse_content_many_mouse( ] mouse_details = fetch_mouse_content(self.example_client, mouse_name="123456") self.assertEqual( - self.example_fetch_mouse_response[0].json_entity, mouse_details + self.example_fetch_mouse_response[0].json_entity, mouse_details.json_entity ) mock_log_warn.assert_called_with( "Warning, Multiple mice in SLIMS with barcode 123456, using pk=3038" ) + @patch("logging.Logger.error") + @patch("slims.slims.Slims.fetch") + def test_fetch_mouse_content_validation_fail( + self, mock_fetch: MagicMock, mock_log_error: MagicMock + ): + """Test fetch_mouse when successful""" + wrong_return = deepcopy(self.example_fetch_mouse_response) + wrong_return[0].cntn_cf_waterRestricted.value = 14 + mock_fetch.return_value = wrong_return + mouse_info = fetch_mouse_content(self.example_client, mouse_name="123456") + self.assertEqual( + self.example_fetch_mouse_response[0].json_entity, + mouse_info, + ) + mock_log_error.assert_called() + if __name__ == "__main__": unittest.main() diff --git a/tests/test_slimsmodel.py b/tests/test_slimsmodel.py new file mode 100644 index 0000000..a3ff69b --- /dev/null +++ b/tests/test_slimsmodel.py @@ -0,0 +1,104 @@ +""" Tests the generic SlimsBaseModel""" + +from datetime import datetime +from typing import Annotated +import unittest + +from pydantic import Field +from slims.internal import Record, Column + +from aind_slims_api.core import SlimsBaseModel, UnitSpec + + +class TestSlimsModel(unittest.TestCase): + """Example Test Class""" + + class TestModel(SlimsBaseModel, validate_assignment=True): + """Test case""" + + datefield: datetime = None + stringfield: str = None + quantfield: Annotated[float, UnitSpec("um", "nm")] = None + + def test_string_field(self): + """Test basic usage for SLIMS column to Model field""" + obj = self.TestModel() + obj.stringfield = Column( + { + "datatype": "STRING", + "name": "stringfield", + "value": "value", + } + ) + + self.assertEqual(obj.stringfield, "value") + + def test_quantity_field(self): + """Test validation/serialization of a quantity type, with unit""" + obj = self.TestModel() + obj.quantfield = Column( + { + "datatype": "QUANTITY", + "name": "quantfield", + "value": 28.28, + "unit": "um", + } + ) + + self.assertEqual(obj.quantfield, 28.28) + + serialized = obj.model_dump()["quantfield"] + expected = {"amount": 28.28, "unit_display": "um"} + + self.assertEqual(serialized, expected) + + def test_quantity_wrong_unit(self): + """Ensure you get an error with an unexpected unit""" + obj = self.TestModel() + with self.assertRaises(ValueError): + obj.quantfield = Column( + { + "datatype": "QUANTITY", + "name": "quantfield", + "value": 28.28, + "unit": "erg", + } + ) + + def test_alias(self): + """Test aliasing of fields""" + + class TestModelAlias(SlimsBaseModel): + """model with field aliases""" + + field: str = Field(..., alias="alias") + pk: int = Field(None, alias="cntn_pk") + + record = Record( + json_entity={ + "columns": [ + { + "datatype": "STRING", + "name": "alias", + "value": "value", + } + ] + }, + slims_api=None, + ) + obj = TestModelAlias.model_validate(record) + + self.assertEqual(obj.field, "value") + obj.field = "value2" + self.assertEqual(obj.field, "value2") + serialized = obj.model_dump(include="field", by_alias=True) + expected = {"alias": "value2"} + self.assertEqual(serialized, expected) + + def test_unitspec(self): + """Test unitspec with no arguments""" + self.assertRaises(ValueError, UnitSpec) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_user.py b/tests/test_user.py index 02a705d..1ea9a81 100644 --- a/tests/test_user.py +++ b/tests/test_user.py @@ -5,6 +5,7 @@ import unittest from pathlib import Path from unittest.mock import MagicMock, patch +from copy import deepcopy from slims.internal import Record @@ -36,7 +37,26 @@ def test_fetch_user_content_success(self, mock_fetch: MagicMock): """Test fetch_user when successful""" mock_fetch.return_value = self.example_fetch_user_response user_info = fetch_user(self.example_client, username="PersonA") - self.assertEqual(self.example_fetch_user_response[0].json_entity, user_info) + self.assertEqual( + self.example_fetch_user_response[0].json_entity, + user_info.json_entity, + ) + + @patch("logging.Logger.error") + @patch("slims.slims.Slims.fetch") + def test_fetch_user_content_validation_fail( + self, mock_fetch: MagicMock, mock_log_error: MagicMock + ): + """Test fetch_user when successful""" + wrong_return = deepcopy(self.example_fetch_user_response) + wrong_return[0].user_userName.value = 14 + mock_fetch.return_value = wrong_return + user_info = fetch_user(self.example_client, username="PersonA") + self.assertEqual( + self.example_fetch_user_response[0].json_entity, + user_info, + ) + mock_log_error.assert_called() @patch("logging.Logger.warning") @patch("slims.slims.Slims.fetch") @@ -60,11 +80,15 @@ def test_fetch_user_content_many_users( self.example_fetch_user_response[0], ] mock_fetch.return_value = mocked_response - user_info = fetch_user(self.example_client, username="PersonA") - self.assertEqual(self.example_fetch_user_response[0].json_entity, user_info) + username = "PersonA" + user_info = fetch_user(self.example_client, username=username) + self.assertEqual( + self.example_fetch_user_response[0].json_entity, + user_info.json_entity, + ) expected_warning = ( f"Warning, Multiple users in SLIMS with " - f"username {[u.json_entity for u in mocked_response]}, " + f"username {username}, " f"using pk={mocked_response[0].pk()}" ) mock_log_warn.assert_called_with(expected_warning)