Skip to content

Commit 5c085b0

Browse files
cleaning up cgrid interpolation
1 parent e25a2a2 commit 5c085b0

File tree

2 files changed

+36
-62
lines changed

2 files changed

+36
-62
lines changed

parcels/_interpolation.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,16 +201,31 @@ def _nearest_3d(ctx: InterpolationContext3D) -> float:
201201
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
202202

203203

204+
def _get_cgrid_depth_point(*, zeta: float, data: np.ndarray, zi: int, yi: int, xi: int) -> tuple[float]:
205+
f0 = data[zi, yi, xi]
206+
f1 = data[zi + 1, yi, xi]
207+
return (1 - zeta) * f0 + zeta * f1
208+
209+
204210
@register_3d_interpolator("cgrid_velocity")
205-
def _cgrid_velocity_3d(ctx: InterpolationContext3D) -> float: # TODO make time-varying
211+
def _cgrid_W_velocity_3d(ctx: InterpolationContext3D) -> float:
206212
# evaluating W velocity in c_grid
207213
if ctx.gridindexingtype == "nemo":
208-
f0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
209-
f1 = ctx.data[ctx.ti, ctx.zi + 1, ctx.yi + 1, ctx.xi + 1]
214+
ft0 = _get_cgrid_depth_point(
215+
zeta=ctx.zeta, data=ctx.data[ctx.ti, :, :, :], zi=ctx.zi, yi=ctx.yi + 1, xi=ctx.xi + 1
216+
)
210217
elif ctx.gridindexingtype in ["mitgcm", "croco"]:
211-
f0 = ctx.data[ctx.ti, ctx.zi, ctx.yi, ctx.xi]
212-
f1 = ctx.data[ctx.ti, ctx.zi + 1, ctx.yi, ctx.xi]
213-
return (1 - ctx.zeta) * f0 + ctx.zeta * f1
218+
ft0 = _get_cgrid_depth_point(zeta=ctx.zeta, data=ctx.data[ctx.ti, :, :, :], zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
219+
if ctx.tau < EPS or ctx.ti == ctx.data.shape[0] - 1:
220+
return ft0
221+
222+
if ctx.gridindexingtype == "nemo":
223+
ft1 = _get_cgrid_depth_point(
224+
zeta=ctx.zeta, data=ctx.data[ctx.ti + 1, :, :, :], zi=ctx.zi, yi=ctx.yi + 1, xi=ctx.xi + 1
225+
)
226+
elif ctx.gridindexingtype in ["mitgcm", "croco"]:
227+
ft1 = _get_cgrid_depth_point(zeta=ctx.zeta, data=ctx.data[ctx.ti + 1, :, :, :], zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
228+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
214229

215230

216231
@register_3d_interpolator("linear_invdist_land_tracer")

parcels/field.py

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,66 +1526,25 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True):
15261526
if applyConversion:
15271527
u = self.U.units.to_target(u, z, y, x)
15281528
v = self.V.units.to_target(v, z, y, x)
1529-
if "3D" in self.vector_type:
1530-
w = self.W.eval(time, z, y, x, particle=particle, applyConversion=False)
1531-
if applyConversion:
1532-
w = self.W.units.to_target(w, z, y, x)
1533-
return (u, v, w)
1534-
else:
1535-
return (u, v)
15361529
else:
15371530
interp = {
1538-
"cgrid_velocity": {
1539-
"2D": self.spatial_c_grid_interpolation2D,
1540-
"3D": self.spatial_c_grid_interpolation3D,
1541-
},
1542-
"partialslip": {"2D": self.spatial_slip_interpolation, "3D": self.spatial_slip_interpolation},
1543-
"freeslip": {"2D": self.spatial_slip_interpolation, "3D": self.spatial_slip_interpolation},
1531+
"cgrid_velocity": self.spatial_c_grid_interpolation2D,
1532+
"partialslip": self.spatial_slip_interpolation,
1533+
"freeslip": self.spatial_slip_interpolation,
15441534
}
1545-
grid = self.U.grid
15461535
tau, ti = self.U._time_index(time)
1547-
if ti < grid.tdim - 1 and time > grid.time[ti]:
1548-
t0 = grid.time[ti]
1549-
t1 = grid.time[ti + 1]
1550-
if "3D" in self.vector_type:
1551-
(u0, v0, w0) = interp[self.U.interp_method]["3D"](
1552-
ti,
1553-
z,
1554-
y,
1555-
x,
1556-
t0,
1557-
particle=particle,
1558-
applyConversion=applyConversion, # TODO see if we can directly call time interpolation for W here
1559-
)
1560-
(u1, v1, w1) = interp[self.U.interp_method]["3D"](
1561-
ti + 1, z, y, x, t1, particle=particle, applyConversion=applyConversion
1562-
)
1563-
w = w0 * (1 - tau) + w1 * tau
1564-
else:
1565-
(u0, v0) = interp[self.U.interp_method]["2D"](
1566-
ti, z, y, x, t0, particle=particle, applyConversion=applyConversion
1567-
)
1568-
(u1, v1) = interp[self.U.interp_method]["2D"](
1569-
ti + 1, z, y, x, t1, particle=particle, applyConversion=applyConversion
1570-
)
1571-
u = u0 * (1 - tau) + u1 * tau
1572-
v = v0 * (1 - tau) + v1 * tau
1573-
if "3D" in self.vector_type:
1574-
return (u, v, w)
1575-
else:
1576-
return (u, v)
1577-
else:
1578-
# Skip temporal interpolation if time is outside
1579-
# of the defined time range or if we have hit an
1580-
# exact value in the time array.
1581-
if "3D" in self.vector_type:
1582-
return interp[self.U.interp_method]["3D"](
1583-
ti, z, y, x, grid.time[ti], particle=particle, applyConversion=applyConversion
1584-
)
1585-
else:
1586-
return interp[self.U.interp_method]["2D"](
1587-
ti, z, y, x, grid.time[ti], particle=particle, applyConversion=applyConversion
1588-
)
1536+
(u, v) = interp[self.U.interp_method](ti, z, y, x, time, particle=particle, applyConversion=applyConversion)
1537+
if ti < self.U.grid.tdim - 1 and time > self.U.grid.time[ti]:
1538+
(u1, v1) = interp[self.U.interp_method](
1539+
ti + 1, z, y, x, time, particle=particle, applyConversion=applyConversion
1540+
)
1541+
u = u * (1 - tau) + u1 * tau
1542+
v = v * (1 - tau) + v1 * tau
1543+
if "3D" in self.vector_type:
1544+
w = self.W.eval(time, z, y, x, particle=particle, applyConversion=applyConversion)
1545+
return (u, v, w)
1546+
else:
1547+
return (u, v)
15891548

15901549
def __getitem__(self, key):
15911550
try:

0 commit comments

Comments
 (0)