Skip to content

Commit aeada5d

Browse files
committed
Add XGrid.get_axis_dim_mapping
1 parent b7347d3 commit aeada5d

File tree

3 files changed

+144
-6
lines changed

3 files changed

+144
-6
lines changed

parcels/application_kernels/interpolation.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,16 @@ def XTriCurviLinear(
3333
yi, eta = position["Y"]
3434
zi, zeta = position["Z"]
3535
data = field.data
36+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
3637

3738
return (
3839
(
39-
(1 - xsi) * (1 - eta) * data.isel(YG=yi, XG=xi)
40-
+ xsi * (1 - eta) * data.isel(YG=yi, XG=xi + 1)
41-
+ xsi * eta * data.isel(YG=yi + 1, XG=xi + 1)
42-
+ (1 - xsi) * eta * data.isel(YG=yi + 1, XG=xi)
40+
(1 - xsi) * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi})
41+
+ xsi * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi + 1})
42+
+ xsi * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi + 1})
43+
+ (1 - xsi) * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi})
4344
)
44-
.interp(time=t, ZG=zi + zeta)
45+
.interp(time=t, **{axis_dim["Z"]: zi + zeta})
4546
.values
4647
)
4748

@@ -57,10 +58,13 @@ def XTriRectiLinear(
5758
x: np.float32 | np.float64,
5859
):
5960
"""Trilinear interpolation on a rectilinear grid."""
61+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
62+
6063
xi, xsi = position["X"]
6164
yi, eta = position["Y"]
6265
zi, zeta = position["Z"]
63-
return field.data.interp(time=t, ZG=zi + zeta, YG=yi + eta, XG=xi + xsi).values
66+
kwargs = {axis_dim["X"]: xi + xsi, axis_dim["Y"]: yi + eta, axis_dim["Z"]: zi + zeta}
67+
return field.data.interp(time=t, **kwargs).values
6468

6569

6670
def UXPiecewiseConstantFace(

parcels/xgrid.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,38 @@ def _fpoint_info(self):
304304

305305
return axis_position_mapping
306306

307+
def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]:
308+
"""
309+
Maps xarray dimension names to their corresponding axis (X, Y, Z).
310+
311+
WARNING: This API is unstable and subject to change in future versions.
312+
313+
Parameters
314+
----------
315+
dims : list[str]
316+
List of xarray dimension names
317+
318+
Returns
319+
-------
320+
dict[_XGRID_AXES, str]
321+
Dictionary mapping axes (X, Y, Z) to their corresponding dimension names
322+
323+
Examples
324+
--------
325+
>>> grid.get_axis_dim_mapping(['time', 'lat', 'lon'])
326+
{'Y': 'lat', 'X': 'lon'}
327+
328+
Notes
329+
-----
330+
Only returns mappings for spatial axes (X, Y, Z) that are present in the grid.
331+
"""
332+
result = {}
333+
for dim in dims:
334+
axis = get_axis_from_dim_name(self.xgcm_grid.axes, dim)
335+
if axis in self.axes: # Only include spatial axes (X, Y, Z)
336+
result[cast(_XGRID_AXES, axis)] = dim
337+
return result
338+
307339

308340
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
309341
"""For a given dimension name in a grid, returns the direction axis it is on."""

tests/v4/test_interpolation.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import itertools
2+
3+
import numpy as np
4+
import xarray as xr
5+
6+
from parcels.application_kernels.interpolation import XTriCurviLinear
7+
from parcels.field import Field
8+
from parcels.xgcm import Grid
9+
from parcels.xgrid import XGrid
10+
11+
12+
def get_unit_square_ds():
13+
T, Z, Y, X = 2, 2, 2, 2
14+
TIME = xr.date_range("2000", "2001", T)
15+
16+
_, data_z, data_y, data_x = np.meshgrid(
17+
np.zeros(T),
18+
np.linspace(0, 1, Z),
19+
np.linspace(0, 1, Y),
20+
np.linspace(0, 1, X),
21+
indexing="ij",
22+
)
23+
24+
return xr.Dataset(
25+
{
26+
"0 to 1 in X": (["time", "ZG", "YG", "XG"], data_x),
27+
"0 to 1 in Y": (["time", "ZG", "YG", "XG"], data_y),
28+
"0 to 1 in Z": (["time", "ZG", "YG", "XG"], data_z),
29+
"0 to 1 in X (T-points)": (["time", "ZC", "YC", "XC"], data_x + 0.5),
30+
"0 to 1 in Y (T-points)": (["time", "ZC", "YC", "XC"], data_y + 0.5),
31+
"0 to 1 in Z (T-points)": (["time", "ZC", "YC", "XC"], data_z + 0.5),
32+
"0 to 1 in X (U velocity C-grid points)": (["time", "ZC", "YC", "XG"], data_x),
33+
"0 to 1 in Y (V velocity C-grid points)": (["time", "ZC", "YG", "XC"], data_y),
34+
},
35+
coords={
36+
"XG": (
37+
["XG"],
38+
np.arange(0, X),
39+
{"axis": "X", "c_grid_axis_shift": -0.5},
40+
),
41+
"XC": (["XC"], np.arange(0, X) + 0.5, {"axis": "X"}),
42+
"YG": (
43+
["YG"],
44+
np.arange(0, Y),
45+
{"axis": "Y", "c_grid_axis_shift": -0.5},
46+
),
47+
"YC": (
48+
["YC"],
49+
np.arange(0, Y) + 0.5,
50+
{"axis": "Y"},
51+
),
52+
"ZG": (
53+
["ZG"],
54+
np.arange(Z),
55+
{"axis": "Z", "c_grid_axis_shift": -0.5},
56+
),
57+
"ZC": (
58+
["ZC"],
59+
np.arange(Z) + 0.5,
60+
{"axis": "Z"},
61+
),
62+
"lon": (["XG"], np.arange(0, X)),
63+
"lat": (["YG"], np.arange(0, Y)),
64+
"depth": (["ZG"], np.arange(Z)),
65+
"time": (["time"], TIME, {"axis": "T"}),
66+
},
67+
)
68+
69+
70+
def test_XTriRectiLinear_interpolation():
71+
ds = get_unit_square_ds()
72+
grid = XGrid(Grid(ds))
73+
field = Field("test", ds["0 to 1 in X"], grid=grid, interp_method=XTriCurviLinear)
74+
left = field.time_interval.left
75+
76+
epsilon = 1e-6
77+
N = 4
78+
79+
# Interpolate wrt. items on f-points
80+
for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3):
81+
assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"
82+
83+
field = Field("test", ds["0 to 1 in Y"], grid=grid, interp_method=XTriCurviLinear)
84+
for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3):
85+
assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"
86+
87+
field = Field("test", ds["0 to 1 in Z"], grid=grid, interp_method=XTriCurviLinear)
88+
for x, y, z in itertools.product(np.linspace(0 + epsilon, 1 - epsilon, N), repeat=3):
89+
assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"
90+
91+
# Interpolate wrt. items on T-points
92+
field = Field("test", ds["0 to 1 in X (T-points)"], grid=grid, interp_method=XTriCurviLinear)
93+
for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3):
94+
assert np.isclose(x, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"
95+
96+
field = Field("test", ds["0 to 1 in Y (T-points)"], grid=grid, interp_method=XTriCurviLinear)
97+
for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3):
98+
assert np.isclose(y, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"
99+
100+
field = Field("test", ds["0 to 1 in Z (T-points)"], grid=grid, interp_method=XTriCurviLinear)
101+
for x, y, z in itertools.product(np.linspace(0.5 + epsilon, 1 - epsilon, N), repeat=3):
102+
assert np.isclose(z, field.eval(left, z, y, x)), f"Failed for x={x}, y={y}, z={z}"

0 commit comments

Comments
 (0)