Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to pydantic 2 #188

Merged
merged 1 commit into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Jinja2<3.1
networkx
openpyxl
pandas>=1.1
pydantic<2
pydantic>=2
pydot
pyyaml
# Requirements file for ReadTheDocs, check .readthedocs.yml.
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ install_requires =
networkx
flatten_dict
openpyxl
pydantic<2
pydantic>=2
[options.packages.find]
where = src
exclude =
Expand Down
106 changes: 52 additions & 54 deletions src/otoole/preprocess/validate_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,52 @@
import logging
from typing import List, Optional, Union

from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, ConfigDict, field_validator, model_validator

# from pydantic import FieldValidationInfo


logger = logging.getLogger(__name__)


class UserDefinedValue(BaseModel):
"""Represents any user defined value"""

model_config = ConfigDict(extra="forbid")

name: str
type: str
dtype: str
defined_sets: Optional[List[str]]
indices: Optional[List[str]]
default: Optional[Union[int, float]]
calculated: Optional[bool]
short_name: Optional[str]
defined_sets: Optional[List[str]] = None
indices: Optional[List[str]] = None
default: Optional[Union[int, float]] = None
calculated: Optional[bool] = None
short_name: Optional[str] = None

@validator("type")
@field_validator("type")
@classmethod
def check_param_type(cls, value, values):
def check_param_type(cls, value, info):
if value not in ["param", "result", "set"]:
raise ValueError(
f"{values['name']} -> Type must be 'param', 'result', or 'set'"
f"{info.field_name} -> Type must be 'param', 'result', or 'set'"
)
return value

@validator("name", "short_name")
@field_validator("name", "short_name")
@classmethod # for linting purposes
def check_name_for_spaces(cls, value):
if " " in value:
raise ValueError(f"{value} -> Name can not have spaces")
return value

@validator("name", "short_name")
@field_validator("name", "short_name")
@classmethod
def check_name_for_numbers(cls, value):
if any(char.isdigit() for char in value):
raise ValueError(f"{value} -> Name can not have digits")
return value

@validator("name", "short_name")
@field_validator("name", "short_name")
@classmethod
def check_name_for_special_chars(cls, value):
# removed underscore from the recommeded special char list
Expand All @@ -54,7 +59,7 @@ def check_name_for_special_chars(cls, value):
)
return value

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_name_length(cls, values):
if len(values["name"]) > 31:
Expand All @@ -69,32 +74,29 @@ def check_name_length(cls, values):
)
return values

class Config:
extra = "forbid"


class UserDefinedSet(UserDefinedValue):
"""Represents a set"""

@validator("dtype")
@field_validator("dtype")
@classmethod
def check_dtype(cls, value, values):
def check_dtype(cls, value, info):
if value not in ["str", "int"]:
raise ValueError(f"{values['name']} -> Value must be a 'str' or 'int'")
raise ValueError(f"{info.field_name} -> Value must be a 'str' or 'int'")
return value


class UserDefinedParameter(UserDefinedValue):
"""Represents a parameter"""

@validator("dtype")
@field_validator("dtype")
@classmethod
def check_dtype(cls, value, values):
def check_dtype(cls, value, info):
if value not in ["float", "int"]:
raise ValueError(f"{values['name']} -> Value must be an 'int' or 'float'")
raise ValueError(f"{info.field_name} -> Value must be an 'int' or 'float'")
return value

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_required_inputs(cls, values):
required = ["default", "defined_sets", "indices"]
Expand All @@ -104,38 +106,36 @@ def check_required_inputs(cls, values):
)
return values

@root_validator(pre=True)
@classmethod
def check_index_in_set(cls, values):
if not all(i in values["defined_sets"] for i in values["indices"]):
raise ValueError(f"{values['name']} -> Index not in user supplied sets")
return values
@model_validator(mode="after")
def check_index_in_set(self):
if not all(i in self.defined_sets for i in self.indices):
raise ValueError(f"{self.name} -> Index not in user supplied sets")
return self

@root_validator(pre=True)
@classmethod
def check_dtype_default(cls, values):
dtype_input = values["dtype"]
dtype_default = type(values["default"]).__name__
@model_validator(mode="after")
def check_dtype_default(self):
dtype_input = self.dtype
dtype_default = type(self.default).__name__
if dtype_input != dtype_default:
# allow ints to be cast as floats
if not ((dtype_default == "int") and (dtype_input == "float")):
raise ValueError(
f"{values['name']} -> User dtype is {dtype_input} while default value dtype is {dtype_default}"
f"{self.name} -> User dtype is {dtype_input} while default value dtype is {dtype_default}"
)
return values
return self


class UserDefinedResult(UserDefinedValue):
"""Represents a result"""

@validator("dtype")
@field_validator("dtype")
@classmethod
def check_dtype(cls, value, values):
def check_dtype(cls, value, info):
if value not in ["float", "int"]:
raise ValueError(f"{values['name']} -> Value must be an 'int' or 'float'")
raise ValueError(f"{info.field_name} -> Value must be an 'int' or 'float'")
return value

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_required_inputs(cls, values):
required = ["default", "defined_sets", "indices"]
Expand All @@ -145,7 +145,7 @@ def check_required_inputs(cls, values):
)
return values

@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def check_deprecated_values(cls, values):
deprecated = ["calculated", "Calculated"]
Expand All @@ -156,22 +156,20 @@ def check_deprecated_values(cls, values):
)
return values

@root_validator(pre=True)
@classmethod
def check_index_in_set(cls, values):
if not all(i in values["defined_sets"] for i in values["indices"]):
raise ValueError(f"{values['name']} -> Index not in user supplied sets")
return values
@model_validator(mode="after")
def check_index_in_set(self):
if not all(i in self.defined_sets for i in self.indices):
raise ValueError(f"{self.name} -> Index not in user supplied sets")
return self

@root_validator(pre=True)
@classmethod
def check_dtype_default(cls, values):
dtype_input = values["dtype"]
dtype_default = type(values["default"]).__name__
@model_validator(mode="after")
def check_dtype_default(self):
dtype_input = self.dtype
dtype_default = type(self.default).__name__
if dtype_input != dtype_default:
# allow ints to be cast as floats
if not ((dtype_default == "int") and (dtype_input == "float")):
raise ValueError(
f"{values['name']} -> User dtype is {dtype_input} while default value dtype is {dtype_default}"
f"{self.name} -> User dtype is {dtype_input} while default value dtype is {dtype_default}"
)
return values
return self