diff --git a/src/patito/validators.py b/src/patito/validators.py index 7112898..399ddea 100644 --- a/src/patito/validators.py +++ b/src/patito/validators.py @@ -175,6 +175,8 @@ def _find_errors( # noqa: C901 .select(column) # Remove those rows that do not contain lists at all .filter(pl.col(column).is_not_null()) + # Remove empty lists + .filter(pl.col(column).list.len() > 0) # Convert lists of N items to N individual rows .explode(column) # Calculate how many nulls are present in lists diff --git a/tests/test_validators.py b/tests/test_validators.py index d0c669d..62367a1 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -638,6 +638,35 @@ class NestedPositiveStructModel(pt.Model): NestedPositiveStructModel.validate(bad_df) +def test_empty_list_validation() -> None: + """Test validation of model with empty lists.""" + + class TestModel(pt.Model): + list_field: list[str] + + # validate presence of an empty list + df = pl.DataFrame({"list_field": [["a", "b"], []]}) + TestModel.validate(df) + + # validate when all lists are empty, so long as the schema is correct + df = pl.DataFrame( + {"list_field": [[], []]}, schema={"list_field": pl.List(pl.String)} + ) + TestModel.validate(df) + + class NestedTestModel(pt.Model): + nested_list_field: list[list[str]] + + df = pl.DataFrame({"nested_list_field": [[["a", "b"], ["c"]], []]}) + NestedTestModel.validate(df) + + df = pl.DataFrame( + {"nested_list_field": [[], []]}, + schema={"nested_list_field": pl.List(pl.List(pl.String))}, + ) + NestedTestModel.validate(df) + + def test_list_struct_validation() -> None: """Test validation of model with list of structs column."""