Skip to content

Commit 92a3b75

Browse files
Merge pull request #2165 from OceanParcels/spatial_slip_interpolation
Spatial slip interpolation
2 parents 1c8d5e5 + 1d44cdb commit 92a3b75

File tree

2 files changed

+228
-29
lines changed

2 files changed

+228
-29
lines changed

parcels/application_kernels/interpolation.py

Lines changed: 190 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@
2020
"CGrid_Velocity",
2121
"UXPiecewiseConstantFace",
2222
"UXPiecewiseLinearNode",
23+
"XFreeslip",
2324
"XLinear",
2425
"XNearest",
26+
"XPartialslip",
2527
"ZeroInterpolator",
2628
"ZeroInterpolator_Vector",
2729
]
@@ -30,7 +32,7 @@
3032
def ZeroInterpolator(
3133
field: Field,
3234
ti: int,
33-
position: dict[str, tuple[int, float | np.ndarray]],
35+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
3436
tau: np.float32 | np.float64,
3537
t: np.float32 | np.float64,
3638
z: np.float32 | np.float64,
@@ -44,7 +46,7 @@ def ZeroInterpolator(
4446
def ZeroInterpolator_Vector(
4547
vectorfield: VectorField,
4648
ti: int,
47-
position: dict[str, tuple[int, float | np.ndarray]],
49+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
4850
tau: np.float32 | np.float64,
4951
t: np.float32 | np.float64,
5052
z: np.float32 | np.float64,
@@ -56,48 +58,38 @@ def ZeroInterpolator_Vector(
5658
return 0.0
5759

5860

59-
def XLinear(
60-
field: Field,
61+
def _get_corner_data_Agrid(
62+
data: np.ndarray | xr.DataArray,
6163
ti: int,
62-
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
63-
tau: np.float32 | np.float64,
64-
t: np.float32 | np.float64,
65-
z: np.float32 | np.float64,
66-
y: np.float32 | np.float64,
67-
x: np.float32 | np.float64,
68-
):
69-
"""Trilinear interpolation on a regular grid."""
70-
xi, xsi = position["X"]
71-
yi, eta = position["Y"]
72-
zi, zeta = position["Z"]
73-
74-
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
75-
data = field.data
76-
tdim, zdim, ydim, xdim = data.shape[0], data.shape[1], data.shape[2], data.shape[3]
77-
78-
lenT = 2 if np.any(tau > 0) else 1
79-
lenZ = 2 if np.any(zeta > 0) else 1
80-
64+
zi: int,
65+
yi: int,
66+
xi: int,
67+
lenT: int,
68+
lenZ: int,
69+
npart: int,
70+
axis_dim: dict[str, str],
71+
) -> np.ndarray:
72+
"""Helper function to get the corner data for a given A-grid field and position."""
8173
# Time coordinates: 8 points at ti, then 8 points at ti+1
8274
if lenT == 1:
8375
ti = np.repeat(ti, lenZ * 4)
8476
else:
85-
ti_1 = np.clip(ti + 1, 0, tdim - 1)
77+
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
8678
ti = np.concatenate([np.repeat(ti, lenZ * 4), np.repeat(ti_1, lenZ * 4)])
8779

8880
# Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
8981
if lenZ == 1:
9082
zi = np.repeat(zi, lenT * 4)
9183
else:
92-
zi_1 = np.clip(zi + 1, 0, zdim - 1)
84+
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
9385
zi = np.tile(np.array([zi, zi, zi, zi, zi_1, zi_1, zi_1, zi_1]).flatten(), lenT)
9486

9587
# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
96-
yi_1 = np.clip(yi + 1, 0, ydim - 1)
88+
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
9789
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT) * (lenZ))
9890

9991
# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
100-
xi_1 = np.clip(xi + 1, 0, xdim - 1)
92+
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
10193
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))
10294

10395
# Create DataArrays for indexing
@@ -110,7 +102,31 @@ def XLinear(
110102
if "time" in data.dims:
111103
selection_dict["time"] = xr.DataArray(ti, dims=("points"))
112104

113-
corner_data = data.isel(selection_dict).data.reshape(lenT, lenZ, len(xsi), 4)
105+
return data.isel(selection_dict).data.reshape(lenT, lenZ, npart, 4)
106+
107+
108+
def XLinear(
109+
field: Field,
110+
ti: int,
111+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
112+
tau: np.float32 | np.float64,
113+
t: np.float32 | np.float64,
114+
z: np.float32 | np.float64,
115+
y: np.float32 | np.float64,
116+
x: np.float32 | np.float64,
117+
):
118+
"""Trilinear interpolation on a regular grid."""
119+
xi, xsi = position["X"]
120+
yi, eta = position["Y"]
121+
zi, zeta = position["Z"]
122+
123+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
124+
data = field.data
125+
126+
lenT = 2 if np.any(tau > 0) else 1
127+
lenZ = 2 if np.any(zeta > 0) else 1
128+
129+
corner_data = _get_corner_data_Agrid(data, ti, zi, yi, xi, lenT, lenZ, len(xsi), axis_dim)
114130

115131
if lenT == 2:
116132
tau = tau[np.newaxis, :, np.newaxis]
@@ -392,6 +408,152 @@ def CGrid_Tracer(
392408
return value.compute() if is_dask_collection(value) else value
393409

394410

411+
def _Spatialslip(
412+
vectorfield: VectorField,
413+
ti: int,
414+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
415+
tau: np.float32 | np.float64,
416+
t: np.float32 | np.float64,
417+
z: np.float32 | np.float64,
418+
y: np.float32 | np.float64,
419+
x: np.float32 | np.float64,
420+
a: np.float32,
421+
b: np.float32,
422+
):
423+
"""Helper function for spatial boundary condition interpolation for velocity fields."""
424+
xi, xsi = position["X"]
425+
yi, eta = position["Y"]
426+
zi, zeta = position["Z"]
427+
428+
axis_dim = vectorfield.U.grid.get_axis_dim_mapping(vectorfield.U.data.dims)
429+
lenT = 2 if np.any(tau > 0) else 1
430+
lenZ = 2 if np.any(zeta > 0) else 1
431+
npart = len(xsi)
432+
433+
u = XLinear(vectorfield.U, ti, position, tau, t, z, y, x)
434+
v = XLinear(vectorfield.V, ti, position, tau, t, z, y, x)
435+
if vectorfield.W:
436+
w = XLinear(vectorfield.W, ti, position, tau, t, z, y, x)
437+
438+
corner_dataU = _get_corner_data_Agrid(vectorfield.U.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim)
439+
corner_dataV = _get_corner_data_Agrid(vectorfield.V.data, ti, zi, yi, xi, lenT, lenZ, npart, axis_dim)
440+
441+
def is_land(ti: int, zi: int, yi: int, xi: int):
442+
uval = corner_dataU[ti, zi, :, xi + 2 * yi]
443+
vval = corner_dataV[ti, zi, :, xi + 2 * yi]
444+
return np.where(np.isclose(uval, 0.0) & np.isclose(vval, 0.0), True, False)
445+
446+
f_u = np.ones_like(xsi)
447+
f_v = np.ones_like(eta)
448+
449+
if lenZ == 1:
450+
f_u = np.where(is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & (eta > 0), f_u * (a + b * eta) / eta, f_u)
451+
f_u = np.where(is_land(0, 0, 1, 0) & is_land(0, 0, 1, 1) & (eta < 1), f_u * (1 - b * eta) / (1 - eta), f_u)
452+
f_v = np.where(is_land(0, 0, 0, 0) & is_land(0, 0, 1, 0) & (xsi > 0), f_v * (a + b * xsi) / xsi, f_v)
453+
f_v = np.where(is_land(0, 0, 0, 1) & is_land(0, 0, 1, 1) & (xsi < 1), f_v * (1 - b * xsi) / (1 - xsi), f_v)
454+
else:
455+
f_u = np.where(
456+
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & (eta > 0),
457+
f_u * (a + b * eta) / eta,
458+
f_u,
459+
)
460+
f_u = np.where(
461+
is_land(0, 0, 1, 0) & is_land(0, 0, 1, 1) & is_land(0, 1, 1, 0) & is_land(0, 1, 1, 1) & (eta < 1),
462+
f_u * (1 - b * eta) / (1 - eta),
463+
f_u,
464+
)
465+
f_v = np.where(
466+
is_land(0, 0, 0, 0) & is_land(0, 0, 1, 0) & is_land(0, 1, 0, 0) & is_land(0, 1, 1, 0) & (xsi > 0),
467+
f_v * (a + b * xsi) / xsi,
468+
f_v,
469+
)
470+
f_v = np.where(
471+
is_land(0, 0, 0, 1) & is_land(0, 0, 1, 1) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 1) & (xsi < 1),
472+
f_v * (1 - b * xsi) / (1 - xsi),
473+
f_v,
474+
)
475+
f_u = np.where(
476+
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 0, 1, 0 & is_land(0, 0, 1, 1) & (zeta > 0)),
477+
f_u * (a + b * zeta) / zeta,
478+
f_u,
479+
)
480+
f_u = np.where(
481+
is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 0 & is_land(0, 1, 1, 1) & (zeta < 1)),
482+
f_u * (1 - b * zeta) / (1 - zeta),
483+
f_u,
484+
)
485+
f_v = np.where(
486+
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 0, 1, 0 & is_land(0, 0, 1, 1) & (zeta > 0)),
487+
f_v * (a + b * zeta) / zeta,
488+
f_v,
489+
)
490+
f_v = np.where(
491+
is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 0 & is_land(0, 1, 1, 1) & (zeta < 1)),
492+
f_v * (1 - b * zeta) / (1 - zeta),
493+
f_v,
494+
)
495+
496+
u *= f_u
497+
v *= f_v
498+
if vectorfield.W:
499+
f_w = np.ones_like(zeta)
500+
f_w = np.where(
501+
is_land(0, 0, 0, 0) & is_land(0, 0, 0, 1) & is_land(0, 1, 0, 0) & is_land(0, 1, 0, 1) & (eta > 0),
502+
f_w * (a + b * eta) / eta,
503+
f_w,
504+
)
505+
f_w = np.where(
506+
is_land(0, 0, 1, 0) & is_land(0, 0, 1, 1) & is_land(0, 1, 1, 0) & is_land(0, 1, 1, 1) & (eta < 1),
507+
f_w * (a - b * eta) / (1 - eta),
508+
f_w,
509+
)
510+
f_w = np.where(
511+
is_land(0, 0, 0, 0) & is_land(0, 0, 1, 0) & is_land(0, 1, 0, 0) & is_land(0, 1, 1, 0) & (xsi > 0),
512+
f_w * (a + b * xsi) / xsi,
513+
f_w,
514+
)
515+
f_w = np.where(
516+
is_land(0, 0, 0, 1) & is_land(0, 0, 1, 1) & is_land(0, 1, 0, 1) & is_land(0, 1, 1, 1) & (xsi < 1),
517+
f_w * (a - b * xsi) / (1 - xsi),
518+
f_w,
519+
)
520+
521+
w *= f_w
522+
else:
523+
w = None
524+
return u, v, w
525+
526+
527+
def XFreeslip(
528+
vectorfield: VectorField,
529+
ti: int,
530+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
531+
tau: np.float32 | np.float64,
532+
t: np.float32 | np.float64,
533+
z: np.float32 | np.float64,
534+
y: np.float32 | np.float64,
535+
x: np.float32 | np.float64,
536+
applyConversion: bool,
537+
):
538+
"""Free-slip boundary condition interpolation for velocity fields."""
539+
return _Spatialslip(vectorfield, ti, position, tau, t, z, y, x, a=1.0, b=0.0)
540+
541+
542+
def XPartialslip(
543+
vectorfield: VectorField,
544+
ti: int,
545+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
546+
tau: np.float32 | np.float64,
547+
t: np.float32 | np.float64,
548+
z: np.float32 | np.float64,
549+
y: np.float32 | np.float64,
550+
x: np.float32 | np.float64,
551+
applyConversion: bool,
552+
):
553+
"""Partial-slip boundary condition interpolation for velocity fields."""
554+
return _Spatialslip(vectorfield, ti, position, tau, t, z, y, x, a=0.5, b=0.5)
555+
556+
395557
def XNearest(
396558
field: Field,
397559
ti: int,

tests/v4/test_interpolation.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
from parcels._datasets.unstructured.generic import datasets as datasets_unstructured
77
from parcels._index_search import _search_time_index
88
from parcels.application_kernels.advection import AdvectionRK4_3D
9-
from parcels.application_kernels.interpolation import UXPiecewiseLinearNode, XLinear, XNearest, ZeroInterpolator
9+
from parcels.application_kernels.interpolation import (
10+
UXPiecewiseLinearNode,
11+
XFreeslip,
12+
XLinear,
13+
XNearest,
14+
XPartialslip,
15+
ZeroInterpolator,
16+
)
1017
from parcels.field import Field, VectorField
1118
from parcels.fieldset import FieldSet
1219
from parcels.particle import Particle, Variable
@@ -80,6 +87,36 @@ def test_raw_2d_interpolation(field, func, t, z, y, x, expected):
8087
np.testing.assert_equal(value, expected)
8188

8289

90+
@pytest.mark.parametrize(
91+
"func, t, z, y, x, expected",
92+
[
93+
(XPartialslip, np.timedelta64(1, "s"), 0, 0, 0.0, [[1], [1]]),
94+
(XFreeslip, np.timedelta64(1, "s"), 0, 0.5, 1.5, [[1], [0.5]]),
95+
(XPartialslip, np.timedelta64(1, "s"), 0, 2.5, 1.5, [[0.75], [0.5]]),
96+
(XFreeslip, np.timedelta64(1, "s"), 0, 2.5, 1.5, [[1], [0.5]]),
97+
(XPartialslip, np.timedelta64(1, "s"), 0, 1.5, 0.5, [[0.5], [0.75]]),
98+
(XFreeslip, np.timedelta64(1, "s"), 0, 1.5, 0.5, [[0.5], [1]]),
99+
(
100+
XFreeslip,
101+
[np.timedelta64(1, "s"), np.timedelta64(0, "s")],
102+
[0, 2],
103+
[1.5, 1.5],
104+
[2.5, 0.5],
105+
[[0.5, 0.5], [1, 1]],
106+
),
107+
],
108+
)
109+
def test_spatial_slip_interpolation(field, func, t, z, y, x, expected):
110+
field.data[:] = 1.0
111+
field.data[:, :, 1:3, 1:3] = 0.0 # Set zero land value to test spatial slip
112+
U = field
113+
V = field
114+
UV = VectorField("UV", U, V, vector_interp_method=func)
115+
116+
velocities = UV[t, z, y, x]
117+
np.testing.assert_array_almost_equal(velocities, expected)
118+
119+
83120
@pytest.mark.parametrize("mesh", ["spherical", "flat"])
84121
def test_interpolation_mesh_type(mesh, npart=10):
85122
ds = simple_UV_dataset(mesh=mesh)

0 commit comments

Comments
 (0)