Skip to content

Commit 251219d

Browse files
committed
Add XGrid.get_axis_dim_mapping
1 parent 74b6b84 commit 251219d

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

parcels/application_kernels/interpolation.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,16 @@ def XTriCurviLinear(
3333
yi, eta = position["Y"]
3434
zi, zeta = position["Z"]
3535
data = field.data
36+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
3637

3738
return (
3839
(
39-
(1 - xsi) * (1 - eta) * data.isel(YG=yi, XG=xi)
40-
+ xsi * (1 - eta) * data.isel(YG=yi, XG=xi + 1)
41-
+ xsi * eta * data.isel(YG=yi + 1, XG=xi + 1)
42-
+ (1 - xsi) * eta * data.isel(YG=yi + 1, XG=xi)
40+
(1 - xsi) * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi})
41+
+ xsi * (1 - eta) * data.isel({axis_dim["Y"]: yi, axis_dim["X"]: xi + 1})
42+
+ xsi * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi + 1})
43+
+ (1 - xsi) * eta * data.isel({axis_dim["Y"]: yi + 1, axis_dim["X"]: xi})
4344
)
44-
.interp(time=t, ZG=zi + zeta)
45+
.interp(time=t, **{axis_dim["Z"]: zi + zeta})
4546
.values
4647
)
4748

@@ -57,10 +58,13 @@ def XTriRectiLinear(
5758
x: np.float32 | np.float64,
5859
):
5960
"""Trilinear interpolation on a rectilinear grid."""
61+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
62+
6063
xi, xsi = position["X"]
6164
yi, eta = position["Y"]
6265
zi, zeta = position["Z"]
63-
return field.data.interp(time=t, ZG=zi + zeta, YG=yi + eta, XG=xi + xsi).values
66+
kwargs = {axis_dim["X"]: xi + xsi, axis_dim["Y"]: yi + eta, axis_dim["Z"]: zi + zeta}
67+
return field.data.interp(time=t, **kwargs).values
6468

6569

6670
def UXPiecewiseConstantFace(

parcels/xgrid.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,38 @@ def search(self, z, y, x, ei=None):
185185

186186
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
187187

188+
def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]:
189+
"""
190+
Maps xarray dimension names to their corresponding axis (X, Y, Z).
191+
192+
WARNING: This API is unstable and subject to change in future versions.
193+
194+
Parameters
195+
----------
196+
dims : list[str]
197+
List of xarray dimension names
198+
199+
Returns
200+
-------
201+
dict[_XGRID_AXES, str]
202+
Dictionary mapping axes (X, Y, Z) to their corresponding dimension names
203+
204+
Examples
205+
--------
206+
>>> grid.get_axis_dim_mapping(['time', 'lat', 'lon'])
207+
{'Y': 'lat', 'X': 'lon'}
208+
209+
Notes
210+
-----
211+
Only returns mappings for spatial axes (X, Y, Z) that are present in the grid.
212+
"""
213+
result = {}
214+
for dim in dims:
215+
axis = get_axis_from_dim_name(self.xgcm_grid.axes, dim)
216+
if axis in self.axes: # Only include spatial axes (X, Y, Z)
217+
result[cast(_XGRID_AXES, axis)] = dim
218+
return result
219+
188220

189221
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
190222
"""For a given dimension name in a grid, returns the direction axis it is on."""

0 commit comments

Comments
 (0)