44import numpy as np
55
66from parcels ._typing import GridIndexingType
7-
8- EPS = np .finfo (float ).eps
7+ from parcels .tools ._helpers import calculate_next_ti
98
109
1110@dataclass
@@ -119,7 +118,7 @@ def _nearest_2d(ctx: InterpolationContext2D) -> float:
119118 xii = ctx .xi if ctx .xsi <= 0.5 else ctx .xi + 1
120119 yii = ctx .yi if ctx .eta <= 0.5 else ctx .yi + 1
121120 ft0 = ctx .data [ctx .ti , yii , xii ]
122- if ctx .tau < EPS or ctx .ti >= ctx .data .shape [0 ] - 1 :
121+ if calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
123122 return ft0
124123 ft1 = ctx .data [ctx .ti + 1 , yii , xii ]
125124 return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
@@ -141,7 +140,7 @@ def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int,
141140@register_2d_interpolator ("freeslip" )
142141def _linear_2d (ctx : InterpolationContext2D ) -> float :
143142 ft0 = _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti , :, :], yi = ctx .yi , xi = ctx .xi )
144- if ctx .tau < EPS or ctx .ti >= ctx .data .shape [0 ] - 1 :
143+ if not calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
145144 return ft0
146145 ft1 = _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti + 1 , :, :], yi = ctx .yi , xi = ctx .xi )
147146 return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
@@ -160,7 +159,7 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
160159
161160 def _get_data_temporalinterp (ti , yi , xi ):
162161 dt0 = data [ti , yi , xi ]
163- if ctx .tau < EPS or ctx .ti >= ctx .data .shape [0 ] - 1 :
162+ if not calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
164163 return dt0
165164 dt1 = data [ti + 1 , yi , xi ]
166165 return (1 - ctx .tau ) * dt0 + ctx .tau * dt1
@@ -190,7 +189,7 @@ def _get_data_temporalinterp(ti, yi, xi):
190189@register_2d_interpolator ("bgrid_tracer" )
191190def _tracer_2d (ctx : InterpolationContext2D ) -> float :
192191 ft0 = ctx .data [ctx .ti , ctx .yi + 1 , ctx .xi + 1 ]
193- if ctx .tau < EPS or ctx .ti >= ctx .data .shape [0 ] - 1 :
192+ if not calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
194193 return ft0
195194 ft1 = ctx .data [ctx .ti + 1 , ctx .yi + 1 , ctx .xi + 1 ]
196195 return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
@@ -202,7 +201,7 @@ def _nearest_3d(ctx: InterpolationContext3D) -> float:
202201 yii = ctx .yi if ctx .eta <= 0.5 else ctx .yi + 1
203202 zii = ctx .zi if ctx .zeta <= 0.5 else ctx .zi + 1
204203 ft0 = ctx .data [ctx .ti , zii , yii , xii ]
205- if ctx .tau < EPS or ctx .ti == ctx .data .shape [0 ] - 1 :
204+ if not calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
206205 return ft0
207206 ft1 = ctx .data [ctx .ti + 1 , zii , yii , xii ]
208207 return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
@@ -223,7 +222,7 @@ def _cgrid_W_velocity_3d(ctx: InterpolationContext3D) -> float:
223222 )
224223 elif ctx .gridindexingtype in ["mitgcm" , "croco" ]:
225224 ft0 = _get_cgrid_depth_point (zeta = ctx .zeta , data = ctx .data [ctx .ti , :, :, :], zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
226- if ctx .tau < EPS or ctx .ti == ctx .data .shape [0 ] - 1 :
225+ if not calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
227226 return ft0
228227
229228 if ctx .gridindexingtype == "nemo" :
@@ -242,7 +241,7 @@ def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
242241
243242 def _get_data_temporalinterp (ti , zi , yi , xi ):
244243 dt0 = ctx .data [ti , zi , yi , xi ]
245- if ctx .tau < EPS or ctx .ti >= ctx .data .shape [0 ] - 1 :
244+ if not calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
246245 return dt0
247246 dt1 = data [ti + 1 , zi , yi , xi ]
248247 return (1 - ctx .tau ) * dt0 + ctx .tau * dt1
@@ -308,7 +307,7 @@ def _linear_3d(ctx: InterpolationContext3D) -> float:
308307 zdim = ctx .data .shape [1 ]
309308 data_3d = ctx .data [ctx .ti , :, :, :]
310309 fz0 , fz1 = _get_3d_f0_f1 (eta = ctx .eta , xsi = ctx .xsi , data = data_3d , zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
311- if ctx .tau > EPS and ctx .ti < ctx .data .shape [0 ] - 1 :
310+ if calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
312311 data_3d = ctx .data [ctx .ti + 1 , :, :, :]
313312 fz0_t1 , fz1_t1 = _get_3d_f0_f1 (eta = ctx .eta , xsi = ctx .xsi , data = data_3d , zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
314313 fz0 = (1 - ctx .tau ) * fz0 + ctx .tau * fz0_t1
@@ -338,7 +337,8 @@ def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float:
338337@register_3d_interpolator ("cgrid_tracer" )
339338def _tracer_3d (ctx : InterpolationContext3D ) -> float :
340339 ft0 = ctx .data [ctx .ti , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
341- if ctx .tau < EPS or ctx .ti >= ctx .data .shape [0 ] - 1 :
340+ if not calculate_next_ti ( ctx .ti , ctx .tau , ctx .data .shape [0 ]) :
342341 return ft0
343- ft1 = ctx .data [ctx .ti + 1 , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
344- return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
342+ else :
343+ ft1 = ctx .data [ctx .ti + 1 , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
344+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
0 commit comments