Skip to content

Commit 6392d0e

Browse files
Adding XLinearInvdistLandTracer interpolator
Note that the interpolator now requires nested-for-loops. So either has to be vectorized or written with numba, for speed
1 parent d1ee511 commit 6392d0e

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

src/parcels/interpolators.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"UXPiecewiseLinearNode",
2222
"XFreeslip",
2323
"XLinear",
24+
"XLinearInvdistLandTracer",
2425
"XNearest",
2526
"XPartialslip",
2627
"ZeroInterpolator",
@@ -570,6 +571,52 @@ def XNearest(
570571
return value.compute() if is_dask_collection(value) else value
571572

572573

574+
def XLinearInvdistLandTracer(
575+
particle_positions: dict[str, float | np.ndarray],
576+
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],
577+
field: Field,
578+
):
579+
"""Linear spatial interpolation on a regular grid, where points on land are not used."""
580+
values = XLinear(particle_positions, grid_positions, field)
581+
582+
on_land = np.argwhere(np.isnan(values))
583+
584+
xi, xsi = grid_positions["X"]["index"], grid_positions["X"]["bcoord"]
585+
yi, eta = grid_positions["Y"]["index"], grid_positions["Y"]["bcoord"]
586+
zi, zeta = grid_positions["Z"]["index"], grid_positions["Z"]["bcoord"]
587+
ti, tau = grid_positions["T"]["index"], grid_positions["T"]["bcoord"]
588+
589+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
590+
lenT = 2 if np.any(tau > 0) else 1
591+
lenZ = 2 if np.any(zeta > 0) else 1
592+
593+
corner_data = _get_corner_data_Agrid(field.data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim)
594+
595+
def is_land(p: int):
596+
value = corner_data[:, :, :, :, p]
597+
return np.where(np.isnan(value), True, False)
598+
599+
for p in on_land:
600+
land = is_land(p)
601+
nb_land = np.sum(land)
602+
if nb_land == 4 * lenZ * lenT:
603+
values[p] = 0.0
604+
else:
605+
val = 0
606+
w_sum = 0
607+
for t in range(lenT):
608+
for k in range(lenZ):
609+
for j in range(2):
610+
for i in range(2):
611+
if land[t][k][j][i] == 0:
612+
distance = pow((eta[p] - j), 2) + pow((xsi[p] - i), 2)
613+
val += corner_data[t, k, j, i, p] / distance
614+
w_sum += 1 / distance
615+
values[p] = val / w_sum
616+
617+
return values.compute() if is_dask_collection(values) else values
618+
619+
573620
def UXPiecewiseConstantFace(
574621
particle_positions: dict[str, float | np.ndarray],
575622
grid_positions: dict[_XGRID_AXES, dict[str, int | float | np.ndarray]],

tests/test_interpolation.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
UXPiecewiseLinearNode,
2222
XFreeslip,
2323
XLinear,
24+
XLinearInvdistLandTracer,
2425
XNearest,
2526
XPartialslip,
2627
ZeroInterpolator,
@@ -80,6 +81,7 @@ def field():
8081
[1.49, 6.49, 13.99],
8182
id="Linear-3",
8283
),
84+
pytest.param(XLinearInvdistLandTracer, np.timedelta64(1, "s"), 2.5, 0.49, 0.51, 13.99, id="LinearInvDistLand"),
8385
pytest.param(
8486
XNearest,
8587
[np.timedelta64(0, "s"), np.timedelta64(3, "s")],
@@ -131,6 +133,38 @@ def test_spatial_slip_interpolation(field, func, t, z, y, x, expected):
131133
np.testing.assert_array_almost_equal(velocities, expected)
132134

133135

136+
@pytest.mark.parametrize(
137+
"func, t, z, y, x, expected",
138+
[
139+
(XLinearInvdistLandTracer, np.timedelta64(1, "s"), 0, 0.5, 0.5, 1.0),
140+
(XLinearInvdistLandTracer, np.timedelta64(1, "s"), 0, 1.5, 1.5, 0.0),
141+
(
142+
XLinearInvdistLandTracer,
143+
[np.timedelta64(0, "s"), np.timedelta64(1, "s")],
144+
[0, 2],
145+
[0.5, 0.5],
146+
[0.5, 0.5],
147+
1.0,
148+
),
149+
(
150+
XLinearInvdistLandTracer,
151+
[np.timedelta64(0, "s"), np.timedelta64(1, "s")],
152+
[0, 2],
153+
[0.5, 1.5],
154+
[0.5, 1.5],
155+
[1.0, 0.0],
156+
),
157+
],
158+
)
159+
def test_invdistland_interpolation(field, func, t, z, y, x, expected):
160+
field.data[:] = 1.0
161+
field.data[:, :, 1:3, 1:3] = np.nan # Set NaN land value to test inv_dist
162+
field.interp_method = func
163+
164+
value = field[t, z, y, x]
165+
np.testing.assert_array_almost_equal(value, expected)
166+
167+
134168
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
135169
def test_interpolation_mesh_type(mesh, npart=10):
136170
ds = simple_UV_dataset(mesh=mesh)

0 commit comments

Comments
 (0)