Skip to content

Commit 2c13739

Browse files
authored
Pydantic-based type checking (#179)
* End-to-end implementation of pydantic-based type checking, controlled via CSP_PYDANTIC environment variable. Signed-off-by: Pascal Tomecek <[email protected]>
1 parent 70f43cb commit 2c13739

23 files changed

+1623
-131
lines changed

conda/dev-environment-unix.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ dependencies:
3232
- polars
3333
- psutil
3434
- pyarrow=16
35+
- pydantic>=2
3536
- pytest
3637
- pytest-asyncio
3738
- pytest-cov

conda/dev-environment-win.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ dependencies:
3131
- polars
3232
- psutil
3333
- pyarrow=16
34+
- pydantic>=2
3435
- pytest
3536
- pytest-asyncio
3637
- pytest-cov

csp/impl/types/common_definitions.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from enum import Enum, IntEnum, auto
44
from typing import Dict, List, Optional, Union
55

6-
from .container_type_normalizer import ContainerTypeNormalizer
7-
from .tstype import isTsBasket
8-
from .typing_utils import CspTypingUtils
6+
from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer
7+
from csp.impl.types.tstype import isTsBasket
8+
from csp.impl.types.typing_utils import CspTypingUtils
99

1010

1111
class OutputTypeError(TypeError):
@@ -53,7 +53,11 @@ def __new__(cls, *args, **kwargs):
5353
kwargs = {k: v if not isTsBasket(v) else OutputBasket(v) for k, v in kwargs.items()}
5454

5555
# stash for convenience later
56-
kwargs["__annotations__"] = kwargs
56+
kwargs["__annotations__"] = kwargs.copy()
57+
try:
58+
_make_pydantic_outputs(kwargs)
59+
except ImportError:
60+
pass
5761
return type("Outputs", (Outputs,), kwargs)
5862

5963
def __init__(self, *args, **kwargs):
@@ -62,6 +66,30 @@ def __init__(self, *args, **kwargs):
6266
...
6367

6468

69+
def _make_pydantic_outputs(kwargs):
70+
"""Add pydantic functionality to Outputs, if necessary"""
71+
from pydantic import create_model
72+
from pydantic_core import core_schema
73+
74+
from csp.impl.wiring.outputs import OutputsContainer
75+
76+
if None in kwargs:
77+
typ = ContainerTypeNormalizer.normalize_type(kwargs[None])
78+
model_fields = {"out": (typ, ...)}
79+
else:
80+
model_fields = {
81+
name: (ContainerTypeNormalizer.normalize_type(annotation), ...)
82+
for name, annotation in kwargs["__annotations__"].items()
83+
}
84+
config = {"arbitrary_types_allowed": True, "extra": "forbid", "strict": True}
85+
kwargs["__pydantic_model__"] = create_model("OutputsModel", __config__=config, **model_fields)
86+
kwargs["__get_pydantic_core_schema__"] = classmethod(
87+
lambda cls, source_type, handler: core_schema.no_info_after_validator_function(
88+
lambda v: OutputsContainer(**v.model_dump()), handler(cls.__pydantic_model__)
89+
)
90+
)
91+
92+
6593
class OutputBasket(object):
6694
def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: Optional[str] = None):
6795
"""we are abusing class construction here because we can't use classgetitem.
@@ -78,8 +106,10 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O
78106
if shape and shape_of:
79107
raise OutputBasketMixedShapeAndShapeOf()
80108
elif shape:
81-
if not isinstance(shape, (list, int, str)):
82-
raise OutputBasketWrongShapeType((list, int, str), shape)
109+
if CspTypingUtils.get_origin(typ) is Dict and not isinstance(shape, (list, tuple, str)):
110+
raise OutputBasketWrongShapeType((list, tuple, str), shape)
111+
if CspTypingUtils.get_origin(typ) is List and not isinstance(shape, (int, str)):
112+
raise OutputBasketWrongShapeType((int, str), shape)
83113
kwargs["shape"] = shape
84114
kwargs["shape_func"] = "with_shape"
85115
elif shape_of:
@@ -94,8 +124,23 @@ def __new__(cls, typ, shape: Optional[Union[List, int, str]] = None, shape_of: O
94124
# if shape is required, it will be enforced in the parser
95125
kwargs["shape"] = None
96126
kwargs["shape_func"] = None
127+
97128
return type("OutputBasket", (OutputBasket,), kwargs)
98129

130+
@classmethod
131+
def __get_pydantic_core_schema__(cls, source_type, handler):
132+
from pydantic_core import core_schema
133+
134+
def validate_shape(v, info):
135+
shape = cls.shape
136+
if isinstance(shape, int) and len(v) != shape:
137+
raise ValueError(f"Wrong shape: got {len(v)}, expecting {shape}")
138+
if isinstance(shape, (list, tuple)) and v.keys() != set(shape):
139+
raise ValueError(f"Wrong dict shape: got {v.keys()}, expecting {set(shape)}")
140+
return v
141+
142+
return core_schema.with_info_after_validator_function(validate_shape, handler(cls.typ))
143+
99144

100145
class OutputBasketContainer:
101146
SHAPE_FUNCS = None
@@ -170,7 +215,7 @@ def is_list_basket(self):
170215
return CspTypingUtils.get_origin(self.typ) is List
171216

172217
def __str__(self):
173-
return f"OutputBasketContainer(typ={self.typ}, shape={self.shape}, eval_type={self.eval_type}, lineno={self.lineno}, col_offset={self.col_offset})"
218+
return f"OutputBasketContainer(typ={self.typ}, shape={self.shape}, eval_type={self.eval_type})"
174219

175220
def __repr__(self):
176221
return str(self)
@@ -185,6 +230,7 @@ def create_wrapper(cls, eval_typ):
185230
"with_shape_of": OutputBasketContainer.create_wrapper(OutputBasketContainer.EvalType.WITH_SHAPE_OF),
186231
}
187232

233+
188234
InputDef = namedtuple("InputDef", ["name", "typ", "kind", "basket_kind", "ts_idx", "arg_idx"])
189235
OutputDef = namedtuple("OutputDef", ["name", "typ", "kind", "ts_idx", "shape"])
190236

csp/impl/types/instantiation_type_resolver.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def resolve_type(self, expected_type: type, new_type: type, raise_on_error=True)
3434
if CspTypingUtils.is_generic_container(expected_type):
3535
expected_type_base = CspTypingUtils.get_orig_base(expected_type)
3636
if expected_type_base is new_type:
37-
return expected_type
37+
return expected_type_base # If new_type is Generic and expected type is Generic[T], return Generic
3838
if CspTypingUtils.is_generic_container(new_type):
3939
expected_origin = CspTypingUtils.get_origin(expected_type)
4040
new_type_origin = CspTypingUtils.get_origin(new_type)
@@ -99,14 +99,7 @@ def __reduce__(self):
9999
class TypeMismatchError(TypeError):
100100
@classmethod
101101
def pretty_typename(cls, typ):
102-
if CspTypingUtils.is_generic_container(typ):
103-
return str(typ)
104-
elif CspTypingUtils.is_forward_ref(typ):
105-
return cls.pretty_typename(typ.__forward_arg__)
106-
elif isinstance(typ, type):
107-
return typ.__name__
108-
else:
109-
return str(typ)
102+
return CspTypingUtils.pretty_typename(typ)
110103

111104
@classmethod
112105
def get_tvar_info_str(cls, tvar_info):
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import numpy
2+
from pydantic import TypeAdapter, ValidationError
3+
from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args
4+
5+
import csp.typing
6+
from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer
7+
from csp.impl.types.instantiation_type_resolver import UpcastRegistry
8+
from csp.impl.types.numpy_type_util import map_numpy_dtype_to_python_type
9+
from csp.impl.types.pydantic_types import CspTypeVarType, adjust_annotations
10+
from csp.impl.types.typing_utils import CspTypingUtils, TsTypeValidator
11+
12+
13+
class TVarValidationContext:
14+
"""Custom validation context class for handling the special csp TVAR logic."""
15+
16+
# Note: some of the implementation is borrowed from InputInstanceTypeResolver
17+
18+
def __init__(
19+
self,
20+
forced_tvars: Union[Dict[str, Type], None] = None,
21+
allow_none_ts: bool = False,
22+
):
23+
# Can be set by a field validator to help track the source field of the different tvar refs
24+
self.field_name = None
25+
self._allow_none_ts = allow_none_ts
26+
self._forced_tvars: Dict[str, Type] = forced_tvars or {}
27+
self._tvar_type_refs: Dict[str, Set[Tuple[str, Type]]] = {}
28+
self._tvar_refs: Dict[str, Dict[str, List[Any]]] = {}
29+
self._tvars: Dict[str, Type] = {}
30+
self._conflicting_tvar_types = {}
31+
32+
if self._forced_tvars:
33+
config = {"arbitrary_types_allowed": True, "strict": True}
34+
self._forced_tvars = {k: ContainerTypeNormalizer.normalize_type(v) for k, v in self._forced_tvars.items()}
35+
self._forced_tvar_adapters = {
36+
tvar: TypeAdapter(List[t], config=config) for tvar, t in self._forced_tvars.items()
37+
}
38+
self._forced_tvar_validators = {tvar: TsTypeValidator(t) for tvar, t in self._forced_tvars.items()}
39+
self._tvars.update(**self._forced_tvars)
40+
41+
@property
42+
def tvars(self) -> Dict[str, Type]:
43+
return self._tvars
44+
45+
@property
46+
def allow_none_ts(self) -> bool:
47+
return self._allow_none_ts
48+
49+
def add_tvar_type_ref(self, tvar, value_type):
50+
if value_type is not numpy.ndarray:
51+
# Need to convert, i.e. [float] into List[float] when passed as a tref
52+
# Exclude ndarray because otherwise will get converted to NumpyNDArray[float], even for non-float
53+
# See, i.e. TestParquetReader.test_numpy_array_on_struct_with_field_map
54+
# TODO: This should be fixed in the ContainerTypeNormalizer
55+
value_type = ContainerTypeNormalizer.normalize_type(value_type)
56+
self._tvar_type_refs.setdefault(tvar, set()).add((self.field_name, value_type))
57+
58+
def add_tvar_ref(self, tvar, value):
59+
self._tvar_refs.setdefault(tvar, {}).setdefault(self.field_name, []).append(value)
60+
61+
def resolve_tvars(self):
62+
# Validate instances against forced tvars
63+
if self._forced_tvars:
64+
for tvar, adapter in self._forced_tvar_adapters.items():
65+
for field_name, field_values in self._tvar_refs.get(tvar, {}).items():
66+
# Validate using TypeAdapter(List[t]) in pydantic as it's faster than iterating through in python
67+
adapter.validate_python(field_values, strict=True)
68+
69+
for tvar, validator in self._forced_tvar_validators.items():
70+
for field_name, v in self._tvar_type_refs.get(tvar, set()):
71+
validator.validate(v)
72+
73+
# Add resolutions for references to tvar types (where type is inferred directly from type)
74+
for tvar, type_refs in self._tvar_type_refs.items():
75+
for field_name, value_type in type_refs:
76+
self._add_t_var_resolution(tvar, field_name, value_type)
77+
78+
# Add resolutions for references to tvar values (where type is inferred from type of value)
79+
for tvar, field_refs in self._tvar_refs.items():
80+
if self._forced_tvars and tvar in self._forced_tvars:
81+
# Already handled these
82+
continue
83+
for field_name, values in field_refs.items():
84+
for value in values:
85+
typ = type(value)
86+
if not CspTypingUtils.is_type_spec(typ):
87+
typ = ContainerTypeNormalizer.normalize_type(typ)
88+
self._add_t_var_resolution(tvar, field_name, typ, value if value is not typ else None)
89+
self._try_resolve_tvar_conflicts()
90+
91+
def revalidate(self, model):
92+
"""Once tvars have been resolved, need to revalidate input values against resolved tvars"""
93+
# Determine the fields that need to be revalidated because of tvar resolution
94+
# At the moment, that's only int fields that need to be converted to float
95+
# What does revalidation do?
96+
# - It makes sure that, edges declared as ts[float] inside a data structure, i.e. List[ts[float]],
97+
# get properly converted from, ts[int]
98+
# - It makes sure that scalar int values get converted to float
99+
# - It ignores validating a pass "int" type as a "float" type.
100+
fields_to_revalidate = set()
101+
for tvar, type_refs in self._tvar_type_refs.items():
102+
if self._tvars[tvar] is float:
103+
for field_name, value_type in type_refs:
104+
if field_name and value_type is int:
105+
fields_to_revalidate.add(field_name)
106+
for tvar, field_refs in self._tvar_refs.items():
107+
for field_name, values in field_refs.items():
108+
if field_name and any(type(value) is int for value in values): # noqa E721
109+
fields_to_revalidate.add(field_name)
110+
# Do the conversion only for the relevant fields
111+
for field in fields_to_revalidate:
112+
value = getattr(model, field)
113+
annotation = model.__annotations__[field]
114+
args = get_args(annotation)
115+
if args and args[0] is CspTypeVarType:
116+
# Skip revalidation of top-level type var types, as these have been handled via tvar resolution
117+
continue
118+
new_annotation = adjust_annotations(annotation, forced_tvars=self.tvars)
119+
try:
120+
new_value = TypeAdapter(new_annotation).validate_python(value)
121+
except ValidationError as e:
122+
msg = "\t" + str(e).replace("\n", "\n\t")
123+
raise ValueError(
124+
f"failed to revalidate field `{field}` after applying Tvars: {self._tvars}\n{msg}\n"
125+
) from None
126+
setattr(model, field, new_value)
127+
return model
128+
129+
def _add_t_var_resolution(self, tvar, field_name, resolved_type, arg=None):
130+
old_tvar_type = self._tvars.get(tvar)
131+
if old_tvar_type is None:
132+
self._tvars[tvar] = self._resolve_tvar_container_internal_types(tvar, resolved_type, arg)
133+
return
134+
elif self._forced_tvars and tvar in self._forced_tvars:
135+
# We must not change types, it's forced. So we will have to make sure that the new resolution matches the old one
136+
return
137+
138+
combined_type = UpcastRegistry.instance().resolve_type(resolved_type, old_tvar_type, raise_on_error=False)
139+
if combined_type is None:
140+
self._conflicting_tvar_types.setdefault(tvar, []).append(resolved_type)
141+
142+
if combined_type is not None and combined_type != old_tvar_type:
143+
self._tvars[tvar] = combined_type
144+
145+
def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise_on_error=True):
146+
"""This function takes, a container type (i.e. list) and an arg (i.e. 6) and infers the type of the TVar,
147+
i.e. typing.List[int]. For simple types, this function is a pass-through (i.e. arg is None).
148+
"""
149+
if arg is None:
150+
return container_typ
151+
if container_typ not in (set, dict, list, numpy.ndarray):
152+
return container_typ
153+
# It's possible that we provided type as scalar argument, that's illegal for containers, it must specify explicitly typed
154+
# list
155+
if arg is container_typ:
156+
if raise_on_error:
157+
raise ValueError(f"unable to resolve container type for type variable {tvar}: invalid argument {arg}")
158+
else:
159+
return False
160+
if len(arg) == 0:
161+
return container_typ
162+
res = None
163+
if isinstance(arg, set):
164+
first_val = arg.__iter__().__next__()
165+
first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val)
166+
if first_val_t:
167+
res = Set[first_val_t]
168+
elif isinstance(arg, list):
169+
first_val = arg.__iter__().__next__()
170+
first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val)
171+
if first_val_t:
172+
res = List[first_val_t]
173+
elif isinstance(arg, numpy.ndarray):
174+
python_type = map_numpy_dtype_to_python_type(arg.dtype)
175+
if arg.ndim > 1:
176+
res = csp.typing.NumpyNDArray[python_type]
177+
else:
178+
res = csp.typing.Numpy1DArray[python_type]
179+
else:
180+
first_k, first_val = arg.items().__iter__().__next__()
181+
first_key_t = self._resolve_tvar_container_internal_types(tvar, type(first_k), first_k)
182+
first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val)
183+
if first_key_t and first_val_t:
184+
res = Dict[first_key_t, first_val_t]
185+
if not res and raise_on_error:
186+
raise ValueError(f"unable to resolve container type for type variable {tvar}.")
187+
return res
188+
189+
def _try_resolve_tvar_conflicts(self):
190+
for tvar, conflicting_types in self._conflicting_tvar_types.items():
191+
# Consider the case:
192+
# f(x : 'T', y:'T', z : 'T')
193+
# f(1, Dummy(), object())
194+
# The resolution between x and y will fail, while resolution between x and z will be object. After we resolve all,
195+
# the tvars resolution should have the most primitive subtype (object in this case) and we can now resolve Dummy to
196+
# object as well
197+
resolved_type = self._tvars.get(tvar)
198+
assert resolved_type, f'"{tvar}" was not resolved'
199+
for conflicting_type in conflicting_types:
200+
if (
201+
UpcastRegistry.instance().resolve_type(resolved_type, conflicting_type, raise_on_error=False)
202+
is not resolved_type
203+
):
204+
raise ValueError(f"Conflicting type resolution for {tvar}: {resolved_type, conflicting_type}")

0 commit comments

Comments
 (0)