Skip to content

Commit

Permalink
ser/deser datatypes
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasaarholt committed Dec 20, 2023
1 parent daf8f59 commit cead6ed
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 27 deletions.
72 changes: 60 additions & 12 deletions src/patito/_pydantic/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import json
from enum import Enum
from typing import (
Any,
Expand Down Expand Up @@ -88,10 +89,55 @@ def parse_composite_dtype(dtype: DataTypeClass | DataType) -> str:
return convert.DataTypeMappings.DTYPE_TO_FFINAME[dtype]


def dtype_from_string(v: str):
"""for deserialization"""
# TODO test all dtypes
return convert.dtype_short_repr_to_dtype(v)
def dtype_to_json(dtype: pl.DataType) -> str:
"""Serialize a polars dtype to a JSON string representation."""
return json.dumps(str(dtype))


def json_to_dtype(json_dtype_str: str) -> pl.DataType:
"""Deserialize a polars dtype from a JSON string representation."""
dtype = str_to_dtype(json.loads(json_dtype_str))
return dtype


def str_to_dtype(dtype_str: str) -> pl.DataType:
"""Return the corresponding polars dtype."""
from polars.datatypes.classes import ( # noqa F401
Array,
Binary,
Boolean,
Categorical,
Date,
Datetime,
Decimal,
Duration,
Enum,
Float32,
Float64,
Int8,
Int16,
Int32,
Int64,
List,
Null,
Object,
Struct,
Time,
UInt8,
UInt16,
UInt32,
UInt64,
Unknown,
Utf8,
)

from polars.datatypes import DataTypeClass

dtype = eval(dtype_str)
if isinstance(dtype, DataTypeClass):
# Float32() has string representation Float32, so we need to call it
dtype = dtype()
return dtype


def validate_polars_dtype(
Expand Down Expand Up @@ -150,13 +196,13 @@ def validate_annotation(annotation: type[Any] | None, column: Optional[str] = No


def valid_polars_dtypes_for_annotation(
annotation: type[Any] | None
annotation: type[Any] | None,
) -> FrozenSet[DataTypeClass | DataType]:
"""Returns a set of polars types that are valid for the given annotation. If the annotation is Any, returns all supported polars dtypes.
Args:
annotation (type[Any] | None): python type annotation
Returns:
FrozenSet[DataTypeClass | DataType]: set of polars dtypes
"""
Expand All @@ -167,7 +213,7 @@ def valid_polars_dtypes_for_annotation(


def default_polars_dtype_for_annotation(
annotation: type[Any] | None
annotation: type[Any] | None,
) -> DataTypeClass | DataType | None:
"""Returns the default polars dtype for the given annotation. If the annotation is Any, returns pl.Utf8. If no default dtype can be determined, returns None.
Expand All @@ -184,7 +230,7 @@ def default_polars_dtype_for_annotation(


def _valid_polars_dtypes_for_schema(
schema: Dict
schema: Dict,
) -> FrozenSet[DataTypeClass | DataType]:
valid_type_sets = []
if "anyOf" in schema:
Expand All @@ -195,7 +241,9 @@ def _valid_polars_dtypes_for_schema(
)
else:
valid_type_sets.append(set(_pydantic_subschema_to_valid_polars_types(schema)))
return set.intersection(*valid_type_sets) if valid_type_sets else frozenset() # pyright: ignore
return (
set.intersection(*valid_type_sets) if valid_type_sets else frozenset()
) # pyright: ignore


def _default_polars_dtype_for_schema(schema: Dict) -> DataTypeClass | DataType | None:
Expand Down Expand Up @@ -242,7 +290,7 @@ def _pydantic_subschema_to_valid_polars_types(


def _pydantic_subschema_to_default_dtype(
props: Dict
props: Dict,
) -> DataTypeClass | DataType | None:
if "type" not in props:
if "enum" in props:
Expand Down Expand Up @@ -317,7 +365,7 @@ def _pyd_type_to_default_dtype(


def _pyd_string_format_to_valid_dtypes(
string_format: PydanticStringFormat | None
string_format: PydanticStringFormat | None,
) -> FrozenSet[DataTypeClass | DataType]:
if string_format is None:
return STRING_DTYPES
Expand All @@ -334,7 +382,7 @@ def _pyd_string_format_to_valid_dtypes(


def _pyd_string_format_to_default_dtype(
string_format: PydanticStringFormat | None
string_format: PydanticStringFormat | None,
) -> DataTypeClass | DataType:
if string_format is None:
return pl.Utf8
Expand Down
40 changes: 25 additions & 15 deletions src/patito/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,22 @@
ConfigDict,
create_model,
field_serializer,
field_validator,
fields,
JsonDict,
)

# JsonDict,
from pydantic._internal._model_construction import (
ModelMetaclass as PydanticModelMetaclass,
)
from pydantic_core.core_schema import ValidationInfo

from patito._pydantic.dtypes import (
default_polars_dtype_for_annotation,
dtype_from_string,
dtype_to_json,
json_to_dtype,
parse_composite_dtype,
str_to_dtype,
valid_polars_dtypes_for_annotation,
validate_annotation,
validate_polars_dtype,
Expand Down Expand Up @@ -134,12 +139,16 @@ def column_infos(cls) -> Dict[str, ColumnInfo]:

def get_column_info(field: fields.FieldInfo) -> ColumnInfo:
if field.json_schema_extra is None:
return ColumnInfo()
return ColumnInfo(
dtype=pl.Null()
) # not sure if we should allow ColumnInfo without dtype
elif callable(field.json_schema_extra):
raise NotImplementedError(
"Callable json_schema_extra not supported by patito."
)
return field.json_schema_extra["column_info"] # pyright: ignore # TODO JsonDict fix
return field.json_schema_extra[
"column_info"
] # pyright: ignore # TODO JsonDict fix

return {k: get_column_info(v) for k, v in fields.items()}

Expand Down Expand Up @@ -1426,7 +1435,7 @@ class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
unique (bool): All row values must be unique.
"""

dtype: DataTypeClass | DataType | None = None
dtype: DataType
constraints: pl.Expr | Sequence[pl.Expr] | None = None
derived_from: str | pl.Expr | None = None
unique: bool | None = None
Expand All @@ -1453,26 +1462,27 @@ def _serialize_expr(self, expr: pl.Expr) -> Dict:
raise ValueError(f"Invalid type for expr: {type(expr)}")

@field_serializer("dtype")
def serialize_dtype(self, dtype: DataTypeClass | DataType | None) -> Any:
def serialize_dtype(self, dtype: DataType) -> Any:
"""
References
----------
[1] https://stackoverflow.com/questions/76572310/how-to-serialize-deserialize-polars-datatypes
"""
if dtype is None:
return None
elif isinstance(dtype, DataTypeClass) or isinstance(dtype, DataType):
return parse_composite_dtype(dtype)
else:
raise ValueError(f"Invalid type for dtype: {type(dtype)}")
return dtype_to_json(dtype)

@field_validator("dtype", mode="before")
@classmethod
def parse_json_dtype(cls, v: Any, info: ValidationInfo) -> pl.DataType:
if isinstance(v, str):
# info.field_name is the name of the field being validated
v = str_to_dtype(v)
return v


def Field(
*args,
dtype: DataTypeClass
| DataType
| None = None, # TODO figure out how to make nice signature
dtype: DataType, # TODO figure out how to make nice signature
constraints: pl.Expr | Sequence[pl.Expr] | None = None,
derived_from: str | pl.Expr | None = None,
unique: bool | None = None,
Expand Down

0 comments on commit cead6ed

Please sign in to comment.