diff --git a/docs/requirements.txt b/docs/requirements.txt index a575fdd3..5e98bc20 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -5,7 +5,7 @@ Jinja2<3.1 networkx openpyxl pandas>=1.1 -pydantic<2 +pydantic>=2 pydot pyyaml # Requirements file for ReadTheDocs, check .readthedocs.yml. diff --git a/setup.cfg b/setup.cfg index 00cc9ce3..d14c3729 100644 --- a/setup.cfg +++ b/setup.cfg @@ -52,7 +52,7 @@ install_requires = networkx flatten_dict openpyxl - pydantic<2 + pydantic>=2 [options.packages.find] where = src exclude = diff --git a/src/otoole/preprocess/validate_config.py b/src/otoole/preprocess/validate_config.py index b6d8d17b..7903b0ab 100644 --- a/src/otoole/preprocess/validate_config.py +++ b/src/otoole/preprocess/validate_config.py @@ -3,7 +3,10 @@ import logging from typing import List, Optional, Union -from pydantic import BaseModel, root_validator, validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator + +# from pydantic import FieldValidationInfo + logger = logging.getLogger(__name__) @@ -11,39 +14,41 @@ class UserDefinedValue(BaseModel): """Represents any user defined value""" + model_config = ConfigDict(extra="forbid") + name: str type: str dtype: str - defined_sets: Optional[List[str]] - indices: Optional[List[str]] - default: Optional[Union[int, float]] - calculated: Optional[bool] - short_name: Optional[str] + defined_sets: Optional[List[str]] = None + indices: Optional[List[str]] = None + default: Optional[Union[int, float]] = None + calculated: Optional[bool] = None + short_name: Optional[str] = None - @validator("type") + @field_validator("type") @classmethod - def check_param_type(cls, value, values): + def check_param_type(cls, value, info): if value not in ["param", "result", "set"]: raise ValueError( - f"{values['name']} -> Type must be 'param', 'result', or 'set'" + f"{info.field_name} -> Type must be 'param', 'result', or 'set'" ) return value - @validator("name", "short_name") + @field_validator("name", "short_name") @classmethod # for linting purposes def check_name_for_spaces(cls, value): if " " in value: raise ValueError(f"{value} -> Name can not have spaces") return value - @validator("name", "short_name") + @field_validator("name", "short_name") @classmethod def check_name_for_numbers(cls, value): if any(char.isdigit() for char in value): raise ValueError(f"{value} -> Name can not have digits") return value - @validator("name", "short_name") + @field_validator("name", "short_name") @classmethod def check_name_for_special_chars(cls, value): # removed underscore from the recommeded special char list @@ -54,7 +59,7 @@ def check_name_for_special_chars(cls, value): ) return value - @root_validator(pre=True) + @model_validator(mode="before") @classmethod def check_name_length(cls, values): if len(values["name"]) > 31: @@ -69,32 +74,29 @@ def check_name_length(cls, values): ) return values - class Config: - extra = "forbid" - class UserDefinedSet(UserDefinedValue): """Represents a set""" - @validator("dtype") + @field_validator("dtype") @classmethod - def check_dtype(cls, value, values): + def check_dtype(cls, value, info): if value not in ["str", "int"]: - raise ValueError(f"{values['name']} -> Value must be a 'str' or 'int'") + raise ValueError(f"{info.field_name} -> Value must be a 'str' or 'int'") return value class UserDefinedParameter(UserDefinedValue): """Represents a parameter""" - @validator("dtype") + @field_validator("dtype") @classmethod - def check_dtype(cls, value, values): + def check_dtype(cls, value, info): if value not in ["float", "int"]: - raise ValueError(f"{values['name']} -> Value must be an 'int' or 'float'") + raise ValueError(f"{info.field_name} -> Value must be an 'int' or 'float'") return value - @root_validator(pre=True) + @model_validator(mode="before") @classmethod def check_required_inputs(cls, values): required = ["default", "defined_sets", "indices"] @@ -104,38 +106,36 @@ def check_required_inputs(cls, values): ) return values - @root_validator(pre=True) - @classmethod - def check_index_in_set(cls, values): - if not all(i in values["defined_sets"] for i in values["indices"]): - raise ValueError(f"{values['name']} -> Index not in user supplied sets") - return values + @model_validator(mode="after") + def check_index_in_set(self): + if not all(i in self.defined_sets for i in self.indices): + raise ValueError(f"{self.name} -> Index not in user supplied sets") + return self - @root_validator(pre=True) - @classmethod - def check_dtype_default(cls, values): - dtype_input = values["dtype"] - dtype_default = type(values["default"]).__name__ + @model_validator(mode="after") + def check_dtype_default(self): + dtype_input = self.dtype + dtype_default = type(self.default).__name__ if dtype_input != dtype_default: # allow ints to be cast as floats if not ((dtype_default == "int") and (dtype_input == "float")): raise ValueError( - f"{values['name']} -> User dtype is {dtype_input} while default value dtype is {dtype_default}" + f"{self.name} -> User dtype is {dtype_input} while default value dtype is {dtype_default}" ) - return values + return self class UserDefinedResult(UserDefinedValue): """Represents a result""" - @validator("dtype") + @field_validator("dtype") @classmethod - def check_dtype(cls, value, values): + def check_dtype(cls, value, info): if value not in ["float", "int"]: - raise ValueError(f"{values['name']} -> Value must be an 'int' or 'float'") + raise ValueError(f"{info.field_name} -> Value must be an 'int' or 'float'") return value - @root_validator(pre=True) + @model_validator(mode="before") @classmethod def check_required_inputs(cls, values): required = ["default", "defined_sets", "indices"] @@ -145,7 +145,7 @@ def check_required_inputs(cls, values): ) return values - @root_validator(pre=True) + @model_validator(mode="before") @classmethod def check_deprecated_values(cls, values): deprecated = ["calculated", "Calculated"] @@ -156,22 +156,20 @@ def check_deprecated_values(cls, values): ) return values - @root_validator(pre=True) - @classmethod - def check_index_in_set(cls, values): - if not all(i in values["defined_sets"] for i in values["indices"]): - raise ValueError(f"{values['name']} -> Index not in user supplied sets") - return values + @model_validator(mode="after") + def check_index_in_set(self): + if not all(i in self.defined_sets for i in self.indices): + raise ValueError(f"{self.name} -> Index not in user supplied sets") + return self - @root_validator(pre=True) - @classmethod - def check_dtype_default(cls, values): - dtype_input = values["dtype"] - dtype_default = type(values["default"]).__name__ + @model_validator(mode="after") + def check_dtype_default(self): + dtype_input = self.dtype + dtype_default = type(self.default).__name__ if dtype_input != dtype_default: # allow ints to be cast as floats if not ((dtype_default == "int") and (dtype_input == "float")): raise ValueError( - f"{values['name']} -> User dtype is {dtype_input} while default value dtype is {dtype_default}" + f"{self.name} -> User dtype is {dtype_input} while default value dtype is {dtype_default}" ) - return values + return self