Skip to content

Commit 1c8d5e5

Browse files
Merge pull request #2152 from OceanParcels/c-grid-interpolation
C grid interpolation for VectorFields
2 parents 4218098 + 38622fc commit 1c8d5e5

File tree

7 files changed

+473
-46
lines changed

7 files changed

+473
-46
lines changed

parcels/_index_search.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,16 @@ def _search_indices_curvilinear_2d(
8383
cc = a[1] * b[0] - a[0] * b[1] + x * b[1] - y * a[1]
8484

8585
det2 = bb * bb - 4 * aa * cc
86-
det = np.where(det2 > 0, np.sqrt(det2), eta)
87-
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))
88-
89-
xsi = np.where(
90-
abs(a[1] + a[3] * eta) < 1e-12,
91-
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
92-
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
93-
)
86+
with np.errstate(divide="ignore", invalid="ignore"):
87+
det = np.where(det2 > 0, np.sqrt(det2), eta)
88+
89+
eta = np.where(abs(aa) < 1e-12, -cc / bb, np.where(det2 > 0, (-bb + det) / (2 * aa), eta))
90+
91+
xsi = np.where(
92+
abs(a[1] + a[3] * eta) < 1e-12,
93+
((y - py[0]) / (py[1] - py[0]) + (y - py[3]) / (py[2] - py[3])) * 0.5,
94+
(x - a[0] - a[2] * eta) / (a[1] + a[3] * eta),
95+
)
9496

9597
xi = np.where(xsi < -tol, xi - 1, np.where(xsi > 1 + tol, xi + 1, xi))
9698
yi = np.where(eta < -tol, yi - 1, np.where(eta > 1 + tol, yi + 1, yi))

parcels/application_kernels/interpolation.py

Lines changed: 288 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,26 @@
44

55
from typing import TYPE_CHECKING
66

7-
import dask.array as dask
87
import numpy as np
98
import xarray as xr
9+
from dask import is_dask_collection
10+
11+
import parcels.tools.interpolation_utils as i_u
1012

1113
if TYPE_CHECKING:
12-
from parcels.field import Field
14+
from parcels.field import Field, VectorField
1315
from parcels.uxgrid import _UXGRID_AXES
1416
from parcels.xgrid import _XGRID_AXES
1517

1618
__all__ = [
19+
"CGrid_Tracer",
20+
"CGrid_Velocity",
1721
"UXPiecewiseConstantFace",
1822
"UXPiecewiseLinearNode",
1923
"XLinear",
2024
"XNearest",
2125
"ZeroInterpolator",
26+
"ZeroInterpolator_Vector",
2227
]
2328

2429

@@ -36,6 +41,21 @@ def ZeroInterpolator(
3641
return 0.0
3742

3843

44+
def ZeroInterpolator_Vector(
45+
vectorfield: VectorField,
46+
ti: int,
47+
position: dict[str, tuple[int, float | np.ndarray]],
48+
tau: np.float32 | np.float64,
49+
t: np.float32 | np.float64,
50+
z: np.float32 | np.float64,
51+
y: np.float32 | np.float64,
52+
x: np.float32 | np.float64,
53+
applyConversion: bool,
54+
) -> np.float32 | np.float64:
55+
"""Template function used for the signature check of the interpolation methods for velocity fields."""
56+
return 0.0
57+
58+
3959
def XLinear(
4060
field: Field,
4161
ti: int,
@@ -53,6 +73,7 @@ def XLinear(
5373

5474
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
5575
data = field.data
76+
tdim, zdim, ydim, xdim = data.shape[0], data.shape[1], data.shape[2], data.shape[3]
5677

5778
lenT = 2 if np.any(tau > 0) else 1
5879
lenZ = 2 if np.any(zeta > 0) else 1
@@ -61,22 +82,22 @@ def XLinear(
6182
if lenT == 1:
6283
ti = np.repeat(ti, lenZ * 4)
6384
else:
64-
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
85+
ti_1 = np.clip(ti + 1, 0, tdim - 1)
6586
ti = np.concatenate([np.repeat(ti, lenZ * 4), np.repeat(ti_1, lenZ * 4)])
6687

6788
# Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
6889
if lenZ == 1:
6990
zi = np.repeat(zi, lenT * 4)
7091
else:
71-
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
92+
zi_1 = np.clip(zi + 1, 0, zdim - 1)
7293
zi = np.tile(np.array([zi, zi, zi, zi, zi_1, zi_1, zi_1, zi_1]).flatten(), lenT)
7394

7495
# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
75-
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
96+
yi_1 = np.clip(yi + 1, 0, ydim - 1)
7697
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT) * (lenZ))
7798

7899
# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
79-
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
100+
xi_1 = np.clip(xi + 1, 0, xdim - 1)
80101
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))
81102

82103
# Create DataArrays for indexing
@@ -109,7 +130,266 @@ def XLinear(
109130
+ (1 - xsi) * eta * corner_data[:, 2]
110131
+ xsi * eta * corner_data[:, 3]
111132
)
112-
return value.compute() if isinstance(value, dask.Array) else value
133+
return value.compute() if is_dask_collection(value) else value
134+
135+
136+
def CGrid_Velocity(
137+
vectorfield: VectorField,
138+
ti: int,
139+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
140+
tau: np.float32 | np.float64,
141+
t: np.float32 | np.float64,
142+
z: np.float32 | np.float64,
143+
y: np.float32 | np.float64,
144+
x: np.float32 | np.float64,
145+
applyConversion: bool,
146+
):
147+
"""
148+
Interpolation kernel for velocity fields on a C-Grid.
149+
Following Delandmeter and Van Sebille (2019), velocity fields should be interpolated
150+
only in the direction of the grid cell faces.
151+
"""
152+
xi, xsi = position["X"]
153+
yi, eta = position["Y"]
154+
zi, zeta = position["Z"]
155+
156+
U = vectorfield.U.data
157+
V = vectorfield.V.data
158+
grid = vectorfield.grid
159+
tdim, zdim, ydim, xdim = U.shape[0], U.shape[1], U.shape[2], U.shape[3]
160+
161+
if grid.lon.ndim == 1:
162+
px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]])
163+
py = np.array([grid.lat[yi], grid.lat[yi], grid.lat[yi + 1], grid.lat[yi + 1]])
164+
else:
165+
px = np.array([grid.lon[yi, xi], grid.lon[yi, xi + 1], grid.lon[yi + 1, xi + 1], grid.lon[yi + 1, xi]])
166+
py = np.array([grid.lat[yi, xi], grid.lat[yi, xi + 1], grid.lat[yi + 1, xi + 1], grid.lat[yi + 1, xi]])
167+
168+
if grid._mesh == "spherical":
169+
px[0] = np.where(px[0] < x - 225, px[0] + 360, px[0])
170+
px[0] = np.where(px[0] > x + 225, px[0] - 360, px[0])
171+
px[1:] = np.where(px[1:] - px[0] > 180, px[1:] - 360, px[1:])
172+
px[1:] = np.where(-px[1:] + px[0] > 180, px[1:] + 360, px[1:])
173+
c1 = i_u._geodetic_distance(
174+
py[0], py[1], px[0], px[1], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(0.0, xsi), py)
175+
)
176+
c2 = i_u._geodetic_distance(
177+
py[1], py[2], px[1], px[2], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 1.0), py)
178+
)
179+
c3 = i_u._geodetic_distance(
180+
py[2], py[3], px[2], px[3], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(1.0, xsi), py)
181+
)
182+
c4 = i_u._geodetic_distance(
183+
py[3], py[0], px[3], px[0], grid._mesh, np.einsum("ij,ji->i", i_u.phi2D_lin(eta, 0.0), py)
184+
)
185+
186+
lenT = 2 if np.any(tau > 0) else 1
187+
188+
# Create arrays of corner points for xarray.isel
189+
# TODO C grid may not need all xi and yi cornerpoints, so could speed up here?
190+
191+
# Time coordinates: 4 points at ti, then 4 points at ti+1
192+
if lenT == 1:
193+
ti_full = np.repeat(ti, 4)
194+
else:
195+
ti_1 = np.clip(ti + 1, 0, tdim - 1)
196+
ti_full = np.concatenate([np.repeat(ti, 4), np.repeat(ti_1, 4)])
197+
198+
# Depth coordinates: 4 points at zi, repeated for both time levels
199+
zi_full = np.repeat(zi, lenT * 4)
200+
201+
# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
202+
yi_1 = np.clip(yi + 1, 0, ydim - 1)
203+
yi_full = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT))
204+
# # TODO check why in some cases minus needed here!!!
205+
# yi_minus_1 = np.clip(yi - 1, 0, ydim - 1)
206+
# yi = np.tile(np.repeat(np.column_stack([yi_minus_1, yi]), 2), (lenT))
207+
208+
# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
209+
xi_1 = np.clip(xi + 1, 0, xdim - 1)
210+
xi_full = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT))
211+
212+
for data in [U, V]:
213+
axis_dim = grid.get_axis_dim_mapping(data.dims)
214+
215+
# Create DataArrays for indexing
216+
selection_dict = {
217+
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
218+
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
219+
}
220+
if "Z" in axis_dim:
221+
selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points"))
222+
if "time" in data.dims:
223+
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))
224+
225+
corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi), 4)
226+
227+
if lenT == 2:
228+
tau_full = tau[:, np.newaxis]
229+
corner_data = corner_data[0, :, :] * (1 - tau_full) + corner_data[1, :, :] * tau_full
230+
else:
231+
corner_data = corner_data[0, :, :]
232+
# # See code below for v3 version
233+
# # if self.gridindexingtype == "nemo":
234+
# # U0 = self.U.data[ti, zi, yi + 1, xi] * c4
235+
# # U1 = self.U.data[ti, zi, yi + 1, xi + 1] * c2
236+
# # V0 = self.V.data[ti, zi, yi, xi + 1] * c1
237+
# # V1 = self.V.data[ti, zi, yi + 1, xi + 1] * c3
238+
# # elif self.gridindexingtype in ["mitgcm", "croco"]:
239+
# # U0 = self.U.data[ti, zi, yi, xi] * c4
240+
# # U1 = self.U.data[ti, zi, yi, xi + 1] * c2
241+
# # V0 = self.V.data[ti, zi, yi, xi] * c1
242+
# # V1 = self.V.data[ti, zi, yi + 1, xi] * c3
243+
# # TODO Nick can you help use xgcm to fix this implementation?
244+
245+
# # CROCO and MITgcm grid indexing,
246+
# if data is U:
247+
# U0 = corner_data[:, 0] * c4
248+
# U1 = corner_data[:, 1] * c2
249+
# elif data is V:
250+
# V0 = corner_data[:, 0] * c1
251+
# V1 = corner_data[:, 2] * c3
252+
# # NEMO grid indexing
253+
if data is U:
254+
U0 = corner_data[:, 2] * c4
255+
U1 = corner_data[:, 3] * c2
256+
elif data is V:
257+
V0 = corner_data[:, 1] * c1
258+
V1 = corner_data[:, 3] * c3
259+
260+
U = (1 - xsi) * U0 + xsi * U1
261+
V = (1 - eta) * V0 + eta * V1
262+
263+
deg2m = 1852 * 60.0
264+
if applyConversion:
265+
meshJac = (deg2m * deg2m * np.cos(np.deg2rad(y))) if grid._mesh == "spherical" else 1
266+
else:
267+
meshJac = deg2m if grid._mesh == "spherical" else 1
268+
269+
jac = i_u._compute_jacobian_determinant(py, px, eta, xsi) * meshJac
270+
271+
u = (
272+
(-(1 - eta) * U - (1 - xsi) * V) * px[0]
273+
+ ((1 - eta) * U - xsi * V) * px[1]
274+
+ (eta * U + xsi * V) * px[2]
275+
+ (-eta * U + (1 - xsi) * V) * px[3]
276+
) / jac
277+
v = (
278+
(-(1 - eta) * U - (1 - xsi) * V) * py[0]
279+
+ ((1 - eta) * U - xsi * V) * py[1]
280+
+ (eta * U + xsi * V) * py[2]
281+
+ (-eta * U + (1 - xsi) * V) * py[3]
282+
) / jac
283+
if is_dask_collection(u):
284+
u = u.compute()
285+
v = v.compute()
286+
287+
# check whether the grid conversion has been applied correctly
288+
xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3]
289+
u = np.where(np.abs((xx - x) / x) > 1e-4, np.nan, u)
290+
291+
if vectorfield.W:
292+
data = vectorfield.W.data
293+
# Time coordinates: 2 points at ti, then 2 points at ti+1
294+
if lenT == 1:
295+
ti_full = np.repeat(ti, 2)
296+
else:
297+
ti_1 = np.clip(ti + 1, 0, tdim - 1)
298+
ti_full = np.concatenate([np.repeat(ti, 2), np.repeat(ti_1, 2)])
299+
300+
# Depth coordinates: 1 points at zi, repeated for both time levels
301+
zi_1 = np.clip(zi + 1, 0, zdim - 1)
302+
zi_full = np.tile(np.array([zi, zi_1]).flatten(), lenT)
303+
304+
# Y coordinates: yi+1 for each spatial point, repeated for time/depth
305+
yi_1 = np.clip(yi + 1, 0, ydim - 1)
306+
yi_full = np.tile(yi_1, (lenT) * 2)
307+
308+
# X coordinates: xi+1 for each spatial point, repeated for time/depth
309+
xi_1 = np.clip(xi + 1, 0, xdim - 1)
310+
xi_full = np.tile(xi_1, (lenT) * 2)
311+
312+
axis_dim = grid.get_axis_dim_mapping(data.dims)
313+
314+
# Create DataArrays for indexing
315+
selection_dict = {
316+
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
317+
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
318+
axis_dim["Z"]: xr.DataArray(zi_full, dims=("points")),
319+
}
320+
if "time" in data.dims:
321+
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))
322+
323+
corner_data = data.isel(selection_dict).data.reshape(lenT, 2, len(xsi))
324+
325+
if lenT == 2:
326+
tau_full = tau[np.newaxis, :]
327+
corner_data = corner_data[0, :, :] * (1 - tau_full) + corner_data[1, :, :] * tau_full
328+
else:
329+
corner_data = corner_data[0, :, :]
330+
331+
w = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta
332+
if is_dask_collection(w):
333+
w = w.compute()
334+
else:
335+
w = np.zeros_like(u)
336+
337+
return (u, v, w)
338+
339+
340+
def CGrid_Tracer(
341+
field: Field,
342+
ti: int,
343+
position: dict[_XGRID_AXES, tuple[int, float | np.ndarray]],
344+
tau: np.float32 | np.float64,
345+
t: np.float32 | np.float64,
346+
z: np.float32 | np.float64,
347+
y: np.float32 | np.float64,
348+
x: np.float32 | np.float64,
349+
):
350+
"""Interpolation kernel for tracer fields on a C-Grid.
351+
352+
Following Delandmeter and Van Sebille (2019), tracer fields should be interpolated
353+
constant over the grid cell
354+
"""
355+
xi, _ = position["X"]
356+
yi, _ = position["Y"]
357+
zi, _ = position["Z"]
358+
359+
axis_dim = field.grid.get_axis_dim_mapping(field.data.dims)
360+
data = field.data
361+
362+
lenT = 2 if np.any(tau > 0) else 1
363+
364+
if lenT == 2:
365+
ti_1 = np.clip(ti + 1, 0, data.shape[0] - 1)
366+
ti = np.concatenate([np.repeat(ti), np.repeat(ti_1)])
367+
zi_1 = np.clip(zi + 1, 0, data.shape[1] - 1)
368+
zi = np.concatenate([np.repeat(zi), np.repeat(zi_1)])
369+
yi_1 = np.clip(yi + 1, 0, data.shape[2] - 1)
370+
yi = np.concatenate([np.repeat(yi), np.repeat(yi_1)])
371+
xi_1 = np.clip(xi + 1, 0, data.shape[3] - 1)
372+
xi = np.concatenate([np.repeat(xi), np.repeat(xi_1)])
373+
374+
# Create DataArrays for indexing
375+
selection_dict = {
376+
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
377+
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
378+
}
379+
if "Z" in axis_dim:
380+
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
381+
if "time" in field.data.dims:
382+
selection_dict["time"] = xr.DataArray(ti, dims=("points"))
383+
384+
value = data.isel(selection_dict).data.reshape(lenT, len(xi))
385+
386+
if lenT == 2:
387+
tau = tau[:, np.newaxis]
388+
value = value[0, :] * (1 - tau) + value[1, :] * tau
389+
else:
390+
value = value[0, :]
391+
392+
return value.compute() if is_dask_collection(value) else value
113393

114394

115395
def XNearest(
@@ -172,7 +452,7 @@ def XNearest(
172452
else:
173453
value = corner_data[0, :]
174454

175-
return value.compute() if isinstance(value, dask.Array) else value
455+
return value.compute() if is_dask_collection(value) else value
176456

177457

178458
def UXPiecewiseConstantFace(

0 commit comments

Comments
 (0)