Skip to content

Commit 196c6e7

Browse files
Fixing W interpolation for CGrid
1 parent 76d015b commit 196c6e7

File tree

2 files changed

+99
-20
lines changed

2 files changed

+99
-20
lines changed

parcels/application_kernels/interpolation.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def CGrid_Velocity(
139139
U = vectorfield.U.data
140140
V = vectorfield.V.data
141141
grid = vectorfield.grid
142-
tdim, ydim, xdim = U.shape[0], U.shape[2], U.shape[3]
142+
tdim, zdim, ydim, xdim = U.shape[0], U.shape[1], U.shape[2], U.shape[3]
143143

144144
if grid.lon.ndim == 1:
145145
px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]])
@@ -171,39 +171,39 @@ def CGrid_Velocity(
171171
# Create arrays of corner points for xarray.isel
172172
# TODO C grid may not need all xi and yi cornerpoints, so could speed up here?
173173

174-
# Time coordinates: 8 points at ti, then 8 points at ti+1
174+
# Time coordinates: 4 points at ti, then 4 points at ti+1
175175
if lenT == 1:
176-
ti = np.repeat(ti, 4)
176+
ti_full = np.repeat(ti, 4)
177177
else:
178178
ti_1 = np.clip(ti + 1, 0, tdim - 1)
179-
ti = np.concatenate([np.repeat(ti, 4), np.repeat(ti_1, 4)])
179+
ti_full = np.concatenate([np.repeat(ti, 4), np.repeat(ti_1, 4)])
180180

181181
# Depth coordinates: 4 points at zi, repeated for both time levels
182-
zi = np.repeat(zi, lenT * 4)
182+
zi_full = np.repeat(zi, lenT * 4)
183183

184184
# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
185185
yi_1 = np.clip(yi + 1, 0, ydim - 1)
186-
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT))
186+
yi_full = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT))
187187
# # TODO check why in some cases minus needed here!!!
188188
# yi_minus_1 = np.clip(yi - 1, 0, ydim - 1)
189189
# yi = np.tile(np.repeat(np.column_stack([yi_minus_1, yi]), 2), (lenT))
190190

191191
# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
192192
xi_1 = np.clip(xi + 1, 0, xdim - 1)
193-
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT))
193+
xi_full = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT))
194194

195195
for data in [U, V]:
196196
axis_dim = grid.get_axis_dim_mapping(data.dims)
197197

198198
# Create DataArrays for indexing
199199
selection_dict = {
200-
axis_dim["X"]: xr.DataArray(xi, dims=("points")),
201-
axis_dim["Y"]: xr.DataArray(yi, dims=("points")),
200+
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
201+
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
202202
}
203203
if "Z" in axis_dim:
204-
selection_dict[axis_dim["Z"]] = xr.DataArray(zi, dims=("points"))
204+
selection_dict[axis_dim["Z"]] = xr.DataArray(zi_full, dims=("points"))
205205
if "time" in data.dims:
206-
selection_dict["time"] = xr.DataArray(ti, dims=("points"))
206+
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))
207207

208208
corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi), 4)
209209

@@ -271,7 +271,53 @@ def CGrid_Velocity(
271271
xx = (1 - xsi) * (1 - eta) * px[0] + xsi * (1 - eta) * px[1] + xsi * eta * px[2] + (1 - xsi) * eta * px[3]
272272
u = np.where(np.abs(xx - x) > 1e-4, np.nan, u)
273273

274-
return (u, v, 0) # TODO fix and test W also
274+
if vectorfield.W:
275+
data = vectorfield.W.data
276+
# Time coordinates: 2 points at ti, then 2 points at ti+1
277+
if lenT == 1:
278+
ti_full = np.repeat(ti, 2)
279+
else:
280+
ti_1 = np.clip(ti + 1, 0, tdim - 1)
281+
ti_full = np.concatenate([np.repeat(ti, 2), np.repeat(ti_1, 2)])
282+
283+
# Depth coordinates: 1 points at zi, repeated for both time levels
284+
zi_1 = np.clip(zi + 1, 0, zdim - 1)
285+
zi_full = np.tile(np.array([zi, zi_1]).flatten(), lenT)
286+
287+
# Y coordinates: yi+1 for each spatial point, repeated for time/depth
288+
yi_1 = np.clip(yi + 1, 0, ydim - 1)
289+
yi_full = np.tile(yi_1, (lenT) * 2)
290+
291+
# X coordinates: xi+1 for each spatial point, repeated for time/depth
292+
xi_1 = np.clip(xi + 1, 0, xdim - 1)
293+
xi_full = np.tile(xi_1, (lenT) * 2)
294+
295+
axis_dim = grid.get_axis_dim_mapping(data.dims)
296+
297+
# Create DataArrays for indexing
298+
selection_dict = {
299+
axis_dim["X"]: xr.DataArray(xi_full, dims=("points")),
300+
axis_dim["Y"]: xr.DataArray(yi_full, dims=("points")),
301+
axis_dim["Z"]: xr.DataArray(zi_full, dims=("points")),
302+
}
303+
if "time" in data.dims:
304+
selection_dict["time"] = xr.DataArray(ti_full, dims=("points"))
305+
306+
corner_data = data.isel(selection_dict).data.reshape(lenT, 2, len(xsi))
307+
308+
if lenT == 2:
309+
tau_full = tau[np.newaxis, :]
310+
corner_data = corner_data[0, :, :] * (1 - tau_full) + corner_data[1, :, :] * tau_full
311+
else:
312+
corner_data = corner_data[0, :, :]
313+
314+
w = corner_data[0, :] * (1 - zeta) + corner_data[1, :] * zeta
315+
if isinstance(w, dask.Array):
316+
w = w.compute()
317+
else:
318+
w = np.zeros_like(u)
319+
320+
return (u, v, w)
275321

276322

277323
def CGrid_Tracer(

tests/v4/test_advection.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from parcels.particleset import ParticleSet
2121
from parcels.tools.statuscodes import StatusCode
2222
from parcels.xgrid import XGrid
23+
from tests.utils import round_and_hash_float_array
2324

2425
kernel = {
2526
"EE": AdvectionEE,
@@ -444,7 +445,8 @@ def periodicBC(particle, fieldSet, time): # pragma: no cover
444445
np.testing.assert_allclose(pset.lat_nextloop, latp, atol=1e-1)
445446

446447

447-
def test_nemo_3D_curvilinear_fieldset():
448+
@pytest.mark.parametrize("method", ["RK4", "RK4_3D"])
449+
def test_nemo_3D_curvilinear_fieldset(method):
448450
download_dir = parcels.download_example_dataset("NemoNorthSeaORCA025-N006_data")
449451
ufiles = download_dir.glob("*U.nc")
450452
dsu = xr.open_mfdataset(ufiles, decode_times=False, drop_variables=["nav_lat", "nav_lon"])
@@ -466,10 +468,34 @@ def test_nemo_3D_curvilinear_fieldset():
466468

467469
coord_file = f"{download_dir}/coordinates.nc"
468470
dscoord = xr.open_dataset(coord_file, decode_times=False).rename({"glamf": "lon", "gphif": "lat"})
471+
dscoord = dscoord.isel(time=0, drop=True)
469472

470473
ds = xr.merge([dsu, dsv, dsw, dscoord])
474+
ds = ds.drop_vars(
475+
[
476+
"uos",
477+
"vos",
478+
"nav_lev",
479+
"nav_lon",
480+
"nav_lat",
481+
"tauvo",
482+
"tauuo",
483+
"time_steps",
484+
"gphiu",
485+
"gphiv",
486+
"gphit",
487+
"glamu",
488+
"glamv",
489+
"glamt",
490+
"time_centered_bounds",
491+
"time_counter_bounds",
492+
"time_centered",
493+
]
494+
)
495+
ds = ds.drop_vars(["e1f", "e1t", "e1u", "e1v", "e2f", "e2t", "e2u", "e2v"])
496+
ds["time"] = [np.timedelta64(int(t), "s") + np.datetime64("1900-01-01") for t in ds["time"]]
471497

472-
ds.load() # TODO remove for debug
498+
ds["W"] *= -1 # Invert W velocity
473499

474500
xgcm_grid = parcels.xgcm.Grid(
475501
ds,
@@ -488,14 +514,21 @@ def test_nemo_3D_curvilinear_fieldset():
488514
W = parcels.Field("W", ds["W"], grid, mesh_type="spherical")
489515
U.units = parcels.GeographicPolar()
490516
V.units = parcels.Geographic()
491-
W.units = parcels.Geographic()
492-
UV = parcels.VectorField("UV", U, V, vector_interp_method=CGrid_Velocity) # TODO remove
517+
UV = parcels.VectorField("UV", U, V, vector_interp_method=CGrid_Velocity)
493518
UVW = parcels.VectorField("UVW", U, V, W, vector_interp_method=CGrid_Velocity)
494519
fieldset = parcels.FieldSet([U, V, W, UV, UVW])
495520

496-
lons = np.linspace(1.9, 3.4, 10)
497-
lats = np.linspace(52.5, 51.6, 10)
521+
npart = 10
522+
lons = np.linspace(1.9, 3.4, npart)
523+
lats = np.linspace(52.5, 51.6, npart)
498524
pset = parcels.ParticleSet(fieldset, lon=lons, lat=lats, depth=np.ones_like(lons))
499525

500-
pset.execute(parcels.AdvectionRK4, runtime=np.timedelta64(4, "D"), dt=np.timedelta64(6, "h"))
501-
print(pset.depth)
526+
pset.execute(kernel[method], runtime=np.timedelta64(4, "D"), dt=np.timedelta64(6, "h"))
527+
528+
if method == "RK4":
529+
np.testing.assert_equal(round_and_hash_float_array([p.lon for p in pset], decimals=5), 29977383852960156017546)
530+
elif method == "RK4_3D":
531+
# TODO check why decimals needs to be so low in RK4_3D (compare to v3)
532+
np.testing.assert_equal(
533+
round_and_hash_float_array([p.depth for p in pset], decimals=1), 29747210774230389239432
534+
)

0 commit comments

Comments
 (0)