From 3d51cbb6957a88c00e4eae3690e640eccd0392b3 Mon Sep 17 00:00:00 2001 From: Dan LaManna Date: Thu, 24 Oct 2024 13:52:59 -0400 Subject: [PATCH] Make hierarchical diagnosis support passing in multiple values --- isic_metadata/metadata.py | 26 ++++++++- tests/test_fields.py | 40 ------------- tests/test_hierarchical_diagnosis.py | 87 ++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 42 deletions(-) create mode 100644 tests/test_hierarchical_diagnosis.py diff --git a/isic_metadata/metadata.py b/isic_metadata/metadata.py index 08dc9fc..c849857 100644 --- a/isic_metadata/metadata.py +++ b/isic_metadata/metadata.py @@ -284,10 +284,32 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - # See https://github.com/samuelcolvin/pydantic/issues/2285 for more detail @model_validator(mode="before") @classmethod - def build_extra(cls, values: dict[str, Any]) -> dict[str, Any]: + def handle_hierarchical_diagnosis_modes_and_unstructured_fields( + cls, values: dict[str, Any] + ) -> dict[str, Any]: + """ + Handle the case where hierarchical diagnosis values are passed in as multiple fields. + + Practically, ingesting data should never pass in multiple values but instead use the + colon-separated `diagnosis` field. This method is provided for the scenario where + data needs to be retrieved from the database (where it's stored multi-valued) and + revalidated. This method also handles putting any unrecognized fields into an unstructured + field. Unfortunately, pydantic doesn't yet support ordering different model validators so + these both need to be combined into one method. + """ + using_diagnoses_multi_values = any(f"diagnosis_{i}" in values for i in range(1, 6)) + using_diagnosis_single_value = bool(values.get("diagnosis")) + + if using_diagnoses_multi_values and using_diagnosis_single_value: + [values.pop(f"diagnosis_{i}", "") for i in range(1, 6)] + elif using_diagnoses_multi_values: + values["diagnosis"] = ":".join(values.pop(f"diagnosis_{i}", "") for i in range(1, 6)) + values["diagnosis"] = values["diagnosis"].rstrip(":") + + # handle unstructured fields + # See https://github.com/samuelcolvin/pydantic/issues/2285 for more detail structured_field_names = {field for field in cls.model_fields if field != "unstructured"} unstructured: dict[str, Any] = {} diff --git a/tests/test_fields.py b/tests/test_fields.py index 8b15e09..7a19db7 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -124,43 +124,3 @@ def test_clin_size_long_diam_mm_invalid(): MetadataRow.model_validate({"clin_size_long_diam_mm": "foo"}) assert len(excinfo.value.errors()) == 1 assert "Unable to parse value as a number" in convert_errors(excinfo.value)[0]["msg"] - - -@pytest.mark.parametrize( - ("raw", "parsed"), - [ - ("Benign", ["Benign"]), - ("Benign - Other", ["Benign", "Benign - Other"]), - ("Blue nevus", ["Benign", "Benign melanocytic proliferations", "Nevus", "Blue nevus"]), - ( - "Squamous cell carcinoma, NOS", - ["Malignant", "Malignant epidermal proliferations", "Squamous cell carcinoma, NOS"], - ), - ( - "Blue nevus, Sclerosing", - [ - "Benign", - "Benign melanocytic proliferations", - "Nevus", - "Blue nevus", - "Blue nevus, Sclerosing", - ], - ), - ], -) -def test_diagnosis(raw, parsed): - metadata = MetadataRow.model_validate({"diagnosis": raw}) - - for i, diagnosis in enumerate(parsed, start=1): - assert getattr(metadata, f"diagnosis_{i}") == diagnosis - - -def test_top_level_diagnosis_is_never_exported(): - metadata = MetadataRow.model_validate({"diagnosis": "Benign"}) - assert "diagnosis" not in metadata.model_dump() - assert metadata.diagnosis_1 == "Benign" - - -def test_diagnosis_enum_has_unique_terminal_values(): - terminal_nodes = [member.value.split(":")[-1] for member in DiagnosisEnum] - assert len(terminal_nodes) == len(set(terminal_nodes)) diff --git a/tests/test_hierarchical_diagnosis.py b/tests/test_hierarchical_diagnosis.py new file mode 100644 index 0000000..223809f --- /dev/null +++ b/tests/test_hierarchical_diagnosis.py @@ -0,0 +1,87 @@ +from pydantic import ValidationError +import pytest + +from isic_metadata.diagnosis_hierarchical import DiagnosisEnum +from isic_metadata.metadata import MetadataRow + + +@pytest.mark.parametrize( + ("raw", "parsed"), + [ + ("Benign", ["Benign"]), + ("Benign - Other", ["Benign", "Benign - Other"]), + ("Blue nevus", ["Benign", "Benign melanocytic proliferations", "Nevus", "Blue nevus"]), + ( + "Squamous cell carcinoma, NOS", + ["Malignant", "Malignant epidermal proliferations", "Squamous cell carcinoma, NOS"], + ), + ( + "Blue nevus, Sclerosing", + [ + "Benign", + "Benign melanocytic proliferations", + "Nevus", + "Blue nevus", + "Blue nevus, Sclerosing", + ], + ), + ], +) +def test_diagnosis(raw, parsed): + metadata = MetadataRow.model_validate({"diagnosis": raw}) + + for i, diagnosis in enumerate(parsed, start=1): + assert getattr(metadata, f"diagnosis_{i}") == diagnosis + + +def test_top_level_diagnosis_is_never_exported(): + metadata = MetadataRow.model_validate({"diagnosis": "Benign"}) + assert "diagnosis" not in metadata.model_dump() + assert metadata.diagnosis_1 == "Benign" + + +def test_diagnosis_enum_has_unique_terminal_values(): + terminal_nodes = [member.value.split(":")[-1] for member in DiagnosisEnum] + assert len(terminal_nodes) == len(set(terminal_nodes)) + + +def test_single_value_diagnosis_is_favored(): + # test that passing in a single diagnosis value is favored over multiple values. used + # for when data is coming from the database and potentially contains an existing + # 1..5 diagnosis and a newly updated single diagnosis. + with pytest.raises(ValidationError) as excinfo: + MetadataRow.model_validate( + { + "diagnosis": "Melanoma Invasive", + "nevus_type": "blue", + # these should be ignored + "diagnosis_1": "Benign", + "diagnosis_2": "Benign melanocytic proliferations", + "diagnosis_3": "Nevus", + } + ) + assert "Setting nevus_type is incompatible with diagnosis" in excinfo.value.errors()[0]["msg"] + + +def test_diagnosis_multiple_levels_is_coerced(): + # test that passing in diagnosis_1..5 is coerced into a single diagnosis field to handle + # cross field input validation + metadata = MetadataRow.model_validate({"diagnosis_1": "Benign"}) + assert metadata.diagnosis_1 == "Benign" + assert metadata.diagnosis_2 is None + assert metadata.diagnosis_3 is None + assert metadata.diagnosis_4 is None + assert metadata.diagnosis_5 is None + + +def test_diagnosis_validation_is_idempotent(): + # test that running model_validate on a MetadataRow multiple times does not change the + # output + metadata = MetadataRow.model_validate({"diagnosis": "Melanoma Invasive"}) + assert metadata.diagnosis_1 == "Malignant" + assert metadata.diagnosis_2 == "Malignant melanocytic proliferations (Melanoma)" + assert metadata.diagnosis_3 == "Melanoma Invasive" + metadata_2 = MetadataRow.model_validate( + metadata.model_dump(exclude_unset=True, exclude_none=True, exclude={"unstructured"}) + ) + assert metadata == metadata_2