diff --git a/pastastore/base.py b/pastastore/base.py index 21a2f72..333e4ab 100644 --- a/pastastore/base.py +++ b/pastastore/base.py @@ -15,6 +15,7 @@ import pastas as ps from numpy import isin from packaging.version import parse as parse_version +from pandas.testing import assert_series_equal from pastas.io.pas import PastasEncoder from tqdm.auto import tqdm @@ -30,7 +31,7 @@ class BaseConnector(ABC): Class holds base logic for dealing with time series and Pastas Models. Create your own Connector to a data source by writing a a class that inherits from this - BaseConnector. Your class has to override each abstractmethod and abstractproperty. + BaseConnector. Your class has to override each abstractmethod and property. """ _default_library_names = [ @@ -47,6 +48,10 @@ class BaseConnector(ABC): # True for pastas>=0.23.0 and False for pastas<=0.22.0 USE_PASTAS_VALIDATE_SERIES = False if PASTAS_LEQ_022 else True + # set series equality comparison settings (using assert_series_equal) + SERIES_EQUALITY_ABSOLUTE_TOLERANCE = 1e-10 + SERIES_EQUALITY_RELATIVE_TOLERANCE = 0.0 + def __repr__(self): """Representation string of the object.""" return ( @@ -1665,11 +1670,18 @@ def _check_oseries_in_store(self, ml: Union[ps.Model, dict]): so = ml.oseries.series_original else: so = ml.oseries._series_original - if not so.dropna().equals(s_org): + try: + assert_series_equal( + so.dropna(), + s_org, + atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE, + rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE, + ) + except AssertionError as e: raise ValueError( f"Cannot add model because model oseries '{name}'" - " is different from stored oseries!" - ) + " is different from stored oseries! See stacktrace for differences." + ) from e def _check_stresses_in_store(self, ml: Union[ps.Model, dict]): """Check if stresses time series are contained in PastaStore (internal method). @@ -1699,11 +1711,19 @@ def _check_stresses_in_store(self, ml: Union[ps.Model, dict]): so = s.series_original else: so = s._series_original - if not so.equals(s_org): + try: + assert_series_equal( + so, + s_org, + atol=self.SERIES_EQUALITY_ABSOLUTE_TOLERANCE, + rtol=self.SERIES_EQUALITY_RELATIVE_TOLERANCE, + ) + except AssertionError as e: raise ValueError( f"Cannot add model because model stress " - f"'{s.name}' is different from stored stress!" - ) + f"'{s.name}' is different from stored stress! " + "See stacktrace for differences." + ) from e elif isinstance(ml, dict): for sm in ml["stressmodels"].values(): classkey = "stressmodel" if PASTAS_LEQ_022 else "class"