Skip to content

Commit 089b8f9

Browse files
Merge pull request #2164 from OceanParcels/deal_with_search_interpolation_errors
Deal with search interpolation errors
2 parents b185396 + bbc8e28 commit 089b8f9

File tree

5 files changed

+114
-37
lines changed

5 files changed

+114
-37
lines changed

parcels/_index_search.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77

88
from parcels._typing import Mesh
99
from parcels.tools.statuscodes import (
10-
_raise_field_out_of_bound_error,
11-
_raise_field_sampling_error,
10+
_raise_grid_searching_error,
1211
_raise_time_extrapolation_error,
1312
)
1413

@@ -65,12 +64,12 @@ def _search_indices_curvilinear_2d(
6564
# TODO: Re-enable in some capacity
6665
# if x < field.lonlat_minmax[0] or x > field.lonlat_minmax[1]:
6766
# if grid.lon[0, 0] < grid.lon[0, -1]:
68-
# _raise_field_out_of_bound_error(y, x)
67+
# _raise_grid_searching_error(y, x)
6968
# elif x < grid.lon[0, 0] and x > grid.lon[0, -1]: # This prevents from crashing in [160, -160]
70-
# _raise_field_out_of_bound_error(z, y, x)
69+
# _raise_grid_searching_error(z, y, x)
7170

7271
# if y < field.lonlat_minmax[2] or y > field.lonlat_minmax[3]:
73-
# _raise_field_out_of_bound_error(z, y, x)
72+
# _raise_grid_searching_error(z, y, x)
7473

7574
while np.any(xsi < -tol) or np.any(xsi > 1 + tol) or np.any(eta < -tol) or np.any(eta > 1 + tol):
7675
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
@@ -100,12 +99,12 @@ def _search_indices_curvilinear_2d(
10099
it += 1
101100
if it > maxIterSearch:
102101
print(f"Correct cell not found after {maxIterSearch} iterations")
103-
_raise_field_out_of_bound_error(0, y, x)
102+
_raise_grid_searching_error(0, y, x)
104103
xsi = np.where(xsi < 0.0, 0.0, np.where(xsi > 1.0, 1.0, xsi))
105104
eta = np.where(eta < 0.0, 0.0, np.where(eta > 1.0, 1.0, eta))
106105

107106
if np.any((xsi < 0) | (xsi > 1) | (eta < 0) | (eta > 1)):
108-
_raise_field_sampling_error(y, x)
107+
_raise_grid_searching_error(y, x)
109108
return (yi, eta, xi, xsi)
110109

111110

parcels/field.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,12 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
233233

234234
tau, ti = _search_time_index(self, time)
235235
position = self.grid.search(z, y, x, ei=_ei)
236-
_update_particle_states(particle, position)
236+
_update_particle_states_position(particle, position)
237237

238238
value = self._interp_method(self, ti, position, tau, time, z, y, x)
239239

240+
_update_particle_states_interp_value(particle, value)
241+
240242
if applyConversion:
241243
value = self.units.to_target(value, z, y, x)
242244
return value
@@ -313,16 +315,21 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
313315

314316
tau, ti = _search_time_index(self.U, time)
315317
position = self.grid.search(z, y, x, ei=_ei)
316-
_update_particle_states(particle, position)
318+
_update_particle_states_position(particle, position)
317319

318320
if self._vector_interp_method is None:
319321
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
320322
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
321323
if "3D" in self.vector_type:
322324
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
325+
else:
326+
w = 0.0
323327
else:
324328
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)
325329

330+
for vel in (u, v, w):
331+
_update_particle_states_interp_value(particle, vel)
332+
326333
if applyConversion:
327334
u = self.U.units.to_target(u, z, y, x)
328335
v = self.V.units.to_target(v, z, y, x)
@@ -344,14 +351,30 @@ def __getitem__(self, key):
344351
return _deal_with_errors(error, key, vector_type=self.vector_type)
345352

346353

347-
def _update_particle_states(particle, position):
354+
def _update_particle_states_position(particle, position):
348355
"""Update the particle states based on the position dictionary."""
349356
if particle and "X" in position: # TODO also support uxgrid search
350-
particle.state = np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
351-
particle.state = np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
352-
particle.state = np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state)
353-
particle.state = np.where(
354-
position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state
357+
particle.state = np.maximum(
358+
np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
359+
)
360+
particle.state = np.maximum(
361+
np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state), particle.state
362+
)
363+
particle.state = np.maximum(
364+
np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state),
365+
particle.state,
366+
)
367+
particle.state = np.maximum(
368+
np.where(position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state),
369+
particle.state,
370+
)
371+
372+
373+
def _update_particle_states_interp_value(particle, value):
374+
"""Update the particle states based on the interpolated value, but only if state is not an Error already."""
375+
if particle:
376+
particle.state = np.maximum(
377+
np.where(np.isnan(value), StatusCode.ErrorInterpolation, particle.state), particle.state
355378
)
356379

357380

parcels/kernel.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
from parcels.basegrid import GridType
1616
from parcels.tools.statuscodes import (
1717
StatusCode,
18+
_raise_field_interpolation_error,
1819
_raise_field_out_of_bound_error,
1920
_raise_field_out_of_bound_surface_error,
20-
_raise_field_sampling_error,
21+
_raise_general_error,
22+
_raise_grid_searching_error,
2123
_raise_time_extrapolation_error,
2224
)
2325
from parcels.tools.warnings import KernelWarning
@@ -28,6 +30,16 @@
2830
__all__ = ["Kernel"]
2931

3032

33+
ErrorsToThrow = {
34+
StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error,
35+
StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error,
36+
StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error,
37+
StatusCode.ErrorInterpolation: _raise_field_interpolation_error,
38+
StatusCode.ErrorGridSearching: _raise_grid_searching_error,
39+
StatusCode.Error: _raise_general_error,
40+
}
41+
42+
3143
class Kernel:
3244
"""Kernel object that encapsulates auto-generated code.
3345
@@ -270,14 +282,7 @@ def execute(self, pset, endtime, dt):
270282
if np.any(pset.state == StatusCode.StopAllExecution):
271283
return StatusCode.StopAllExecution
272284

273-
errors_to_throw = {
274-
StatusCode.ErrorTimeExtrapolation: _raise_time_extrapolation_error,
275-
StatusCode.ErrorOutOfBounds: _raise_field_out_of_bound_error,
276-
StatusCode.ErrorThroughSurface: _raise_field_out_of_bound_surface_error,
277-
StatusCode.Error: _raise_field_sampling_error,
278-
}
279-
280-
for error_code, error_func in errors_to_throw.items():
285+
for error_code, error_func in ErrorsToThrow.items():
281286
if np.any(pset.state == error_code):
282287
inds = pset.state == error_code
283288
if error_code == StatusCode.ErrorTimeExtrapolation:

parcels/tools/statuscodes.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
"KernelError",
88
"StatusCode",
99
"TimeExtrapolationError",
10+
"_raise_field_interpolation_error",
1011
"_raise_field_out_of_bound_error",
1112
"_raise_field_out_of_bound_surface_error",
12-
"_raise_field_sampling_error",
13+
"_raise_general_error",
14+
"_raise_grid_searching_error",
1315
"_raise_time_extrapolation_error",
1416
]
1517

@@ -25,37 +27,38 @@ class StatusCode:
2527
StopAllExecution = 41
2628
Error = 50
2729
ErrorInterpolation = 51
30+
ErrorGridSearching = 52
2831
ErrorOutOfBounds = 60
2932
ErrorThroughSurface = 61
3033
ErrorTimeExtrapolation = 70
3134

3235

33-
class FieldSamplingError(RuntimeError):
34-
"""Utility error class to propagate erroneous field sampling."""
36+
class FieldInterpolationError(RuntimeError):
37+
"""Utility error class to propagate NaN field interpolation."""
3538

3639
pass
3740

3841

42+
def _raise_field_interpolation_error(z, y, x):
43+
raise FieldInterpolationError(f"Field interpolation returned NaN at (depth={z}, lat={y}, lon={x})")
44+
45+
3946
class FieldOutOfBoundError(RuntimeError):
4047
"""Utility error class to propagate out-of-bound field sampling."""
4148

4249
pass
4350

4451

52+
def _raise_field_out_of_bound_error(z, y, x):
53+
raise FieldOutOfBoundError(f"Field sampled out-of-bound, at (depth={z}, lat={y}, lon={x})")
54+
55+
4556
class FieldOutOfBoundSurfaceError(RuntimeError):
4657
"""Utility error class to propagate out-of-bound field sampling at the surface."""
4758

4859
pass
4960

5061

51-
def _raise_field_sampling_error(z, y, x):
52-
raise FieldSamplingError(f"Field sampled at (depth={z}, lat={y}, lon={x})")
53-
54-
55-
def _raise_field_out_of_bound_error(z, y, x):
56-
raise FieldOutOfBoundError(f"Field sampled out-of-bound, at (depth={z}, lat={y}, lon={x})")
57-
58-
5962
def _raise_field_out_of_bound_surface_error(z: float | None, y: float | None, x: float | None) -> None:
6063
def format_out(val):
6164
return "unknown" if val is None else val
@@ -65,6 +68,32 @@ def format_out(val):
6568
)
6669

6770

71+
class FieldSamplingError(RuntimeError):
72+
"""Utility error class to propagate field sampling errors."""
73+
74+
pass
75+
76+
77+
class GridSearchingError(RuntimeError):
78+
"""Utility error class to propagate grid searching errors."""
79+
80+
pass
81+
82+
83+
def _raise_grid_searching_error(z, y, x):
84+
raise GridSearchingError(f"Grid searching failed at (depth={z}, lat={y}, lon={x})")
85+
86+
87+
class GeneralError(RuntimeError):
88+
"""Utility error class to propagate general errors."""
89+
90+
pass
91+
92+
93+
def _raise_general_error(z, y, x):
94+
raise GeneralError(f"General error occurred at (depth={z}, lat={y}, lon={x})")
95+
96+
6897
class TimeExtrapolationError(RuntimeError):
6998
"""Utility error class to propagate erroneous time extrapolation sampling."""
7099

@@ -91,9 +120,11 @@ def __init__(self, particle, fieldset=None, msg=None):
91120

92121

93122
AllParcelsErrorCodes = {
94-
FieldSamplingError: StatusCode.Error,
123+
FieldInterpolationError: StatusCode.ErrorInterpolation,
95124
FieldOutOfBoundError: StatusCode.ErrorOutOfBounds,
96125
FieldOutOfBoundSurfaceError: StatusCode.ErrorThroughSurface,
126+
GridSearchingError: StatusCode.ErrorGridSearching,
97127
TimeExtrapolationError: StatusCode.ErrorTimeExtrapolation,
98128
KernelError: StatusCode.Error,
129+
GeneralError: StatusCode.Error,
99130
}

tests/v4/test_particleset_execute.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from parcels._datasets.structured.generated import simple_UV_dataset
1818
from parcels._datasets.structured.generic import datasets as datasets_structured
1919
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
20-
from parcels.tools.statuscodes import FieldOutOfBoundError, TimeExtrapolationError
20+
from parcels.tools.statuscodes import FieldInterpolationError, FieldOutOfBoundError, TimeExtrapolationError
2121
from parcels.uxgrid import UxGrid
2222
from parcels.xgrid import XGrid
2323
from tests import utils
@@ -267,6 +267,25 @@ def FieldAccessOutsideTime(particle, fieldset, time): # pragma: no cover
267267
pset.execute(FieldAccessOutsideTime, runtime=np.timedelta64(1, "D"), dt=np.timedelta64(10, "D"))
268268

269269

270+
def test_raise_grid_searching_error(): ...
271+
272+
273+
def test_raise_general_error(): ...
274+
275+
276+
def test_errorinterpolation(fieldset):
277+
def NaNInterpolator(field, ti, position, tau, t, z, y, x): # pragma: no cover
278+
return np.nan * np.zeros_like(x)
279+
280+
def SampleU(particle, fieldset, time): # pragma: no cover
281+
fieldset.U[particle.time, particle.depth, particle.lat, particle.lon, particle]
282+
283+
fieldset.U.interp_method = NaNInterpolator
284+
pset = ParticleSet(fieldset, lon=[0, 2], lat=[0, 0])
285+
with pytest.raises(FieldInterpolationError):
286+
pset.execute(SampleU, runtime=np.timedelta64(2, "s"), dt=np.timedelta64(1, "s"))
287+
288+
270289
def test_execution_check_stopallexecution(fieldset):
271290
def addoneLon(particle, fieldset, time): # pragma: no cover
272291
particle.dlon += 1

0 commit comments

Comments
 (0)