Skip to content

Commit

Permalink
Remove deprecated json_encoder from ConfigDict for serializing SympyE…
Browse files Browse the repository at this point in the history
…xprStr objects. Use updated field serializer decorator
  • Loading branch information
nanglo123 committed Sep 12, 2024
1 parent 44c1699 commit 5269fcd
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 54 deletions.
10 changes: 2 additions & 8 deletions mira/metamodel/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,8 @@ class IntraModelEdge(DataEdge):

class ModelComparisonGraphdata(BaseModel):
"""A data structure holding a graph representation of TemplateModel delta"""
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(arbitrary_types_allowed=True, json_encoders={
SympyExprStr: lambda e: str(e),
}, json_decoders={
SympyExprStr: lambda e: safe_parse_expr(e),
Template: lambda t: Template.from_json(data=t),
})

model_config = ConfigDict(arbitrary_types_allowed=True)

template_models: Dict[int, TemplateModel] = Field(
..., description="A mapping of template model keys to template models"
Expand Down
29 changes: 13 additions & 16 deletions mira/metamodel/template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import networkx as nx
import sympy
import mira.metamodel.io
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_serializer
from .templates import *
from .units import Unit
from .utils import safe_parse_expr, SympyExprStr
Expand All @@ -36,11 +36,8 @@ class Initial(BaseModel):
expression: SympyExprStr = Field(
description="The expression for the initial."
)
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(arbitrary_types_allowed=True, json_encoders={
SympyExprStr: lambda e: str(e),
}, json_decoders={SympyExprStr: lambda e: sympy.parse_expr(e)})

model_config = ConfigDict(arbitrary_types_allowed=True)

@classmethod
def from_json(cls, data: Dict[str, Any], locals_dict=None) -> "Initial":
Expand Down Expand Up @@ -68,6 +65,10 @@ def from_json(cls, data: Dict[str, Any], locals_dict=None) -> "Initial":
expression = safe_parse_expr(expression_str, local_dict=locals_dict)
return cls(concept=concept, expression=SympyExprStr(expression))

@field_serializer('expression')
def serialize_expression(self, expression):
return str(expression)

def substitute_parameter(self, name, value):
"""
Substitute a parameter value into the initial expression.
Expand Down Expand Up @@ -135,16 +136,17 @@ class Observable(Concept):
readout is not defined as a state variable but is rather a function of
state variables.
"""
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(arbitrary_types_allowed=True, json_encoders={
SympyExprStr: lambda e: str(e),
}, json_decoders={SympyExprStr: lambda e: safe_parse_expr(e)})

model_config = ConfigDict(arbitrary_types_allowed=True)

expression: SympyExprStr = Field(
description="The expression for the observable."
)

@field_serializer('expression')
def serialize_expression(self, expression):
return str(expression)

def substitute_parameter(self, name, value):
"""
Substitute a parameter value into the observable expression.
Expand Down Expand Up @@ -367,11 +369,6 @@ class TemplateModel(BaseModel):
description="A structure containing time-related annotations. "
"Note that all annotations are optional.",
)
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(json_encoders={
SympyExprStr: lambda e: str(e),
}, json_decoders={SympyExprStr: lambda e: safe_parse_expr(e)})

def get_parameters_from_rate_law(self, rate_law) -> Set[str]:
"""Given a rate law, find its elements that are model parameters.
Expand Down
24 changes: 9 additions & 15 deletions mira/metamodel/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
import networkx as nx
import pydantic
import sympy
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_serializer

try:
from typing import Annotated # py39+
Expand Down Expand Up @@ -128,13 +128,8 @@ class Concept(BaseModel):
None, description="The units of the concept."
)
_base_name: str = pydantic.PrivateAttr(None)
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(arbitrary_types_allowed=True, json_encoders={
SympyExprStr: lambda e: str(e),
}, json_decoders={
SympyExprStr: lambda e: sympy.parse_expr(e)
})

model_config = ConfigDict(arbitrary_types_allowed=True)

def __eq__(self, other):
if isinstance(other, Concept):
Expand Down Expand Up @@ -399,13 +394,8 @@ def from_json(cls, data) -> "Concept":

class Template(BaseModel):
"""The Template is a parent class for model processes"""
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(arbitrary_types_allowed=True, json_encoders={
SympyExprStr: lambda e: str(e),
}, json_decoders={
SympyExprStr: lambda e: safe_parse_expr(e)
})

model_config = ConfigDict(arbitrary_types_allowed=True)

rate_law: Optional[SympyExprStr] = Field(
default=None, description="The rate law for the template."
Expand Down Expand Up @@ -466,6 +456,10 @@ def from_json(cls, data, rate_symbols=None) -> "Template":
if k not in {'rate_law', 'type'}},
rate_law=rate)

@field_serializer('rate_law')
def serialize_expression(self, rate_law):
return str(rate_law)

def is_equal_to(self, other: "Template", with_context: bool = False,
config: Config = None) -> bool:
"""Check if this template is equal to another template
Expand Down
18 changes: 6 additions & 12 deletions mira/metamodel/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Dict, Any

import sympy
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_serializer
from .utils import SympyExprStr


Expand All @@ -34,17 +34,7 @@ def load_units():

class Unit(BaseModel):
"""A unit of measurement."""
# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
arbitrary_types_allowed=True,
json_encoders={
SympyExprStr: lambda e: str(e),
},
#json_decoders={
# SympyExprStr: lambda e: sympy.parse_expr(e)
#}
)
model_config = ConfigDict(arbitrary_types_allowed=True)

expression: SympyExprStr = Field(
description="The expression for the unit."
Expand All @@ -66,6 +56,10 @@ def model_validate(cls, obj):
obj['expression'] = SympyExprStr(obj['expression'])
return super().model_validate(obj)

@field_serializer('expression')
def serialize_expression(self, expression):
return str(expression)


person_units = Unit(expression=sympy.Symbol('person'))
day_units = Unit(expression=sympy.Symbol('day'))
Expand Down
8 changes: 5 additions & 3 deletions tests/test_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,16 +586,18 @@ def test_n_way_comparison_askenet(self):
"exclude_defaults": True,
"exclude_unset": True,
"exclude_none": True,
# This key is outdated in Pydantic2
# "skip_defaults": True,
}
# Compare the ModelComparisonResponse models
assert local_response == resp_model # If assertion fails the diff is printed
# If assertion fails the diff is printed
assert local_response.model_dump() == resp_model.model_dump()
local_sorted_str = sorted_json_str(
json.loads(local_response.model_dump_json(**dict_options)),
json.loads(local_response.model_dump_json()),
skip_empty=True
)
resp_sorted_str = sorted_json_str(
json.loads(resp_model.model_dump_json(**dict_options)),
json.loads(resp_model.model_dump_json()),
skip_empty=True
)
self.assertEqual(local_sorted_str, resp_sorted_str)
Expand Down

0 comments on commit 5269fcd

Please sign in to comment.