@@ -14,6 +14,8 @@ class InterpolationContext2D:
1414 ----------
1515 data: np.ndarray
1616 field data of shape (time, y, x)
17+ tau: float
18+ time interpolation coordinate in unit length
1719 eta: float
1820 y-direction interpolation coordinate in unit cube (between 0 and 1)
1921 xsi: float
@@ -24,15 +26,19 @@ class InterpolationContext2D:
2426 y index of cell containing particle
2527 xi: int
2628 x index of cell containing particle
29+ interptime: bool = True
30+ whether to interpolate in time
2731
2832 """
2933
3034 data : np .ndarray
35+ tau : float
3136 eta : float
3237 xsi : float
3338 ti : int
3439 yi : int
3540 xi : int
41+ interptime : bool = True
3642
3743
3844@dataclass
@@ -45,6 +51,8 @@ class InterpolationContext3D:
4551 field data of shape (time, z, y, x). This needs to be complete in the vertical
4652 direction as some interpolation methods need to know whether they are at the
4753 surface or bottom.
54+ tau: float
55+ time interpolation coordinate in unit length
4856 zeta: float
4957 vertical interpolation coordinate in unit cube
5058 eta: float
@@ -61,10 +69,13 @@ class InterpolationContext3D:
6169 x index of cell containing particle
6270 gridindexingtype: GridIndexingType
6371 grid indexing type
72+ interptime: bool = True
73+ whether to interpolate in time
6474
6575 """
6676
6777 data : np .ndarray
78+ tau : float
6879 zeta : float
6980 eta : float
7081 xsi : float
@@ -73,6 +84,7 @@ class InterpolationContext3D:
7384 yi : int
7485 xi : int
7586 gridindexingtype : GridIndexingType # included in 3D as z-face is indexed differently with MOM5 and POP
87+ interptime : bool = True
7688
7789
7890_interpolator_registry_2d : dict [str , Callable [[InterpolationContext2D ], float ]] = {}
@@ -110,7 +122,11 @@ def decorator(interpolator: Callable[[InterpolationContext3D], float]):
110122def _nearest_2d (ctx : InterpolationContext2D ) -> float :
111123 xii = ctx .xi if ctx .xsi <= 0.5 else ctx .xi + 1
112124 yii = ctx .yi if ctx .eta <= 0.5 else ctx .yi + 1
113- return ctx .data [ctx .ti , yii , xii ]
125+ ft0 = ctx .data [ctx .ti , yii , xii ]
126+ if ctx .interptime :
127+ ft1 = ctx .data [ctx .ti + 1 , yii , xii ]
128+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
129+ return ft0
114130
115131
116132def _interp_on_unit_square (* , eta : float , xsi : float , data : np .ndarray , yi : int , xi : int ) -> float :
@@ -128,11 +144,15 @@ def _interp_on_unit_square(*, eta: float, xsi: float, data: np.ndarray, yi: int,
128144@register_2d_interpolator ("partialslip" )
129145@register_2d_interpolator ("freeslip" )
130146def _linear_2d (ctx : InterpolationContext2D ) -> float :
131- return _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti , :, :], yi = ctx .yi , xi = ctx .xi )
147+ ft0 = _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti , :, :], yi = ctx .yi , xi = ctx .xi )
148+ if ctx .interptime :
149+ ft1 = _interp_on_unit_square (eta = ctx .eta , xsi = ctx .xsi , data = ctx .data [ctx .ti + 1 , :, :], yi = ctx .yi , xi = ctx .xi )
150+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
151+ return ft0
132152
133153
134154@register_2d_interpolator ("linear_invdist_land_tracer" )
135- def _linear_invdist_land_tracer_2d (ctx : InterpolationContext2D ) -> float :
155+ def _linear_invdist_land_tracer_2d (ctx : InterpolationContext2D ) -> float : # TODO make time-varying
136156 xsi = ctx .xsi
137157 eta = ctx .eta
138158 data = ctx .data
@@ -166,19 +186,27 @@ def _linear_invdist_land_tracer_2d(ctx: InterpolationContext2D) -> float:
166186@register_2d_interpolator ("cgrid_tracer" )
167187@register_2d_interpolator ("bgrid_tracer" )
168188def _tracer_2d (ctx : InterpolationContext2D ) -> float :
169- return ctx .data [ctx .ti , ctx .yi + 1 , ctx .xi + 1 ]
189+ ft0 = ctx .data [ctx .ti , ctx .yi + 1 , ctx .xi + 1 ]
190+ if ctx .interptime :
191+ ft1 = ctx .data [ctx .ti + 1 , ctx .yi + 1 , ctx .xi + 1 ]
192+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
193+ return ft0
170194
171195
172196@register_3d_interpolator ("nearest" )
173197def _nearest_3d (ctx : InterpolationContext3D ) -> float :
174198 xii = ctx .xi if ctx .xsi <= 0.5 else ctx .xi + 1
175199 yii = ctx .yi if ctx .eta <= 0.5 else ctx .yi + 1
176200 zii = ctx .zi if ctx .zeta <= 0.5 else ctx .zi + 1
177- return ctx .data [ctx .ti , zii , yii , xii ]
201+ ft0 = ctx .data [ctx .ti , zii , yii , xii ]
202+ if ctx .interptime :
203+ ft1 = ctx .data [ctx .ti + 1 , zii , yii , xii ]
204+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
205+ return ft0
178206
179207
180208@register_3d_interpolator ("cgrid_velocity" )
181- def _cgrid_velocity_3d (ctx : InterpolationContext3D ) -> float :
209+ def _cgrid_velocity_3d (ctx : InterpolationContext3D ) -> float : # TODO make time-varying
182210 # evaluating W velocity in c_grid
183211 if ctx .gridindexingtype == "nemo" :
184212 f0 = ctx .data [ctx .ti , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
@@ -190,7 +218,7 @@ def _cgrid_velocity_3d(ctx: InterpolationContext3D) -> float:
190218
191219
192220@register_3d_interpolator ("linear_invdist_land_tracer" )
193- def _linear_invdist_land_tracer_3d (ctx : InterpolationContext3D ) -> float :
221+ def _linear_invdist_land_tracer_3d (ctx : InterpolationContext3D ) -> float : # TODO make time-varying
194222 land = np .isclose (ctx .data [ctx .ti , ctx .zi : ctx .zi + 2 , ctx .yi : ctx .yi + 2 , ctx .xi : ctx .xi + 2 ], 0.0 )
195223 nb_land = np .sum (land )
196224 if nb_land == 8 :
@@ -250,7 +278,7 @@ def _z_layer_interp(
250278@register_3d_interpolator ("linear" )
251279@register_3d_interpolator ("partialslip" )
252280@register_3d_interpolator ("freeslip" )
253- def _linear_3d (ctx : InterpolationContext3D ) -> float :
281+ def _linear_3d (ctx : InterpolationContext3D ) -> float : # TODO make time-varying
254282 zdim = ctx .data .shape [1 ]
255283 data_3d = ctx .data [ctx .ti , :, :, :]
256284 f0 , f1 = _get_3d_f0_f1 (eta = ctx .eta , xsi = ctx .xsi , data = data_3d , zi = ctx .zi , yi = ctx .yi , xi = ctx .xi )
@@ -277,4 +305,8 @@ def _linear_3d_bgrid_w_velocity(ctx: InterpolationContext3D) -> float:
277305@register_3d_interpolator ("bgrid_tracer" )
278306@register_3d_interpolator ("cgrid_tracer" )
279307def _tracer_3d (ctx : InterpolationContext3D ) -> float :
280- return ctx .data [ctx .ti , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
308+ ft0 = ctx .data [ctx .ti , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
309+ if ctx .interptime :
310+ ft1 = ctx .data [ctx .ti + 1 , ctx .zi , ctx .yi + 1 , ctx .xi + 1 ]
311+ return (1 - ctx .tau ) * ft0 + ctx .tau * ft1
312+ return ft0
0 commit comments