From db44aa344edfa51fec9a9bc114a187b5e780f081 Mon Sep 17 00:00:00 2001 From: Brendan Cooley Date: Mon, 6 Nov 2023 13:07:50 -0500 Subject: [PATCH] wip: robustify array dtype inference, add pt custom fields to `Field()` --- src/patito/pydantic.py | 71 ++++++++++++++++++++-------------------- tests/test_dummy_data.py | 8 ++--- 2 files changed, 39 insertions(+), 40 deletions(-) diff --git a/src/patito/pydantic.py b/src/patito/pydantic.py index bd6fb3e..7ebcaa7 100644 --- a/src/patito/pydantic.py +++ b/src/patito/pydantic.py @@ -20,7 +20,7 @@ get_args, Sequence, Tuple, - Callable + Callable, ) import polars as pl @@ -88,7 +88,7 @@ class classproperty: """Equivalent to @property, but works on a class (doesn't require an instance). - + https://github.com/pola-rs/polars/blob/8d29d3cebec713363db4ad5d782c74047e24314d/py-polars/polars/datatypes/classes.py#L25C12-L25C12 """ @@ -103,7 +103,6 @@ def getter(self, method: Callable[..., Any]) -> Any: # noqa: D102 return self - class ModelMetaclass(PydanticModelMetaclass): """ Metclass used by patito.Model. @@ -119,7 +118,7 @@ class Model(BaseModel, metaclass=ModelMetaclass): if TYPE_CHECKING: model_fields: ClassVar[Dict[str, FieldInfo]] - + model_config = ConfigDict( ignored_types=(classproperty,), ) @@ -230,20 +229,13 @@ def _valid_dtypes( # noqa: C901 Returns: List of valid dtypes. None if no mapping exists. """ - if props.get("type") == "array": - array_props = props["items"] - item_dtypes = cls._valid_dtypes(column, array_props) - if item_dtypes is None: - raise NotImplementedError( - f"No valid dtype mapping found for column '{column}'." - ) - return [pl.List(dtype) for dtype in item_dtypes] - if "dtype" in props: def dtype_invalid(props: Dict) -> Tuple[bool, List[PolarsDataType]]: if "type" in props: - valid_pl_types = cls._pydantic_type_to_valid_polars_types(props) + valid_pl_types = cls._pydantic_type_to_valid_polars_types( + column, props + ) if props["dtype"] not in valid_pl_types: return True, valid_pl_types or [] elif "anyOf" in props: @@ -252,7 +244,7 @@ def dtype_invalid(props: Dict) -> Tuple[bool, List[PolarsDataType]]: continue else: valid_pl_types = cls._pydantic_type_to_valid_polars_types( - sub_props + column, sub_props ) if props["dtype"] not in valid_pl_types: return True, valid_pl_types or [] @@ -281,13 +273,25 @@ def dtype_invalid(props: Dict) -> Tuple[bool, List[PolarsDataType]]: ) return None - return cls._pydantic_type_to_valid_polars_types(props) + return cls._pydantic_type_to_valid_polars_types(column, props) - @staticmethod + @classmethod def _pydantic_type_to_valid_polars_types( + cls, + column: str, props: Dict, ) -> Optional[List[PolarsDataType]]: - if props["type"] == "integer": + if props["type"] == "array": + array_props = props["items"] + item_dtypes = ( + cls._valid_dtypes(column, array_props) if array_props else None + ) + if item_dtypes is None: + raise NotImplementedError( + f"No valid dtype mapping found for column '{column}'." + ) + return [pl.List(dtype) for dtype in item_dtypes] + elif props["type"] == "integer": return PL_INTEGER_DTYPES elif props["type"] == "number": if props.get("format") == "time-delta": @@ -593,7 +597,7 @@ def DataFrame( model=cls, # type: ignore ) - @classproperty + @classproperty def LazyFrame( cls: Type[ModelType], # type: ignore[misc] ) -> Type[LazyFrame[ModelType]]: # pyright: ignore @@ -1509,7 +1513,7 @@ def __init__( self, constraints: Optional[Union[pl.Expr, Sequence[pl.Expr]]] = None, derived_from: Optional[Union[str, pl.Expr]] = None, - dtype: Optional[pl.DataType] = None, + dtype: Optional[PolarsDataType] = None, unique: bool = False, **kwargs, ): @@ -1530,20 +1534,24 @@ def __init__( def Field( # noqa: C901 *args, + constraints: Optional[Union[pl.Expr, Sequence[pl.Expr]]] = None, + derived_from: Optional[Union[str, pl.Expr]] = None, + dtype: Optional[PolarsDataType] = None, + unique: bool = False, **kwargs, ) -> Any: - pt_kwargs = {k: kwargs.pop(k, None) for k in get_args(PT_INFO)} meta_kwargs = { k: v for k, v in kwargs.items() if k in fields.FieldInfo.metadata_lookup } - base_kwargs = { - k: v for k, v in kwargs.items() if k not in {**pt_kwargs, **meta_kwargs} - } + base_kwargs = {k: v for k, v in kwargs.items() if k not in meta_kwargs} finfo = fields.Field(*args, **base_kwargs) return FieldInfo( **finfo._attributes_set, **meta_kwargs, - **pt_kwargs, + constraints=constraints, + derived_from=derived_from, + dtype=dtype, + unique=unique, ) @@ -1562,19 +1570,10 @@ class FieldDoc: All rows must satisfy the given constraint. You can refer to the given column with ``pt.field``, which will automatically be replaced with ``polars.col()`` before evaluation. - unique (bool): All row values must be unique. + derived_from (Union[str, polars.Expr]): used to mark fields that are meant to be derived from other fields. Users can specify a polars expression that will be called to derive the column value when `pt.DataFrame.derive` is called. dtype (polars.datatype.DataType): The given dataframe column must have the given polars dtype, for instance ``polars.UInt64`` or ``pl.Float32``. - gt: All values must be greater than ``gt``. - ge: All values must be greater than or equal to ``ge``. - lt: All values must be less than ``lt``. - le: All values must be less than or equal to ``lt``. - multiple_of: All values must be multiples of the given value. - const (bool): If set to ``True`` `all` values must be equal to the provided - default value, the first argument provided to the ``Field`` constructor. - regex (str): UTF-8 string column must match regex pattern for all row values. - min_length (int): Minimum length of all string values in a UTF-8 column. - max_length (int): Maximum length of all string values in a UTF-8 column. + unique (bool): All row values must be unique. Return: FieldInfo: Object used to represent additional constraints put upon the given diff --git a/tests/test_dummy_data.py b/tests/test_dummy_data.py index ace5765..e3e6db2 100644 --- a/tests/test_dummy_data.py +++ b/tests/test_dummy_data.py @@ -1,10 +1,9 @@ """Test of functionality related to the generation of dummy data.""" from datetime import date, datetime -from typing import Optional +from typing import Optional, Literal, List import polars as pl import pytest -from typing_extensions import Literal import patito as pt @@ -52,11 +51,12 @@ class MyModel(pt.Model): a: int b: Optional[str] c: Optional[int] + d: Optional[List[str]] = pt.Field(dtype=pl.List(pl.Utf8)) df = MyModel.examples({"a": [1, 2]}) assert isinstance(df, pl.DataFrame) - assert df.dtypes == [pl.Int64, pl.Utf8, pl.Int64] - assert df.columns == ["a", "b", "c"] + assert df.dtypes == [pl.Int64, pl.Utf8, pl.Int64, pl.List] + assert df.columns == ["a", "b", "c", "d"] # A TypeError should be raised when you provide no column names with pytest.raises(