diff --git a/parcels/_index_search.py b/parcels/_index_search.py index f560d8167..dde987261 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -7,8 +7,7 @@ from parcels._typing import Mesh from parcels.tools.statuscodes import ( - _raise_field_out_of_bound_error, - _raise_field_sampling_error, + _raise_grid_searching_error, _raise_time_extrapolation_error, ) @@ -65,12 +64,12 @@ def _search_indices_curvilinear_2d( # TODO: Re-enable in some capacity # if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]: # if grid.lon[0, 0] < grid.lon[0, -1]: - # _raise_field_out_of_bound_error(y, x) + # _raise_grid_searching_error(y, x) # elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160] - # _raise_field_out_of_bound_error(z, y, x) + # _raise_grid_searching_error(z, y, x) # if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]: - # _raise_field_out_of_bound_error(z, y, x) + # _raise_grid_searching_error(z, y, x) while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol): px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]]) @@ -100,12 +99,12 @@ def _search_indices_curvilinear_2d( it += 1 if it > maxIterSearch: print(f"Correct cell not found after {maxIterSearch} iterations") - _raise_field_out_of_bound_error(0, y, x) + _raise_grid_searching_error(0, y, x) xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi)) eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta)) if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)): - _raise_field_sampling_error(y, x) + _raise_grid_searching_error(y, x) return (yi, eta, xi, xsi) diff --git a/parcels/field.py b/parcels/field.py index 7d14eff60..360e84393 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -233,10 +233,12 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): tau, ti = _search_time_index(self, time) position = self.grid.search(z, y, x, ei=_ei) - _update_particle_states(particle, position) + _update_particle_states_position(particle, position) value = self._interp_method(self, ti, position, tau, time, z, y, x) + _update_particle_states_interp_value(particle, value) + if applyConversion: value = self.units.to_target(value, z, y, x) return value @@ -313,16 +315,21 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): tau, ti = _search_time_index(self.U, time) position = self.grid.search(z, y, x, ei=_ei) - _update_particle_states(particle, position) + _update_particle_states_position(particle, position) if self._vector_interp_method is None: u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x) v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x) if "3D" in self.vector_type: w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x) + else: + w = 0.0 else: (u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x) + for vel in (u, v, w): + _update_particle_states_interp_value(particle, vel) + if applyConversion: u = self.U.units.to_target(u, z, y, x) v = self.V.units.to_target(v, z, y, x) @@ -344,14 +351,30 @@ def __getitem__(self, key): return _deal_with_errors(error, key, vector_type=self.vector_type) -def _update_particle_states(particle, position): +def _update_particle_states_position(particle, position): """Update the particle states based on the position dictionary.""" if particle and "X" in position: # TODO also support uxgrid search - particle.state = np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state) - particle.state = np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state) - particle.state = np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state) - particle.state = np.where( - position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state + particle.state = np.maximum( + np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state + ) + particle.state = np.maximum( + np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state + ) + particle.state = np.maximum( + np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state), + particle.state, + ) + particle.state = np.maximum( + np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state), + particle.state, + ) + + +def _update_particle_states_interp_value(particle, value): + """Update the particle states based on the interpolated value, but only if state is not an Error already.""" + if particle: + particle.state = np.maximum( + np.where(np.isnan(value), StatusCode.ErrorInterpolation, particle.state), particle.state ) diff --git a/parcels/kernel.py b/parcels/kernel.py index 467724ee8..e4c42238f 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -15,9 +15,11 @@ from parcels.basegrid import GridType from parcels.tools.statuscodes import ( StatusCode, + _raise_field_interpolation_error, _raise_field_out_of_bound_error, _raise_field_out_of_bound_surface_error, - _raise_field_sampling_error, + _raise_general_error, + _raise_grid_searching_error, _raise_time_extrapolation_error, ) from parcels.tools.warnings import KernelWarning @@ -28,6 +30,16 @@ __all__ = ["Kernel"] +ErrorsToThrow = { + StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error, + StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error, + StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error, + StatusCode.ErrorInterpolation: _raise_field_interpolation_error, + StatusCode.ErrorGridSearching: _raise_grid_searching_error, + StatusCode.Error: _raise_general_error, +} + + class Kernel: """Kernel object that encapsulates auto-generated code. @@ -270,14 +282,7 @@ def execute(self, pset, endtime, dt): if np.any(pset.state == StatusCode.StopAllExecution): return StatusCode.StopAllExecution - errors_to_throw = { - StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error, - StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error, - StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error, - StatusCode.Error: _raise_field_sampling_error, - } - - for error_code, error_func in errors_to_throw.items(): + for error_code, error_func in ErrorsToThrow.items(): if np.any(pset.state == error_code): inds = pset.state == error_code if error_code == StatusCode.ErrorTimeExtrapolation: diff --git a/parcels/tools/statuscodes.py b/parcels/tools/statuscodes.py index 94589b8d8..aa84b3ef4 100644 --- a/parcels/tools/statuscodes.py +++ b/parcels/tools/statuscodes.py @@ -7,9 +7,11 @@ "KernelError", "StatusCode", "TimeExtrapolationError", + "_raise_field_interpolation_error", "_raise_field_out_of_bound_error", "_raise_field_out_of_bound_surface_error", - "_raise_field_sampling_error", + "_raise_general_error", + "_raise_grid_searching_error", "_raise_time_extrapolation_error", ] @@ -25,37 +27,38 @@ class StatusCode: StopAllExecution = 41 Error = 50 ErrorInterpolation = 51 + ErrorGridSearching = 52 ErrorOutOfBounds = 60 ErrorThroughSurface = 61 ErrorTimeExtrapolation = 70 -class FieldSamplingError(RuntimeError): - """Utility error class to propagate erroneous field sampling.""" +class FieldInterpolationError(RuntimeError): + """Utility error class to propagate NaN field interpolation.""" pass +def _raise_field_interpolation_error(z, y, x): + raise FieldInterpolationError(f"Field interpolation returned NaN at (depth={z}, lat={y}, lon={x})") + + class FieldOutOfBoundError(RuntimeError): """Utility error class to propagate out-of-bound field sampling.""" pass +def _raise_field_out_of_bound_error(z, y, x): + raise FieldOutOfBoundError(f"Field sampled out-of-bound, at (depth={z}, lat={y}, lon={x})") + + class FieldOutOfBoundSurfaceError(RuntimeError): """Utility error class to propagate out-of-bound field sampling at the surface.""" pass -def _raise_field_sampling_error(z, y, x): - raise FieldSamplingError(f"Field sampled at (depth={z}, lat={y}, lon={x})") - - -def _raise_field_out_of_bound_error(z, y, x): - raise FieldOutOfBoundError(f"Field sampled out-of-bound, at (depth={z}, lat={y}, lon={x})") - - def _raise_field_out_of_bound_surface_error(z: float | None, y: float | None, x: float | None) -> None: def format_out(val): return "unknown" if val is None else val @@ -65,6 +68,32 @@ def format_out(val): ) +class FieldSamplingError(RuntimeError): + """Utility error class to propagate field sampling errors.""" + + pass + + +class GridSearchingError(RuntimeError): + """Utility error class to propagate grid searching errors.""" + + pass + + +def _raise_grid_searching_error(z, y, x): + raise GridSearchingError(f"Grid searching failed at (depth={z}, lat={y}, lon={x})") + + +class GeneralError(RuntimeError): + """Utility error class to propagate general errors.""" + + pass + + +def _raise_general_error(z, y, x): + raise GeneralError(f"General error occurred at (depth={z}, lat={y}, lon={x})") + + class TimeExtrapolationError(RuntimeError): """Utility error class to propagate erroneous time extrapolation sampling.""" @@ -91,9 +120,11 @@ def __init__(self, particle, fieldset=None, msg=None): AllParcelsErrorCodes = { - FieldSamplingError: StatusCode.Error, + FieldInterpolationError: StatusCode.ErrorInterpolation, FieldOutOfBoundError: StatusCode.ErrorOutOfBounds, FieldOutOfBoundSurfaceError: StatusCode.ErrorThroughSurface, + GridSearchingError: StatusCode.ErrorGridSearching, TimeExtrapolationError: StatusCode.ErrorTimeExtrapolation, KernelError: StatusCode.Error, + GeneralError: StatusCode.Error, } diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index af631da29..0c6a4baf6 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -17,7 +17,7 @@ from parcels._datasets.structured.generated import simple_UV_dataset from parcels._datasets.structured.generic import datasets as datasets_structured from parcels._datasets.unstructured.generic import datasets as datasets_unstructured -from parcels.tools.statuscodes import FieldOutOfBoundError, TimeExtrapolationError +from parcels.tools.statuscodes import FieldInterpolationError, FieldOutOfBoundError, TimeExtrapolationError from parcels.uxgrid import UxGrid from parcels.xgrid import XGrid from tests import utils @@ -267,6 +267,25 @@ def FieldAccessOutsideTime(particle, fieldset, time): # pragma: no cover pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(10, "D")) +def test_raise_grid_searching_error(): ... + + +def test_raise_general_error(): ... + + +def test_errorinterpolation(fieldset): + def NaNInterpolator(field, ti, position, tau, t, z, y, x): # pragma: no cover + return np.nan * np.zeros_like(x) + + def SampleU(particle, fieldset, time): # pragma: no cover + fieldset.U[particle.time, particle.depth, particle.lat, particle.lon, particle] + + fieldset.U.interp_method = NaNInterpolator + pset = ParticleSet(fieldset, lon=[0, 2], lat=[0, 0]) + with pytest.raises(FieldInterpolationError): + pset.execute(SampleU, runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s")) + + def test_execution_check_stopallexecution(fieldset): def addoneLon(particle, fieldset, time): # pragma: no cover particle.dlon += 1