Skip to content
Merged
13 changes: 6 additions & 7 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from parcels._typing import Mesh
from parcels.tools.statuscodes import (
_raise_field_out_of_bound_error,
_raise_field_sampling_error,
_raise_grid_searching_error,
_raise_time_extrapolation_error,
)

Expand Down Expand Up @@ -65,12 +64,12 @@ def _search_indices_curvilinear_2d(
# TODO: Re-enable in some capacity
# if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
# if grid.lon[0, 0] < grid.lon[0, -1]:
# _raise_field_out_of_bound_error(y, x)
# _raise_grid_searching_error(y, x)
# elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
# _raise_field_out_of_bound_error(z, y, x)
# _raise_grid_searching_error(z, y, x)

# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
# _raise_field_out_of_bound_error(z, y, x)
# _raise_grid_searching_error(z, y, x)

while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
Expand Down Expand Up @@ -100,12 +99,12 @@ def _search_indices_curvilinear_2d(
it += 1
if it > maxIterSearch:
print(f"Correct cell not found after {maxIterSearch} iterations")
_raise_field_out_of_bound_error(0, y, x)
_raise_grid_searching_error(0, y, x)
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))

if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
_raise_field_sampling_error(y, x)
_raise_grid_searching_error(y, x)
return (yi, eta, xi, xsi)


Expand Down
39 changes: 31 additions & 8 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,12 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):

tau, ti = _search_time_index(self, time)
position = self.grid.search(z, y, x, ei=_ei)
_update_particle_states(particle, position)
_update_particle_states_position(particle, position)

value = self._interp_method(self, ti, position, tau, time, z, y, x)

_update_particle_states_interp_value(particle, value)

if applyConversion:
value = self.units.to_target(value, z, y, x)
return value
Expand Down Expand Up @@ -313,16 +315,21 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):

tau, ti = _search_time_index(self.U, time)
position = self.grid.search(z, y, x, ei=_ei)
_update_particle_states(particle, position)
_update_particle_states_position(particle, position)

if self._vector_interp_method is None:
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
if "3D" in self.vector_type:
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
else:
w = 0.0
else:
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)

for vel in (u, v, w):
_update_particle_states_interp_value(particle, vel)

if applyConversion:
u = self.U.units.to_target(u, z, y, x)
v = self.V.units.to_target(v, z, y, x)
Expand All @@ -344,14 +351,30 @@ def __getitem__(self, key):
return _deal_with_errors(error, key, vector_type=self.vector_type)


def _update_particle_states(particle, position):
def _update_particle_states_position(particle, position):
"""Update the particle states based on the position dictionary."""
if particle and "X" in position: # TODO also support uxgrid search
particle.state = np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
particle.state = np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
particle.state = np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state)
particle.state = np.where(
position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state
particle.state = np.maximum(
np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
)
particle.state = np.maximum(
np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
)
particle.state = np.maximum(
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state),
particle.state,
)
particle.state = np.maximum(
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state),
particle.state,
)


def _update_particle_states_interp_value(particle, value):
"""Update the particle states based on the interpolated value, but only if state is not an Error already."""
if particle:
particle.state = np.maximum(
np.where(np.isnan(value), StatusCode.ErrorInterpolation, particle.state), particle.state
)


Expand Down
23 changes: 14 additions & 9 deletions parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
from parcels.basegrid import GridType
from parcels.tools.statuscodes import (
StatusCode,
_raise_field_interpolation_error,
_raise_field_out_of_bound_error,
_raise_field_out_of_bound_surface_error,
_raise_field_sampling_error,
_raise_general_error,
_raise_grid_searching_error,
_raise_time_extrapolation_error,
)
from parcels.tools.warnings import KernelWarning
Expand All @@ -28,6 +30,16 @@
__all__ = ["Kernel"]


ErrorsToThrow = {
StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error,
StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error,
StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error,
StatusCode.ErrorInterpolation: _raise_field_interpolation_error,
StatusCode.ErrorGridSearching: _raise_grid_searching_error,
StatusCode.Error: _raise_general_error,
}


class Kernel:
"""Kernel object that encapsulates auto-generated code.

Expand Down Expand Up @@ -270,14 +282,7 @@ def execute(self, pset, endtime, dt):
if np.any(pset.state == StatusCode.StopAllExecution):
return StatusCode.StopAllExecution

errors_to_throw = {
StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error,
StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error,
StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error,
StatusCode.Error: _raise_field_sampling_error,
}

for error_code, error_func in errors_to_throw.items():
for error_code, error_func in ErrorsToThrow.items():
if np.any(pset.state == error_code):
inds = pset.state == error_code
if error_code == StatusCode.ErrorTimeExtrapolation:
Expand Down
55 changes: 43 additions & 12 deletions parcels/tools/statuscodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
"KernelError",
"StatusCode",
"TimeExtrapolationError",
"_raise_field_interpolation_error",
"_raise_field_out_of_bound_error",
"_raise_field_out_of_bound_surface_error",
"_raise_field_sampling_error",
"_raise_general_error",
"_raise_grid_searching_error",
"_raise_time_extrapolation_error",
]

Expand All @@ -25,37 +27,38 @@ class StatusCode:
StopAllExecution = 41
Error = 50
ErrorInterpolation = 51
ErrorGridSearching = 52
ErrorOutOfBounds = 60
ErrorThroughSurface = 61
ErrorTimeExtrapolation = 70


class FieldSamplingError(RuntimeError):
"""Utility error class to propagate erroneous field sampling."""
class FieldInterpolationError(RuntimeError):
"""Utility error class to propagate NaN field interpolation."""

pass


def _raise_field_interpolation_error(z, y, x):
raise FieldInterpolationError(f"Field interpolation returned NaN at (depth={z}, lat={y}, lon={x})")


class FieldOutOfBoundError(RuntimeError):
"""Utility error class to propagate out-of-bound field sampling."""

pass


def _raise_field_out_of_bound_error(z, y, x):
raise FieldOutOfBoundError(f"Field sampled out-of-bound, at (depth={z}, lat={y}, lon={x})")


class FieldOutOfBoundSurfaceError(RuntimeError):
"""Utility error class to propagate out-of-bound field sampling at the surface."""

pass


def _raise_field_sampling_error(z, y, x):
raise FieldSamplingError(f"Field sampled at (depth={z}, lat={y}, lon={x})")


def _raise_field_out_of_bound_error(z, y, x):
raise FieldOutOfBoundError(f"Field sampled out-of-bound, at (depth={z}, lat={y}, lon={x})")


def _raise_field_out_of_bound_surface_error(z: float | None, y: float | None, x: float | None) -> None:
def format_out(val):
return "unknown" if val is None else val
Expand All @@ -65,6 +68,32 @@ def format_out(val):
)


class FieldSamplingError(RuntimeError):
"""Utility error class to propagate field sampling errors."""

pass


class GridSearchingError(RuntimeError):
"""Utility error class to propagate grid searching errors."""

pass


def _raise_grid_searching_error(z, y, x):
raise GridSearchingError(f"Grid searching failed at (depth={z}, lat={y}, lon={x})")


class GeneralError(RuntimeError):
"""Utility error class to propagate general errors."""

pass


def _raise_general_error(z, y, x):
raise GeneralError(f"General error occurred at (depth={z}, lat={y}, lon={x})")


class TimeExtrapolationError(RuntimeError):
"""Utility error class to propagate erroneous time extrapolation sampling."""

Expand All @@ -91,9 +120,11 @@ def __init__(self, particle, fieldset=None, msg=None):


AllParcelsErrorCodes = {
FieldSamplingError: StatusCode.Error,
FieldInterpolationError: StatusCode.ErrorInterpolation,
FieldOutOfBoundError: StatusCode.ErrorOutOfBounds,
FieldOutOfBoundSurfaceError: StatusCode.ErrorThroughSurface,
GridSearchingError: StatusCode.ErrorGridSearching,
TimeExtrapolationError: StatusCode.ErrorTimeExtrapolation,
KernelError: StatusCode.Error,
GeneralError: StatusCode.Error,
}
21 changes: 20 additions & 1 deletion tests/v4/test_particleset_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from parcels._datasets.structured.generated import simple_UV_dataset
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
from parcels.tools.statuscodes import FieldOutOfBoundError, TimeExtrapolationError
from parcels.tools.statuscodes import FieldInterpolationError, FieldOutOfBoundError, TimeExtrapolationError
from parcels.uxgrid import UxGrid
from parcels.xgrid import XGrid
from tests import utils
Expand Down Expand Up @@ -267,6 +267,25 @@ def FieldAccessOutsideTime(particle, fieldset, time): # pragma: no cover
pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(10, "D"))


def test_raise_grid_searching_error(): ...


def test_raise_general_error(): ...


def test_errorinterpolation(fieldset):
def NaNInterpolator(field, ti, position, tau, t, z, y, x): # pragma: no cover
return np.nan * np.zeros_like(x)

def SampleU(particle, fieldset, time): # pragma: no cover
fieldset.U[particle.time, particle.depth, particle.lat, particle.lon, particle]

fieldset.U.interp_method = NaNInterpolator
pset = ParticleSet(fieldset, lon=[0, 2], lat=[0, 0])
with pytest.raises(FieldInterpolationError):
pset.execute(SampleU, runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s"))


def test_execution_check_stopallexecution(fieldset):
def addoneLon(particle, fieldset, time): # pragma: no cover
particle.dlon += 1
Expand Down
Loading