Skip to content

Commit aa6d2f0

Browse files
Adding calculate_next_ti helper function
1 parent 2b24b47 commit aa6d2f0

File tree

3 files changed

+31
-20
lines changed

3 files changed

+31
-20
lines changed

parcels/_interpolation.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import numpy as np
55

66
from parcels._typing import GridIndexingType
7-
8-
EPS = np.finfo(float).eps
7+
from parcels.tools._helpers import calculate_next_ti
98

109

1110
@dataclass
@@ -119,7 +118,7 @@ def _nearest_2d(ctx: InterpolationContext2D) -> float:
119118
xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1
120119
yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1
121120
ft0 = ctx.data[ctx.ti, yii, xii]
122-
if ctx.tau < EPS or ctx.ti >= ctx.data.shape[0] - 1:
121+
if calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
123122
return ft0
124123
ft1 = ctx.data[ctx.ti + 1, yii, xii]
125124
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
@@ -141,7 +140,7 @@ def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int,
141140
@register_2d_interpolator("freeslip")
142141
def _linear_2d(ctx: InterpolationContext2D) -> float:
143142
ft0 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti, :, :], yi=ctx.yi, xi=ctx.xi)
144-
if ctx.tau < EPS or ctx.ti >= ctx.data.shape[0] - 1:
143+
if not calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
145144
return ft0
146145
ft1 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti + 1, :, :], yi=ctx.yi, xi=ctx.xi)
147146
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
@@ -160,7 +159,7 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
160159

161160
def _get_data_temporalinterp(ti, yi, xi):
162161
dt0 = data[ti, yi, xi]
163-
if ctx.tau < EPS or ctx.ti >= ctx.data.shape[0] - 1:
162+
if not calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
164163
return dt0
165164
dt1 = data[ti + 1, yi, xi]
166165
return (1 - ctx.tau) * dt0 + ctx.tau * dt1
@@ -190,7 +189,7 @@ def _get_data_temporalinterp(ti, yi, xi):
190189
@register_2d_interpolator("bgrid_tracer")
191190
def _tracer_2d(ctx: InterpolationContext2D) -> float:
192191
ft0 = ctx.data[ctx.ti, ctx.yi + 1, ctx.xi + 1]
193-
if ctx.tau < EPS or ctx.ti >= ctx.data.shape[0] - 1:
192+
if not calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
194193
return ft0
195194
ft1 = ctx.data[ctx.ti + 1, ctx.yi + 1, ctx.xi + 1]
196195
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
@@ -202,7 +201,7 @@ def _nearest_3d(ctx: InterpolationContext3D) -> float:
202201
yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1
203202
zii = ctx.zi if ctx.zeta <= 0.5 else ctx.zi + 1
204203
ft0 = ctx.data[ctx.ti, zii, yii, xii]
205-
if ctx.tau < EPS or ctx.ti == ctx.data.shape[0] - 1:
204+
if not calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
206205
return ft0
207206
ft1 = ctx.data[ctx.ti + 1, zii, yii, xii]
208207
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
@@ -223,7 +222,7 @@ def _cgrid_W_velocity_3d(ctx: InterpolationContext3D) -> float:
223222
)
224223
elif ctx.gridindexingtype in ["mitgcm", "croco"]:
225224
ft0 = _get_cgrid_depth_point(zeta=ctx.zeta, data=ctx.data[ctx.ti, :, :, :], zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
226-
if ctx.tau < EPS or ctx.ti == ctx.data.shape[0] - 1:
225+
if not calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
227226
return ft0
228227

229228
if ctx.gridindexingtype == "nemo":
@@ -242,7 +241,7 @@ def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
242241

243242
def _get_data_temporalinterp(ti, zi, yi, xi):
244243
dt0 = ctx.data[ti, zi, yi, xi]
245-
if ctx.tau < EPS or ctx.ti >= ctx.data.shape[0] - 1:
244+
if not calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
246245
return dt0
247246
dt1 = data[ti + 1, zi, yi, xi]
248247
return (1 - ctx.tau) * dt0 + ctx.tau * dt1
@@ -308,7 +307,7 @@ def _linear_3d(ctx: InterpolationContext3D) -> float:
308307
zdim = ctx.data.shape[1]
309308
data_3d = ctx.data[ctx.ti, :, :, :]
310309
fz0, fz1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
311-
if ctx.tau > EPS and ctx.ti < ctx.data.shape[0] - 1:
310+
if calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
312311
data_3d = ctx.data[ctx.ti + 1, :, :, :]
313312
fz0_t1, fz1_t1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
314313
fz0 = (1 - ctx.tau) * fz0 + ctx.tau * fz0_t1
@@ -338,7 +337,8 @@ def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float:
338337
@register_3d_interpolator("cgrid_tracer")
339338
def _tracer_3d(ctx: InterpolationContext3D) -> float:
340339
ft0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
341-
if ctx.tau < EPS or ctx.ti >= ctx.data.shape[0] - 1:
340+
if not calculate_next_ti(ctx.ti, ctx.tau, ctx.data.shape[0]):
342341
return ft0
343-
ft1 = ctx.data[ctx.ti + 1, ctx.zi, ctx.yi + 1, ctx.xi + 1]
344-
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
342+
else:
343+
ft1 = ctx.data[ctx.ti + 1, ctx.zi, ctx.yi + 1, ctx.xi + 1]
344+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1

parcels/field.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
assert_valid_gridindexingtype,
2727
assert_valid_interp_method,
2828
)
29-
from parcels.tools._helpers import default_repr, field_repr
29+
from parcels.tools._helpers import calculate_next_ti, default_repr, field_repr
3030
from parcels.tools.converters import (
3131
TimeConverter,
3232
UnitConverter,
@@ -1030,17 +1030,15 @@ def _calc_UV(ti, yi, xi):
10301030
return (u, v)
10311031

10321032
u, v = _calc_UV(ti, yi, xi)
1033-
if ti < self.U.grid.tdim - 1 and time > self.U.grid.time[ti]:
1033+
if calculate_next_ti(ti, tau, self.U.grid.tdim):
10341034
ut1, vt1 = _calc_UV(ti + 1, yi, xi)
10351035
u = (1 - tau) * u + tau * ut1
10361036
v = (1 - tau) * v + tau * vt1
10371037
return (u, v)
10381038

1039-
def spatial_c_grid_interpolation3D_full(self, ti, z, y, x, time, particle=None):
1039+
def spatial_c_grid_interpolation3D_full(self, time, z, y, x, particle=None):
10401040
grid = self.U.grid
1041-
(_, zeta, eta, xsi, _, zi, yi, xi) = self.U._search_indices(
1042-
time, z, y, x, particle=particle
1043-
) # TODO use tau here too
1041+
(tau, zeta, eta, xsi, ti, zi, yi, xi) = self.U._search_indices(time, z, y, x, particle=particle)
10441042

10451043
if grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
10461044
px = np.array([grid.lon[xi], grid.lon[xi + 1], grid.lon[xi + 1], grid.lon[xi]])
@@ -1093,6 +1091,14 @@ def spatial_c_grid_interpolation3D_full(self, ti, z, y, x, time, particle=None):
10931091
w0 = self.W.data[ti, zi, yi + 1, xi + 1]
10941092
w1 = self.W.data[ti, zi + 1, yi + 1, xi + 1]
10951093

1094+
if calculate_next_ti(ti, tau, self.U.grid.tdim):
1095+
u0 = (1 - tau) * u0 + tau * self.U.data[ti + 1, zi, yi + 1, xi]
1096+
u1 = (1 - tau) * u1 + tau * self.U.data[ti + 1, zi, yi + 1, xi + 1]
1097+
v0 = (1 - tau) * v0 + tau * self.V.data[ti + 1, zi, yi, xi + 1]
1098+
v1 = (1 - tau) * v1 + tau * self.V.data[ti + 1, zi, yi + 1, xi + 1]
1099+
w0 = (1 - tau) * w0 + tau * self.W.data[ti + 1, zi, yi + 1, xi + 1]
1100+
w1 = (1 - tau) * w1 + tau * self.W.data[ti + 1, zi + 1, yi + 1, xi + 1]
1101+
10961102
U0 = u0 * i_u.jacobian3D_lin_face(pz, py, px, zeta, eta, 0, "zonal", grid.mesh)
10971103
U1 = u1 * i_u.jacobian3D_lin_face(pz, py, px, zeta, eta, 1, "zonal", grid.mesh)
10981104
V0 = v0 * i_u.jacobian3D_lin_face(pz, py, px, zeta, 0, xsi, "meridional", grid.mesh)
@@ -1260,7 +1266,7 @@ def spatial_c_grid_interpolation3D(self, ti, z, y, x, time, particle=None, apply
12601266
Curvilinear grids are treated properly, since the element is projected to a rectilinear parent element.
12611267
"""
12621268
if self.U.grid._gtype in [GridType.RectilinearSGrid, GridType.CurvilinearSGrid]:
1263-
(u, v, w) = self.spatial_c_grid_interpolation3D_full(ti, z, y, x, time, particle=particle)
1269+
(u, v, w) = self.spatial_c_grid_interpolation3D_full(time, z, y, x, particle=particle)
12641270
else:
12651271
if self.gridindexingtype == "croco":
12661272
z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle)

parcels/tools/_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,8 @@ def timedelta_to_float(dt: float | timedelta | np.timedelta64) -> float:
152152
if isinstance(dt, np.timedelta64):
153153
return float(dt / np.timedelta64(1, "s"))
154154
return float(dt)
155+
156+
157+
def calculate_next_ti(ti, tau, tdim):
158+
"""Check if the time is beyond the last time in the field"""
159+
return tau > np.finfo(float).eps and ti < tdim - 1

0 commit comments

Comments
 (0)