|
| 1 | +import numpy |
| 2 | +from pydantic import TypeAdapter, ValidationError |
| 3 | +from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args |
| 4 | + |
| 5 | +import csp.typing |
| 6 | +from csp.impl.types.container_type_normalizer import ContainerTypeNormalizer |
| 7 | +from csp.impl.types.instantiation_type_resolver import UpcastRegistry |
| 8 | +from csp.impl.types.numpy_type_util import map_numpy_dtype_to_python_type |
| 9 | +from csp.impl.types.pydantic_types import CspTypeVarType, adjust_annotations |
| 10 | +from csp.impl.types.typing_utils import CspTypingUtils, TsTypeValidator |
| 11 | + |
| 12 | + |
| 13 | +class TVarValidationContext: |
| 14 | + """Custom validation context class for handling the special csp TVAR logic.""" |
| 15 | + |
| 16 | + # Note: some of the implementation is borrowed from InputInstanceTypeResolver |
| 17 | + |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + forced_tvars: Union[Dict[str, Type], None] = None, |
| 21 | + allow_none_ts: bool = False, |
| 22 | + ): |
| 23 | + # Can be set by a field validator to help track the source field of the different tvar refs |
| 24 | + self.field_name = None |
| 25 | + self._allow_none_ts = allow_none_ts |
| 26 | + self._forced_tvars: Dict[str, Type] = forced_tvars or {} |
| 27 | + self._tvar_type_refs: Dict[str, Set[Tuple[str, Type]]] = {} |
| 28 | + self._tvar_refs: Dict[str, Dict[str, List[Any]]] = {} |
| 29 | + self._tvars: Dict[str, Type] = {} |
| 30 | + self._conflicting_tvar_types = {} |
| 31 | + |
| 32 | + if self._forced_tvars: |
| 33 | + config = {"arbitrary_types_allowed": True, "strict": True} |
| 34 | + self._forced_tvars = {k: ContainerTypeNormalizer.normalize_type(v) for k, v in self._forced_tvars.items()} |
| 35 | + self._forced_tvar_adapters = { |
| 36 | + tvar: TypeAdapter(List[t], config=config) for tvar, t in self._forced_tvars.items() |
| 37 | + } |
| 38 | + self._forced_tvar_validators = {tvar: TsTypeValidator(t) for tvar, t in self._forced_tvars.items()} |
| 39 | + self._tvars.update(**self._forced_tvars) |
| 40 | + |
| 41 | + @property |
| 42 | + def tvars(self) -> Dict[str, Type]: |
| 43 | + return self._tvars |
| 44 | + |
| 45 | + @property |
| 46 | + def allow_none_ts(self) -> bool: |
| 47 | + return self._allow_none_ts |
| 48 | + |
| 49 | + def add_tvar_type_ref(self, tvar, value_type): |
| 50 | + if value_type is not numpy.ndarray: |
| 51 | + # Need to convert, i.e. [float] into List[float] when passed as a tref |
| 52 | + # Exclude ndarray because otherwise will get converted to NumpyNDArray[float], even for non-float |
| 53 | + # See, i.e. TestParquetReader.test_numpy_array_on_struct_with_field_map |
| 54 | + # TODO: This should be fixed in the ContainerTypeNormalizer |
| 55 | + value_type = ContainerTypeNormalizer.normalize_type(value_type) |
| 56 | + self._tvar_type_refs.setdefault(tvar, set()).add((self.field_name, value_type)) |
| 57 | + |
| 58 | + def add_tvar_ref(self, tvar, value): |
| 59 | + self._tvar_refs.setdefault(tvar, {}).setdefault(self.field_name, []).append(value) |
| 60 | + |
| 61 | + def resolve_tvars(self): |
| 62 | + # Validate instances against forced tvars |
| 63 | + if self._forced_tvars: |
| 64 | + for tvar, adapter in self._forced_tvar_adapters.items(): |
| 65 | + for field_name, field_values in self._tvar_refs.get(tvar, {}).items(): |
| 66 | + # Validate using TypeAdapter(List[t]) in pydantic as it's faster than iterating through in python |
| 67 | + adapter.validate_python(field_values, strict=True) |
| 68 | + |
| 69 | + for tvar, validator in self._forced_tvar_validators.items(): |
| 70 | + for field_name, v in self._tvar_type_refs.get(tvar, set()): |
| 71 | + validator.validate(v) |
| 72 | + |
| 73 | + # Add resolutions for references to tvar types (where type is inferred directly from type) |
| 74 | + for tvar, type_refs in self._tvar_type_refs.items(): |
| 75 | + for field_name, value_type in type_refs: |
| 76 | + self._add_t_var_resolution(tvar, field_name, value_type) |
| 77 | + |
| 78 | + # Add resolutions for references to tvar values (where type is inferred from type of value) |
| 79 | + for tvar, field_refs in self._tvar_refs.items(): |
| 80 | + if self._forced_tvars and tvar in self._forced_tvars: |
| 81 | + # Already handled these |
| 82 | + continue |
| 83 | + for field_name, values in field_refs.items(): |
| 84 | + for value in values: |
| 85 | + typ = type(value) |
| 86 | + if not CspTypingUtils.is_type_spec(typ): |
| 87 | + typ = ContainerTypeNormalizer.normalize_type(typ) |
| 88 | + self._add_t_var_resolution(tvar, field_name, typ, value if value is not typ else None) |
| 89 | + self._try_resolve_tvar_conflicts() |
| 90 | + |
| 91 | + def revalidate(self, model): |
| 92 | + """Once tvars have been resolved, need to revalidate input values against resolved tvars""" |
| 93 | + # Determine the fields that need to be revalidated because of tvar resolution |
| 94 | + # At the moment, that's only int fields that need to be converted to float |
| 95 | + # What does revalidation do? |
| 96 | + # - It makes sure that, edges declared as ts[float] inside a data structure, i.e. List[ts[float]], |
| 97 | + # get properly converted from, ts[int] |
| 98 | + # - It makes sure that scalar int values get converted to float |
| 99 | + # - It ignores validating a pass "int" type as a "float" type. |
| 100 | + fields_to_revalidate = set() |
| 101 | + for tvar, type_refs in self._tvar_type_refs.items(): |
| 102 | + if self._tvars[tvar] is float: |
| 103 | + for field_name, value_type in type_refs: |
| 104 | + if field_name and value_type is int: |
| 105 | + fields_to_revalidate.add(field_name) |
| 106 | + for tvar, field_refs in self._tvar_refs.items(): |
| 107 | + for field_name, values in field_refs.items(): |
| 108 | + if field_name and any(type(value) is int for value in values): # noqa E721 |
| 109 | + fields_to_revalidate.add(field_name) |
| 110 | + # Do the conversion only for the relevant fields |
| 111 | + for field in fields_to_revalidate: |
| 112 | + value = getattr(model, field) |
| 113 | + annotation = model.__annotations__[field] |
| 114 | + args = get_args(annotation) |
| 115 | + if args and args[0] is CspTypeVarType: |
| 116 | + # Skip revalidation of top-level type var types, as these have been handled via tvar resolution |
| 117 | + continue |
| 118 | + new_annotation = adjust_annotations(annotation, forced_tvars=self.tvars) |
| 119 | + try: |
| 120 | + new_value = TypeAdapter(new_annotation).validate_python(value) |
| 121 | + except ValidationError as e: |
| 122 | + msg = "\t" + str(e).replace("\n", "\n\t") |
| 123 | + raise ValueError( |
| 124 | + f"failed to revalidate field `{field}` after applying Tvars: {self._tvars}\n{msg}\n" |
| 125 | + ) from None |
| 126 | + setattr(model, field, new_value) |
| 127 | + return model |
| 128 | + |
| 129 | + def _add_t_var_resolution(self, tvar, field_name, resolved_type, arg=None): |
| 130 | + old_tvar_type = self._tvars.get(tvar) |
| 131 | + if old_tvar_type is None: |
| 132 | + self._tvars[tvar] = self._resolve_tvar_container_internal_types(tvar, resolved_type, arg) |
| 133 | + return |
| 134 | + elif self._forced_tvars and tvar in self._forced_tvars: |
| 135 | + # We must not change types, it's forced. So we will have to make sure that the new resolution matches the old one |
| 136 | + return |
| 137 | + |
| 138 | + combined_type = UpcastRegistry.instance().resolve_type(resolved_type, old_tvar_type, raise_on_error=False) |
| 139 | + if combined_type is None: |
| 140 | + self._conflicting_tvar_types.setdefault(tvar, []).append(resolved_type) |
| 141 | + |
| 142 | + if combined_type is not None and combined_type != old_tvar_type: |
| 143 | + self._tvars[tvar] = combined_type |
| 144 | + |
| 145 | + def _resolve_tvar_container_internal_types(self, tvar, container_typ, arg, raise_on_error=True): |
| 146 | + """This function takes, a container type (i.e. list) and an arg (i.e. 6) and infers the type of the TVar, |
| 147 | + i.e. typing.List[int]. For simple types, this function is a pass-through (i.e. arg is None). |
| 148 | + """ |
| 149 | + if arg is None: |
| 150 | + return container_typ |
| 151 | + if container_typ not in (set, dict, list, numpy.ndarray): |
| 152 | + return container_typ |
| 153 | + # It's possible that we provided type as scalar argument, that's illegal for containers, it must specify explicitly typed |
| 154 | + # list |
| 155 | + if arg is container_typ: |
| 156 | + if raise_on_error: |
| 157 | + raise ValueError(f"unable to resolve container type for type variable {tvar}: invalid argument {arg}") |
| 158 | + else: |
| 159 | + return False |
| 160 | + if len(arg) == 0: |
| 161 | + return container_typ |
| 162 | + res = None |
| 163 | + if isinstance(arg, set): |
| 164 | + first_val = arg.__iter__().__next__() |
| 165 | + first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val) |
| 166 | + if first_val_t: |
| 167 | + res = Set[first_val_t] |
| 168 | + elif isinstance(arg, list): |
| 169 | + first_val = arg.__iter__().__next__() |
| 170 | + first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val) |
| 171 | + if first_val_t: |
| 172 | + res = List[first_val_t] |
| 173 | + elif isinstance(arg, numpy.ndarray): |
| 174 | + python_type = map_numpy_dtype_to_python_type(arg.dtype) |
| 175 | + if arg.ndim > 1: |
| 176 | + res = csp.typing.NumpyNDArray[python_type] |
| 177 | + else: |
| 178 | + res = csp.typing.Numpy1DArray[python_type] |
| 179 | + else: |
| 180 | + first_k, first_val = arg.items().__iter__().__next__() |
| 181 | + first_key_t = self._resolve_tvar_container_internal_types(tvar, type(first_k), first_k) |
| 182 | + first_val_t = self._resolve_tvar_container_internal_types(tvar, type(first_val), first_val) |
| 183 | + if first_key_t and first_val_t: |
| 184 | + res = Dict[first_key_t, first_val_t] |
| 185 | + if not res and raise_on_error: |
| 186 | + raise ValueError(f"unable to resolve container type for type variable {tvar}.") |
| 187 | + return res |
| 188 | + |
| 189 | + def _try_resolve_tvar_conflicts(self): |
| 190 | + for tvar, conflicting_types in self._conflicting_tvar_types.items(): |
| 191 | + # Consider the case: |
| 192 | + # f(x : 'T', y:'T', z : 'T') |
| 193 | + # f(1, Dummy(), object()) |
| 194 | + # The resolution between x and y will fail, while resolution between x and z will be object. After we resolve all, |
| 195 | + # the tvars resolution should have the most primitive subtype (object in this case) and we can now resolve Dummy to |
| 196 | + # object as well |
| 197 | + resolved_type = self._tvars.get(tvar) |
| 198 | + assert resolved_type, f'"{tvar}" was not resolved' |
| 199 | + for conflicting_type in conflicting_types: |
| 200 | + if ( |
| 201 | + UpcastRegistry.instance().resolve_type(resolved_type, conflicting_type, raise_on_error=False) |
| 202 | + is not resolved_type |
| 203 | + ): |
| 204 | + raise ValueError(f"Conflicting type resolution for {tvar}: {resolved_type, conflicting_type}") |
0 commit comments