Skip to content

Commit a49e4c0

Browse files
Pushing the time interpolation also to Interpolators
1 parent 02aee73 commit a49e4c0

File tree

2 files changed

+49
-27
lines changed

2 files changed

+49
-27
lines changed

parcels/_interpolation.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ class InterpolationContext2D:
1414
----------
1515
data: np.ndarray
1616
field data of shape (time, y, x)
17+
tau: float
18+
time interpolation coordinate in unit length
1719
eta: float
1820
y-direction interpolation coordinate in unit cube (between 0 and 1)
1921
xsi: float
@@ -24,15 +26,19 @@ class InterpolationContext2D:
2426
y index of cell containing particle
2527
xi: int
2628
x index of cell containing particle
29+
interptime: bool = True
30+
whether to interpolate in time
2731
2832
"""
2933

3034
data: np.ndarray
35+
tau: float
3136
eta: float
3237
xsi: float
3338
ti: int
3439
yi: int
3540
xi: int
41+
interptime: bool = True
3642

3743

3844
@dataclass
@@ -45,6 +51,8 @@ class InterpolationContext3D:
4551
field data of shape (time, z, y, x). This needs to be complete in the vertical
4652
direction as some interpolation methods need to know whether they are at the
4753
surface or bottom.
54+
tau: float
55+
time interpolation coordinate in unit length
4856
zeta: float
4957
vertical interpolation coordinate in unit cube
5058
eta: float
@@ -61,10 +69,13 @@ class InterpolationContext3D:
6169
x index of cell containing particle
6270
gridindexingtype: GridIndexingType
6371
grid indexing type
72+
interptime: bool = True
73+
whether to interpolate in time
6474
6575
"""
6676

6777
data: np.ndarray
78+
tau: float
6879
zeta: float
6980
eta: float
7081
xsi: float
@@ -73,6 +84,7 @@ class InterpolationContext3D:
7384
yi: int
7485
xi: int
7586
gridindexingtype: GridIndexingType # included in 3D as z-face is indexed differently with MOM5 and POP
87+
interptime: bool = True
7688

7789

7890
_interpolator_registry_2d: dict[str, Callable[[InterpolationContext2D], float]] = {}
@@ -110,7 +122,11 @@ def decorator(interpolator: Callable[[InterpolationContext3D], float]):
110122
def _nearest_2d(ctx: InterpolationContext2D) -> float:
111123
xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1
112124
yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1
113-
return ctx.data[ctx.ti, yii, xii]
125+
ft0 = ctx.data[ctx.ti, yii, xii]
126+
if ctx.interptime:
127+
ft1 = ctx.data[ctx.ti + 1, yii, xii]
128+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
129+
return ft0
114130

115131

116132
def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int, xi: int) -> float:
@@ -128,11 +144,15 @@ def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int,
128144
@register_2d_interpolator("partialslip")
129145
@register_2d_interpolator("freeslip")
130146
def _linear_2d(ctx: InterpolationContext2D) -> float:
131-
return _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti, :, :], yi=ctx.yi, xi=ctx.xi)
147+
ft0 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti, :, :], yi=ctx.yi, xi=ctx.xi)
148+
if ctx.interptime:
149+
ft1 = _interp_on_unit_square(eta=ctx.eta, xsi=ctx.xsi, data=ctx.data[ctx.ti + 1, :, :], yi=ctx.yi, xi=ctx.xi)
150+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
151+
return ft0
132152

133153

134154
@register_2d_interpolator("linear_invdist_land_tracer")
135-
def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
155+
def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float: # TODO make time-varying
136156
xsi = ctx.xsi
137157
eta = ctx.eta
138158
data = ctx.data
@@ -166,19 +186,27 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
166186
@register_2d_interpolator("cgrid_tracer")
167187
@register_2d_interpolator("bgrid_tracer")
168188
def _tracer_2d(ctx: InterpolationContext2D) -> float:
169-
return ctx.data[ctx.ti, ctx.yi + 1, ctx.xi + 1]
189+
ft0 = ctx.data[ctx.ti, ctx.yi + 1, ctx.xi + 1]
190+
if ctx.interptime:
191+
ft1 = ctx.data[ctx.ti + 1, ctx.yi + 1, ctx.xi + 1]
192+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
193+
return ft0
170194

171195

172196
@register_3d_interpolator("nearest")
173197
def _nearest_3d(ctx: InterpolationContext3D) -> float:
174198
xii = ctx.xi if ctx.xsi <= 0.5 else ctx.xi + 1
175199
yii = ctx.yi if ctx.eta <= 0.5 else ctx.yi + 1
176200
zii = ctx.zi if ctx.zeta <= 0.5 else ctx.zi + 1
177-
return ctx.data[ctx.ti, zii, yii, xii]
201+
ft0 = ctx.data[ctx.ti, zii, yii, xii]
202+
if ctx.interptime:
203+
ft1 = ctx.data[ctx.ti + 1, zii, yii, xii]
204+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
205+
return ft0
178206

179207

180208
@register_3d_interpolator("cgrid_velocity")
181-
def _cgrid_velocity_3d(ctx: InterpolationContext3D) -> float:
209+
def _cgrid_velocity_3d(ctx: InterpolationContext3D) -> float: # TODO make time-varying
182210
# evaluating W velocity in c_grid
183211
if ctx.gridindexingtype == "nemo":
184212
f0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
@@ -190,7 +218,7 @@ def _cgrid_velocity_3d(ctx: InterpolationContext3D) -> float:
190218

191219

192220
@register_3d_interpolator("linear_invdist_land_tracer")
193-
def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float:
221+
def _linear_invdist_land_tracer_3d(ctx: InterpolationContext3D) -> float: # TODO make time-varying
194222
land = np.isclose(ctx.data[ctx.ti, ctx.zi : ctx.zi + 2, ctx.yi : ctx.yi + 2, ctx.xi : ctx.xi + 2], 0.0)
195223
nb_land = np.sum(land)
196224
if nb_land == 8:
@@ -250,7 +278,7 @@ def _z_layer_interp(
250278
@register_3d_interpolator("linear")
251279
@register_3d_interpolator("partialslip")
252280
@register_3d_interpolator("freeslip")
253-
def _linear_3d(ctx: InterpolationContext3D) -> float:
281+
def _linear_3d(ctx: InterpolationContext3D) -> float: # TODO make time-varying
254282
zdim = ctx.data.shape[1]
255283
data_3d = ctx.data[ctx.ti, :, :, :]
256284
f0, f1 = _get_3d_f0_f1(eta=ctx.eta, xsi=ctx.xsi, data=data_3d, zi=ctx.zi, yi=ctx.yi, xi=ctx.xi)
@@ -277,4 +305,8 @@ def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float:
277305
@register_3d_interpolator("bgrid_tracer")
278306
@register_3d_interpolator("cgrid_tracer")
279307
def _tracer_3d(ctx: InterpolationContext3D) -> float:
280-
return ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
308+
ft0 = ctx.data[ctx.ti, ctx.zi, ctx.yi + 1, ctx.xi + 1]
309+
if ctx.interptime:
310+
ft1 = ctx.data[ctx.ti + 1, ctx.zi, ctx.yi + 1, ctx.xi + 1]
311+
return (1 - ctx.tau) * ft0 + ctx.tau * ft1
312+
return ft0

parcels/field.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -893,15 +893,9 @@ def _interpolator2D(self, time, z, y, x, particle=None):
893893

894894
(tau, _, eta, xsi, ti, _, yi, xi) = self._search_indices(time, z, y, x, particle=particle)
895895

896-
ctx = InterpolationContext2D(self.data, eta, xsi, ti, yi, xi)
897-
f0 = f(ctx)
898-
899-
if ti < self.grid.tdim - 1 and time > self.grid.time[ti]:
900-
ctx = InterpolationContext2D(self.data, eta, xsi, ti + 1, yi, xi)
901-
f1 = f(ctx)
902-
return f0 * (1 - tau) + f1 * tau
903-
else:
904-
return f0
896+
interptime = True if (ti < self.grid.tdim - 1 and tau > 0) else False
897+
ctx = InterpolationContext2D(self.data, tau, eta, xsi, ti, yi, xi, interptime=interptime)
898+
return f(ctx)
905899

906900
def _interpolator3D(self, time, z, y, x, particle=None):
907901
"""Impelement 3D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.."""
@@ -912,15 +906,11 @@ def _interpolator3D(self, time, z, y, x, particle=None):
912906

913907
(tau, zeta, eta, xsi, ti, zi, yi, xi) = self._search_indices(time, z, y, x, particle=particle)
914908

915-
ctx = InterpolationContext3D(self.data, zeta, eta, xsi, ti, zi, yi, xi, self.gridindexingtype)
916-
f0 = f(ctx)
917-
918-
if ti < self.grid.tdim - 1 and time > self.grid.time[ti]:
919-
ctx = InterpolationContext3D(self.data, zeta, eta, xsi, ti + 1, zi, yi, xi, self.gridindexingtype)
920-
f1 = f(ctx)
921-
return f0 * (1 - tau) + f1 * tau
922-
else:
923-
return f0
909+
interptime = True if (ti < self.grid.tdim - 1 and tau > 0) else False
910+
ctx = InterpolationContext3D(
911+
self.data, tau, zeta, eta, xsi, ti, zi, yi, xi, self.gridindexingtype, interptime=interptime
912+
)
913+
return f(ctx)
924914

925915
def _spatial_interpolation(self, time, z, y, x, particle=None):
926916
"""Interpolate spatial field values."""

0 commit comments

Comments
 (0)