Skip to content

Commit 76d015b

Browse files
Updating CGrid interpolation to not interpolate over depth for U and V
1 parent dda469d commit 76d015b

File tree

1 file changed

+14
-22
lines changed

1 file changed

+14
-22
lines changed

parcels/application_kernels/interpolation.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def CGrid_Velocity(
139139
U = vectorfield.U.data
140140
V = vectorfield.V.data
141141
grid = vectorfield.grid
142-
tdim, zdim, ydim, xdim = U.shape[0], U.shape[1], U.shape[2], U.shape[3]
142+
tdim, ydim, xdim = U.shape[0], U.shape[2], U.shape[3]
143143

144144
if grid.lon.ndim == 1:
145145
px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]])
@@ -167,32 +167,30 @@ def CGrid_Velocity(
167167
)
168168

169169
lenT = 2 if np.any(tau > 0) else 1
170-
lenZ = 2 if np.any(zeta > 0) else 1
171170

172171
# Create arrays of corner points for xarray.isel
173172
# TODO C grid may not need all xi and yi cornerpoints, so could speed up here?
174173

175174
# Time coordinates: 8 points at ti, then 8 points at ti+1
176175
if lenT == 1:
177-
ti = np.repeat(ti, lenZ * 4)
176+
ti = np.repeat(ti, 4)
178177
else:
179178
ti_1 = np.clip(ti + 1, 0, tdim - 1)
180-
ti = np.concatenate([np.repeat(ti, lenZ * 4), np.repeat(ti_1, lenZ * 4)])
179+
ti = np.concatenate([np.repeat(ti, 4), np.repeat(ti_1, 4)])
181180

182-
# Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
183-
if lenZ == 1:
184-
zi = np.repeat(zi, lenT * 4)
185-
else:
186-
zi_1 = np.clip(zi + 1, 0, zdim - 1)
187-
zi = np.tile(np.array([zi, zi, zi, zi, zi_1, zi_1, zi_1, zi_1]).flatten(), lenT)
181+
# Depth coordinates: 4 points at zi, repeated for both time levels
182+
zi = np.repeat(zi, lenT * 4)
188183

189184
# Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
190-
yi_minus_1 = np.clip(yi - 1, 0, ydim - 1) # TODO check why minus here!!!
191-
yi = np.tile(np.repeat(np.column_stack([yi_minus_1, yi]), 2), (lenT) * (lenZ))
185+
yi_1 = np.clip(yi + 1, 0, ydim - 1)
186+
yi = np.tile(np.repeat(np.column_stack([yi, yi_1]), 2), (lenT))
187+
# # TODO check why in some cases minus needed here!!!
188+
# yi_minus_1 = np.clip(yi - 1, 0, ydim - 1)
189+
# yi = np.tile(np.repeat(np.column_stack([yi_minus_1, yi]), 2), (lenT))
192190

193191
# X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
194192
xi_1 = np.clip(xi + 1, 0, xdim - 1)
195-
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT) * (lenZ))
193+
xi = np.tile(np.column_stack([xi, xi_1, xi, xi_1]).flatten(), (lenT))
196194

197195
for data in [U, V]:
198196
axis_dim = grid.get_axis_dim_mapping(data.dims)
@@ -207,17 +205,11 @@ def CGrid_Velocity(
207205
if "time" in data.dims:
208206
selection_dict["time"] = xr.DataArray(ti, dims=("points"))
209207

210-
corner_data = data.isel(selection_dict).data.reshape(lenT, lenZ, len(xsi), 4)
208+
corner_data = data.isel(selection_dict).data.reshape(lenT, len(xsi), 4)
211209

212210
if lenT == 2:
213-
tau_full = tau[np.newaxis, :, np.newaxis]
214-
corner_data = corner_data[0, :, :, :] * (1 - tau_full) + corner_data[1, :, :, :] * tau_full
215-
else:
216-
corner_data = corner_data[0, :, :, :]
217-
218-
if lenZ == 2:
219-
zeta_full = zeta[:, np.newaxis]
220-
corner_data = corner_data[0, :, :] * (1 - zeta_full) + corner_data[1, :, :] * zeta_full
211+
tau_full = tau[:, np.newaxis]
212+
corner_data = corner_data[0, :, :] * (1 - tau_full) + corner_data[1, :, :] * tau_full
221213
else:
222214
corner_data = corner_data[0, :, :]
223215
# # See code below for v3 version

0 commit comments

Comments
 (0)