Skip to content

Commit

Permalink
wip: robustify array dtype inference, add pt custom fields to Field()
Browse files Browse the repository at this point in the history
brendancooley committed Nov 6, 2023
1 parent a63db3f commit db44aa3
Showing 2 changed files with 39 additions and 40 deletions.
71 changes: 35 additions & 36 deletions src/patito/pydantic.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,7 @@
get_args,
Sequence,
Tuple,
Callable
Callable,
)

import polars as pl
@@ -88,7 +88,7 @@

class classproperty:
"""Equivalent to @property, but works on a class (doesn't require an instance).
https://github.com/pola-rs/polars/blob/8d29d3cebec713363db4ad5d782c74047e24314d/py-polars/polars/datatypes/classes.py#L25C12-L25C12
"""

@@ -103,7 +103,6 @@ def getter(self, method: Callable[..., Any]) -> Any: # noqa: D102
return self



class ModelMetaclass(PydanticModelMetaclass):
"""
Metclass used by patito.Model.
@@ -119,7 +118,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):

if TYPE_CHECKING:
model_fields: ClassVar[Dict[str, FieldInfo]]

model_config = ConfigDict(
ignored_types=(classproperty,),
)
@@ -230,20 +229,13 @@ def _valid_dtypes( # noqa: C901
Returns:
List of valid dtypes. None if no mapping exists.
"""
if props.get("type") == "array":
array_props = props["items"]
item_dtypes = cls._valid_dtypes(column, array_props)
if item_dtypes is None:
raise NotImplementedError(
f"No valid dtype mapping found for column '{column}'."
)
return [pl.List(dtype) for dtype in item_dtypes]

if "dtype" in props:

def dtype_invalid(props: Dict) -> Tuple[bool, List[PolarsDataType]]:
if "type" in props:
valid_pl_types = cls._pydantic_type_to_valid_polars_types(props)
valid_pl_types = cls._pydantic_type_to_valid_polars_types(
column, props
)
if props["dtype"] not in valid_pl_types:
return True, valid_pl_types or []
elif "anyOf" in props:
@@ -252,7 +244,7 @@ def dtype_invalid(props: Dict) -> Tuple[bool, List[PolarsDataType]]:
continue
else:
valid_pl_types = cls._pydantic_type_to_valid_polars_types(
sub_props
column, sub_props
)
if props["dtype"] not in valid_pl_types:
return True, valid_pl_types or []
@@ -281,13 +273,25 @@ def dtype_invalid(props: Dict) -> Tuple[bool, List[PolarsDataType]]:
)
return None

return cls._pydantic_type_to_valid_polars_types(props)
return cls._pydantic_type_to_valid_polars_types(column, props)

@staticmethod
@classmethod
def _pydantic_type_to_valid_polars_types(
cls,
column: str,
props: Dict,
) -> Optional[List[PolarsDataType]]:
if props["type"] == "integer":
if props["type"] == "array":

This comment has been minimized.

Copy link
@ion-elgreco

ion-elgreco Nov 6, 2023

Contributor

Will this support list[list[list[str]]] for example?

This comment has been minimized.

Copy link
@brendancooley

brendancooley Nov 8, 2023

Author Contributor

I believe so but we ought to add a test!

array_props = props["items"]
item_dtypes = (
cls._valid_dtypes(column, array_props) if array_props else None
)
if item_dtypes is None:
raise NotImplementedError(
f"No valid dtype mapping found for column '{column}'."
)
return [pl.List(dtype) for dtype in item_dtypes]
elif props["type"] == "integer":
return PL_INTEGER_DTYPES
elif props["type"] == "number":
if props.get("format") == "time-delta":
@@ -593,7 +597,7 @@ def DataFrame(
model=cls, # type: ignore
)

@classproperty
@classproperty
def LazyFrame(
cls: Type[ModelType], # type: ignore[misc]
) -> Type[LazyFrame[ModelType]]: # pyright: ignore
@@ -1509,7 +1513,7 @@ def __init__(
self,
constraints: Optional[Union[pl.Expr, Sequence[pl.Expr]]] = None,
derived_from: Optional[Union[str, pl.Expr]] = None,
dtype: Optional[pl.DataType] = None,
dtype: Optional[PolarsDataType] = None,
unique: bool = False,
**kwargs,
):
@@ -1530,20 +1534,24 @@ def __init__(

def Field( # noqa: C901
*args,
constraints: Optional[Union[pl.Expr, Sequence[pl.Expr]]] = None,
derived_from: Optional[Union[str, pl.Expr]] = None,
dtype: Optional[PolarsDataType] = None,
unique: bool = False,
**kwargs,
) -> Any:
pt_kwargs = {k: kwargs.pop(k, None) for k in get_args(PT_INFO)}
meta_kwargs = {
k: v for k, v in kwargs.items() if k in fields.FieldInfo.metadata_lookup
}
base_kwargs = {
k: v for k, v in kwargs.items() if k not in {**pt_kwargs, **meta_kwargs}
}
base_kwargs = {k: v for k, v in kwargs.items() if k not in meta_kwargs}
finfo = fields.Field(*args, **base_kwargs)
return FieldInfo(
**finfo._attributes_set,
**meta_kwargs,
**pt_kwargs,
constraints=constraints,
derived_from=derived_from,
dtype=dtype,
unique=unique,
)


@@ -1562,19 +1570,10 @@ class FieldDoc:
All rows must satisfy the given constraint. You can refer to the given column
with ``pt.field``, which will automatically be replaced with
``polars.col(<field_name>)`` before evaluation.
unique (bool): All row values must be unique.
derived_from (Union[str, polars.Expr]): used to mark fields that are meant to be derived from other fields. Users can specify a polars expression that will be called to derive the column value when `pt.DataFrame.derive` is called.
dtype (polars.datatype.DataType): The given dataframe column must have the given
polars dtype, for instance ``polars.UInt64`` or ``pl.Float32``.
gt: All values must be greater than ``gt``.
ge: All values must be greater than or equal to ``ge``.
lt: All values must be less than ``lt``.
le: All values must be less than or equal to ``lt``.
multiple_of: All values must be multiples of the given value.
const (bool): If set to ``True`` `all` values must be equal to the provided
default value, the first argument provided to the ``Field`` constructor.
regex (str): UTF-8 string column must match regex pattern for all row values.
min_length (int): Minimum length of all string values in a UTF-8 column.
max_length (int): Maximum length of all string values in a UTF-8 column.
unique (bool): All row values must be unique.
Return:
FieldInfo: Object used to represent additional constraints put upon the given
8 changes: 4 additions & 4 deletions tests/test_dummy_data.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Test of functionality related to the generation of dummy data."""
from datetime import date, datetime
from typing import Optional
from typing import Optional, Literal, List

import polars as pl
import pytest
from typing_extensions import Literal

import patito as pt

@@ -52,11 +51,12 @@ class MyModel(pt.Model):
a: int
b: Optional[str]
c: Optional[int]
d: Optional[List[str]] = pt.Field(dtype=pl.List(pl.Utf8))

df = MyModel.examples({"a": [1, 2]})
assert isinstance(df, pl.DataFrame)
assert df.dtypes == [pl.Int64, pl.Utf8, pl.Int64]
assert df.columns == ["a", "b", "c"]
assert df.dtypes == [pl.Int64, pl.Utf8, pl.Int64, pl.List]
assert df.columns == ["a", "b", "c", "d"]

# A TypeError should be raised when you provide no column names
with pytest.raises(

0 comments on commit db44aa3

Please sign in to comment.