Skip to content

Commit b87691d

Browse files
Fixing last failing tests in vectorized kernels
1 parent eeff3ad commit b87691d

File tree

7 files changed

+107
-133
lines changed

7 files changed

+107
-133
lines changed

parcels/application_kernels/advectiondiffusion.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,18 @@ def AdvectionDiffusionM1(particle, fieldset, time): # pragma: no cover
2828
dWx = np.random.normal(0, np.sqrt(np.fabs(dt)))
2929
dWy = np.random.normal(0, np.sqrt(np.fabs(dt)))
3030

31-
Kxp1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon + fieldset.dres]
32-
Kxm1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon - fieldset.dres]
31+
Kxp1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon + fieldset.dres, particle]
32+
Kxm1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon - fieldset.dres, particle]
3333
dKdx = (Kxp1 - Kxm1) / (2 * fieldset.dres)
3434

35-
u, v = fieldset.UV[particle.time, particle.depth, particle.lat, particle.lon]
36-
bx = np.sqrt(2 * fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon])
35+
u, v = fieldset.UV[particle.time, particle.depth, particle.lat, particle.lon, particle]
36+
bx = np.sqrt(2 * fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon, particle])
3737

38-
Kyp1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat + fieldset.dres, particle.lon]
39-
Kym1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat - fieldset.dres, particle.lon]
38+
Kyp1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat + fieldset.dres, particle.lon, particle]
39+
Kym1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat - fieldset.dres, particle.lon, particle]
4040
dKdy = (Kyp1 - Kym1) / (2 * fieldset.dres)
4141

42-
by = np.sqrt(2 * fieldset.Kh_meridional[particle.time, particle.depth, particle.lat, particle.lon])
42+
by = np.sqrt(2 * fieldset.Kh_meridional[particle.time, particle.depth, particle.lat, particle.lon, particle])
4343

4444
# Particle positions are updated only after evaluating all terms.
4545
particle.dlon += u * dt + 0.5 * dKdx * (dWx**2 + dt) + bx * dWx
@@ -64,19 +64,19 @@ def AdvectionDiffusionEM(particle, fieldset, time): # pragma: no cover
6464
dWx = np.random.normal(0, np.sqrt(np.fabs(dt)))
6565
dWy = np.random.normal(0, np.sqrt(np.fabs(dt)))
6666

67-
u, v = fieldset.UV[particle.time, particle.depth, particle.lat, particle.lon]
67+
u, v = fieldset.UV[particle.time, particle.depth, particle.lat, particle.lon, particle]
6868

69-
Kxp1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon + fieldset.dres]
70-
Kxm1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon - fieldset.dres]
69+
Kxp1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon + fieldset.dres, particle]
70+
Kxm1 = fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon - fieldset.dres, particle]
7171
dKdx = (Kxp1 - Kxm1) / (2 * fieldset.dres)
7272
ax = u + dKdx
73-
bx = np.sqrt(2 * fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon])
73+
bx = np.sqrt(2 * fieldset.Kh_zonal[particle.time, particle.depth, particle.lat, particle.lon, particle])
7474

75-
Kyp1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat + fieldset.dres, particle.lon]
76-
Kym1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat - fieldset.dres, particle.lon]
75+
Kyp1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat + fieldset.dres, particle.lon, particle]
76+
Kym1 = fieldset.Kh_meridional[particle.time, particle.depth, particle.lat - fieldset.dres, particle.lon, particle]
7777
dKdy = (Kyp1 - Kym1) / (2 * fieldset.dres)
7878
ay = v + dKdy
79-
by = np.sqrt(2 * fieldset.Kh_meridional[particle.time, particle.depth, particle.lat, particle.lon])
79+
by = np.sqrt(2 * fieldset.Kh_meridional[particle.time, particle.depth, particle.lat, particle.lon, particle])
8080

8181
# Particle positions are updated only after evaluating all terms.
8282
particle.dlon += ax * dt + bx * dWx

parcels/application_kernels/interpolation.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
"UXPiecewiseConstantFace",
1818
"UXPiecewiseLinearNode",
1919
"XBiLinear",
20-
"XBiLinearPeriodic",
2120
"XTriLinear",
2221
]
2322

@@ -57,47 +56,6 @@ def XBiLinear(
5756
return val
5857

5958

60-
def XBiLinearPeriodic(
61-
field: Field,
62-
ti: int,
63-
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
64-
tau: np.float32 | np.float64,
65-
t: np.float32 | np.float64,
66-
z: np.float32 | np.float64,
67-
y: np.float32 | np.float64,
68-
x: np.float32 | np.float64,
69-
):
70-
"""Bilinear interpolation on a regular grid with periodic boundary conditions in horizontal directions."""
71-
xi, xsi = position["X"]
72-
yi, eta = position["Y"]
73-
zi, _ = position["Z"]
74-
75-
xi = np.where(xi > len(field.grid.lon) - 2, 0, xi)
76-
xsi = (x - field.grid.lon[xi]) / (field.grid.lon[xi + 1] - field.grid.lon[xi])
77-
yi = np.where(yi > len(field.grid.lat) - 2, 0, yi)
78-
eta = (y - field.grid.lat[yi]) / (field.grid.lat[yi + 1] - field.grid.lat[yi])
79-
80-
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
81-
82-
data = field.data
83-
val = np.zeros_like(tau)
84-
85-
timeslices = [ti, ti + 1] if tau.any() > 0 else [ti]
86-
for tii, tau_factor in zip(timeslices, [1 - tau, tau], strict=False):
87-
xi = xr.DataArray(xi, dims="points")
88-
yi = xr.DataArray(yi, dims="points")
89-
zi = xr.DataArray(zi, dims="points")
90-
ti = xr.DataArray(tii, dims="points")
91-
F00 = data.isel({axis_dim["X"]: xi, axis_dim["Y"]: yi, axis_dim["Z"]: zi, "time": ti}).values.flatten()
92-
F10 = data.isel({axis_dim["X"]: xi + 1, axis_dim["Y"]: yi, axis_dim["Z"]: zi, "time": ti}).values.flatten()
93-
F01 = data.isel({axis_dim["X"]: xi, axis_dim["Y"]: yi + 1, axis_dim["Z"]: zi, "time": ti}).values.flatten()
94-
F11 = data.isel({axis_dim["X"]: xi + 1, axis_dim["Y"]: yi + 1, axis_dim["Z"]: zi, "time": ti}).values.flatten()
95-
val += (
96-
(1 - xsi) * (1 - eta) * F00 + xsi * (1 - eta) * F10 + (1 - xsi) * eta * F01 + xsi * eta * F11
97-
) * tau_factor
98-
return val
99-
100-
10159
def XTriLinear(
10260
field: Field,
10361
ti: int,

parcels/field.py

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,10 @@
2323
)
2424
from parcels.tools.statuscodes import (
2525
AllParcelsErrorCodes,
26-
FieldOutOfBoundError,
27-
FieldOutOfBoundSurfaceError,
28-
FieldSamplingError,
26+
StatusCode,
2927
)
3028
from parcels.uxgrid import UxGrid
31-
from parcels.xgrid import XGrid, _transpose_xfield_data_to_tzyx
29+
from parcels.xgrid import LEFT_OUT_OF_BOUNDS, RIGHT_OUT_OF_BOUNDS, XGrid, _transpose_xfield_data_to_tzyx
3230

3331
from ._index_search import _search_time_index
3432

@@ -257,19 +255,11 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
257255
# else:
258256
# _ei = particle.ei[self.igrid]
259257

260-
try:
261-
tau, ti = _search_time_index(self, time)
262-
position = self.grid.search(z, y, x, ei=_ei)
263-
value = self._interp_method(self, ti, position, tau, time, z, y, x)
264-
265-
# TODO fix outof bounds sampling
266-
# if np.isnan(value):
267-
# # Detect Out-of-bounds sampling and raise exception
268-
# _raise_field_out_of_bound_error(z, y, x)
258+
tau, ti = _search_time_index(self, time)
259+
position = self.grid.search(z, y, x, ei=_ei)
260+
_update_particle_states(particle, position)
269261

270-
except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e:
271-
e.add_note(f"Error interpolating field '{self.name}'.")
272-
raise e
262+
value = self._interp_method(self, ti, position, tau, time, z, y, x)
273263

274264
if applyConversion:
275265
value = self.units.to_target(value, z, y, x)
@@ -345,31 +335,28 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
345335
# else:
346336
# _ei = particle.ei[self.igrid]
347337

348-
try:
349-
tau, ti = _search_time_index(self.U, time)
350-
position = self.grid.search(z, y, x, ei=_ei)
351-
if self._vector_interp_method is None:
352-
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
353-
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
354-
if "3D" in self.vector_type:
355-
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
356-
else:
357-
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)
338+
tau, ti = _search_time_index(self.U, time)
339+
position = self.grid.search(z, y, x, ei=_ei)
340+
_update_particle_states(particle, position)
358341

359-
if applyConversion:
360-
u = self.U.units.to_target(u, z, y, x)
361-
v = self.V.units.to_target(v, z, y, x)
362-
if "3D" in self.vector_type:
363-
w = self.W.units.to_target(w, z, y, x) if self.W else 0.0
342+
if self._vector_interp_method is None:
343+
u = self.U._interp_method(self.U, ti, position, tau, time, z, y, x)
344+
v = self.V._interp_method(self.V, ti, position, tau, time, z, y, x)
345+
if "3D" in self.vector_type:
346+
w = self.W._interp_method(self.W, ti, position, tau, time, z, y, x)
347+
else:
348+
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)
364349

350+
if applyConversion:
351+
u = self.U.units.to_target(u, z, y, x)
352+
v = self.V.units.to_target(v, z, y, x)
365353
if "3D" in self.vector_type:
366-
return (u, v, w)
367-
else:
368-
return (u, v)
354+
w = self.W.units.to_target(w, z, y, x) if self.W else 0.0
369355

370-
except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e:
371-
e.add_note(f"Error interpolating field '{self.name}'.")
372-
raise e
356+
if "3D" in self.vector_type:
357+
return (u, v, w)
358+
else:
359+
return (u, v)
373360

374361
def __getitem__(self, key):
375362
try:
@@ -381,6 +368,17 @@ def __getitem__(self, key):
381368
return _deal_with_errors(error, key, vector_type=self.vector_type)
382369

383370

371+
def _update_particle_states(particle, position):
372+
"""Update the particle states based on the position dictionary."""
373+
if particle and "X" in position: # TODO also support uxgrid search
374+
particle.state = np.where(position["X"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
375+
particle.state = np.where(position["Y"][0] < 0, StatusCode.ErrorOutOfBounds, particle.state)
376+
particle.state = np.where(position["Z"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particle.state)
377+
particle.state = np.where(
378+
position["Z"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particle.state
379+
)
380+
381+
384382
def _assert_valid_uxdataarray(data: ux.UxDataArray):
385383
"""Verifies that all the required attributes are present in the xarray.DataArray or
386384
uxarray.UxDataArray object.

parcels/xgrid.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
_DEFAULT_XGCM_KWARGS = {"periodic": False}
2323

24+
LEFT_OUT_OF_BOUNDS = -2
25+
RIGHT_OUT_OF_BOUNDS = -1
26+
2427

2528
def get_cell_count_along_dim(axis: xgcm.Axis) -> int:
2629
first_coord = list(axis.coords.items())[0]
@@ -271,15 +274,6 @@ def search(self, z, y, x, ei=None):
271274
ds = self.xgcm_grid._ds
272275

273276
zi, zeta = _search_1d_array(ds.depth.values, z)
274-
# if any(zi == -1): # TODO throw error only for those particles where zi == -1
275-
# if any(zeta < 0):
276-
# raise FieldOutOfBoundError(
277-
# f"Depth {z} is out of bounds for the grid with depth values {ds.depth.values}."
278-
# )
279-
# elif any(zeta > 1):
280-
# raise FieldOutOfBoundSurfaceError(
281-
# f"Depth {z} is out of the surface for the grid with depth values {ds.depth.values}."
282-
# )
283277

284278
if ds.lon.ndim == 1:
285279
yi, eta = _search_1d_array(ds.lat.values, y)
@@ -477,7 +471,6 @@ def _search_1d_array(
477471
Searches for the particle location in a 1D array and returns barycentric coordinate along dimension.
478472
479473
Assumptions:
480-
- particle position x is within the bounds of the array
481474
- array is strictly monotonically increasing.
482475
483476
Parameters
@@ -489,9 +482,9 @@ def _search_1d_array(
489482
490483
Returns
491484
-------
492-
int
493-
Index of the element just before the position x in the array.
494-
float
485+
array of int
486+
Index of the element just before the position x in the array. Note that this index is -2 if the index is left out of bounds and -1 if the index is right out of bounds.
487+
array of float
495488
Barycentric coordinate.
496489
"""
497490
# TODO v4: We probably rework this to deal with 0D arrays before this point (as we already know field dimensionality)
@@ -500,11 +493,11 @@ def _search_1d_array(
500493
index = np.searchsorted(arr, x, side="right") - 1
501494
index_next = np.clip(index + 1, 1, len(arr) - 1) # Ensure we don't go out of bounds
502495

503-
bcoord = np.where(
504-
index <= len(arr) - 2,
505-
(x - arr[index]) / (arr[index_next] - arr[index]),
506-
np.nan, # If at the end of the array, we return np.nan
507-
)
496+
bcoord = (x - arr[index]) / (arr[index_next] - arr[index])
497+
498+
index = np.where(x < arr[0], LEFT_OUT_OF_BOUNDS, index)
499+
index = np.where(x >= arr[-1], RIGHT_OUT_OF_BOUNDS, index)
500+
508501
return np.atleast_1d(index), np.atleast_1d(bcoord)
509502

510503

tests/v4/test_advection.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from parcels._datasets.structured.generated import simple_UV_dataset
66
from parcels.application_kernels.advection import AdvectionEE, AdvectionRK4, AdvectionRK4_3D, AdvectionRK45
77
from parcels.application_kernels.advectiondiffusion import AdvectionDiffusionEM, AdvectionDiffusionM1
8-
from parcels.application_kernels.interpolation import XBiLinear, XBiLinearPeriodic, XTriLinear
8+
from parcels.application_kernels.interpolation import XBiLinear, XTriLinear
99
from parcels.field import Field, VectorField
1010
from parcels.fieldset import FieldSet
1111
from parcels.particle import Particle, Variable
@@ -46,8 +46,7 @@ def test_advection_zonal(mesh_type, npart=10):
4646

4747
def periodicBC(particle, fieldset, time):
4848
particle.total_dlon += particle.dlon
49-
particle.lon = np.fmod(particle.lon, fieldset.U.grid.lon[-1])
50-
particle.lat = np.fmod(particle.lat, fieldset.U.grid.lat[-1])
49+
particle.lon = np.fmod(particle.lon, 2)
5150

5251

5352
def test_advection_zonal_periodic():
@@ -56,9 +55,15 @@ def test_advection_zonal_periodic():
5655
ds["lon"].data = np.array([0, 2])
5756
ds["lat"].data = np.array([0, 2])
5857

58+
# add a halo
59+
halo = ds.isel(XG=0)
60+
halo.lon.values = ds.lon.values[1] + 1
61+
halo.XG.values = ds.XG.values[1] + 2
62+
ds = xr.concat([ds, halo], dim="XG")
63+
5964
grid = XGrid.from_dataset(ds)
60-
U = Field("U", ds["U"], grid, interp_method=XBiLinearPeriodic)
61-
V = Field("V", ds["V"], grid, interp_method=XBiLinearPeriodic)
65+
U = Field("U", ds["U"], grid, interp_method=XBiLinear)
66+
V = Field("V", ds["V"], grid, interp_method=XBiLinear)
6267
UV = VectorField("UV", U, V)
6368
fieldset = FieldSet([U, V, UV])
6469

@@ -95,7 +100,7 @@ def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
95100
ds = simple_UV_dataset(mesh_type="flat")
96101
grid = XGrid.from_dataset(ds)
97102
U = Field("U", ds["U"], grid, interp_method=XTriLinear)
98-
U.data[:] = 0.01 # Set U to 0 at the surface
103+
U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds)
99104
V = Field("V", ds["V"], grid, interp_method=XTriLinear)
100105
W = Field("W", ds["V"], grid, interp_method=XTriLinear) # Use V as W for testing
101106
W.data[:] = -1.0 if direction == "up" else 1.0
@@ -104,18 +109,22 @@ def test_advection_3D_outofbounds(direction, wErrorThroughSurface):
104109
fieldset = FieldSet([U, V, W, UVW, UV])
105110

106111
def DeleteParticle(particle, fieldset, time): # pragma: no cover
107-
if particle.state == StatusCode.ErrorOutOfBounds or particle.state == StatusCode.ErrorThroughSurface:
108-
particle.state = StatusCode.Delete
112+
particle.state = np.where(particle.state == StatusCode.ErrorOutOfBounds, StatusCode.Delete, particle.state)
113+
particle.state = np.where(particle.state == StatusCode.ErrorThroughSurface, StatusCode.Delete, particle.state)
109114

110115
def SubmergeParticle(particle, fieldset, time): # pragma: no cover
111-
if particle.state == StatusCode.ErrorThroughSurface:
112-
dt = particle.dt / np.timedelta64(1, "s")
113-
(u, v) = fieldset.UV[particle]
114-
particle.dlon = u * dt
115-
particle.dlat = v * dt
116-
particle.ddepth = 0.0
117-
particle.depth = 0
118-
particle.state = StatusCode.Evaluate
116+
if len(particle.state) == 0:
117+
return
118+
inds = np.argwhere(particle.state == StatusCode.ErrorThroughSurface).flatten()
119+
if len(inds) == 0:
120+
return
121+
dt = particle.dt / np.timedelta64(1, "s")
122+
(u, v) = fieldset.UV[particle[inds]]
123+
particle.dlon[inds] = u * dt
124+
particle.dlat[inds] = v * dt
125+
particle.ddepth[inds] = 0.0
126+
particle.depth[inds] = 0
127+
particle.state[inds] = StatusCode.Evaluate
119128

120129
kernels = [AdvectionRK4_3D]
121130
if wErrorThroughSurface:

tests/v4/test_particleset_execute.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,14 @@ def test_execution_fail_python_exception(fieldset, npart=10):
103103
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))
104104

105105
def PythonFail(particle, fieldset, time): # pragma: no cover
106-
if particle.time >= fieldset.time_interval.left + np.timedelta64(10, "s"):
106+
inds = np.argwhere(particle.time >= fieldset.time_interval.left + np.timedelta64(10, "s"))
107+
if inds.size > 0:
107108
raise RuntimeError("Enough is enough!")
108-
else:
109-
pass
110109

111110
with pytest.raises(RuntimeError):
112111
pset.execute(PythonFail, runtime=np.timedelta64(20, "s"), dt=np.timedelta64(2, "s"))
113112
assert len(pset) == npart
114-
assert pset.time[0] == fieldset.time_interval.left + np.timedelta64(10, "s")
115-
assert all([time == fieldset.time_interval.left + np.timedelta64(0, "s") for time in pset.time[1:]])
113+
assert all(pset.time == fieldset.time_interval.left + np.timedelta64(10, "s"))
116114

117115

118116
def test_uxstommelgyre_pset_execute():

0 commit comments

Comments
 (0)