Skip to content

Commit

Permalink
change series comparison to use assert_series_equal
Browse files Browse the repository at this point in the history
- allow small floating point differences
- configurable through class variables SERIES_EQUALITY_ABSOLUTE_TOLERANCE (default 1e-10) and SERIES_EQUALITY_RELATIVE_TOLERANCE (default 0.0)
  • Loading branch information
dbrakenhoff committed Oct 9, 2024
1 parent 664a355 commit 4fad416
Showing 1 changed file with 27 additions and 7 deletions.
34 changes: 27 additions & 7 deletions pastastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pastas as ps
from numpy import isin
from packaging.version import parse as parse_version
from pandas.testing import assert_series_equal
from pastas.io.pas import PastasEncoder
from tqdm.auto import tqdm

Expand All @@ -30,7 +31,7 @@ class BaseConnector(ABC):
Class holds base logic for dealing with time series and Pastas Models. Create your
own Connector to a data source by writing a a class that inherits from this
BaseConnector. Your class has to override each abstractmethod and abstractproperty.
BaseConnector. Your class has to override each abstractmethod and property.
"""

_default_library_names = [
Expand All @@ -47,6 +48,10 @@ class BaseConnector(ABC):
# True for pastas>=0.23.0 and False for pastas<=0.22.0
USE_PASTAS_VALIDATE_SERIES = False if PASTAS_LEQ_022 else True

# set series equality comparison settings (using assert_series_equal)
SERIES_EQUALITY_ABSOLUTE_TOLERANCE = 1e-10
SERIES_EQUALITY_RELATIVE_TOLERANCE = 0.0

def __repr__(self):
"""Representation string of the object."""
return (
Expand Down Expand Up @@ -1665,11 +1670,18 @@ def _check_oseries_in_store(self, ml: Union[ps.Model, dict]):
so = ml.oseries.series_original
else:
so = ml.oseries._series_original
if not so.dropna().equals(s_org):
try:
assert_series_equal(
so.dropna(),
s_org,
atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE,
rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE,
)
except AssertionError as e:
raise ValueError(
f"Cannot add model because model oseries '{name}'"
" is different from stored oseries!"
)
" is different from stored oseries! See stacktrace for differences."
) from e

def _check_stresses_in_store(self, ml: Union[ps.Model, dict]):
"""Check if stresses time series are contained in PastaStore (internal method).
Expand Down Expand Up @@ -1699,11 +1711,19 @@ def _check_stresses_in_store(self, ml: Union[ps.Model, dict]):
so = s.series_original
else:
so = s._series_original
if not so.equals(s_org):
try:
assert_series_equal(
so,
s_org,
atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE,
rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE,
)
except AssertionError as e:
raise ValueError(
f"Cannot add model because model stress "
f"'{s.name}' is different from stored stress!"
)
f"'{s.name}' is different from stored stress! "
"See stacktrace for differences."
) from e
elif isinstance(ml, dict):
for sm in ml["stressmodels"].values():
classkey = "stressmodel" if PASTAS_LEQ_022 else "class"
Expand Down

0 comments on commit 4fad416

Please sign in to comment.