44
55from typing import TYPE_CHECKING
66
7- import dask .array as dask
87import numpy as np
98import xarray as xr
9+ from dask import is_dask_collection
10+
11+ import parcels .tools .interpolation_utils as i_u
1012
1113if TYPE_CHECKING :
12- from parcels .field import Field
14+ from parcels .field import Field , VectorField
1315 from parcels .uxgrid import _UXGRID_AXES
1416 from parcels .xgrid import _XGRID_AXES
1517
1618__all__ = [
19+ "CGrid_Tracer" ,
20+ "CGrid_Velocity" ,
1721 "UXPiecewiseConstantFace" ,
1822 "UXPiecewiseLinearNode" ,
1923 "XLinear" ,
2024 "XNearest" ,
2125 "ZeroInterpolator" ,
26+ "ZeroInterpolator_Vector" ,
2227]
2328
2429
@@ -36,6 +41,21 @@ def ZeroInterpolator(
3641 return 0.0
3742
3843
44+ def ZeroInterpolator_Vector (
45+ vectorfield : VectorField ,
46+ ti : int ,
47+ position : dict [str , tuple [int , float | np .ndarray ]],
48+ tau : np .float32 | np .float64 ,
49+ t : np .float32 | np .float64 ,
50+ z : np .float32 | np .float64 ,
51+ y : np .float32 | np .float64 ,
52+ x : np .float32 | np .float64 ,
53+ applyConversion : bool ,
54+ ) -> np .float32 | np .float64 :
55+ """Template function used for the signature check of the interpolation methods for velocity fields."""
56+ return 0.0
57+
58+
3959def XLinear (
4060 field : Field ,
4161 ti : int ,
@@ -53,6 +73,7 @@ def XLinear(
5373
5474 axis_dim = field .grid .get_axis_dim_mapping (field .data .dims )
5575 data = field .data
76+ tdim , zdim , ydim , xdim = data .shape [0 ], data .shape [1 ], data .shape [2 ], data .shape [3 ]
5677
5778 lenT = 2 if np .any (tau > 0 ) else 1
5879 lenZ = 2 if np .any (zeta > 0 ) else 1
@@ -61,22 +82,22 @@ def XLinear(
6182 if lenT == 1 :
6283 ti = np .repeat (ti , lenZ * 4 )
6384 else :
64- ti_1 = np .clip (ti + 1 , 0 , data . shape [ 0 ] - 1 )
85+ ti_1 = np .clip (ti + 1 , 0 , tdim - 1 )
6586 ti = np .concatenate ([np .repeat (ti , lenZ * 4 ), np .repeat (ti_1 , lenZ * 4 )])
6687
6788 # Depth coordinates: 4 points at zi, 4 at zi+1, repeated for both time levels
6889 if lenZ == 1 :
6990 zi = np .repeat (zi , lenT * 4 )
7091 else :
71- zi_1 = np .clip (zi + 1 , 0 , data . shape [ 1 ] - 1 )
92+ zi_1 = np .clip (zi + 1 , 0 , zdim - 1 )
7293 zi = np .tile (np .array ([zi , zi , zi , zi , zi_1 , zi_1 , zi_1 , zi_1 ]).flatten (), lenT )
7394
7495 # Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
75- yi_1 = np .clip (yi + 1 , 0 , data . shape [ 2 ] - 1 )
96+ yi_1 = np .clip (yi + 1 , 0 , ydim - 1 )
7697 yi = np .tile (np .repeat (np .column_stack ([yi , yi_1 ]), 2 ), (lenT ) * (lenZ ))
7798
7899 # X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
79- xi_1 = np .clip (xi + 1 , 0 , data . shape [ 3 ] - 1 )
100+ xi_1 = np .clip (xi + 1 , 0 , xdim - 1 )
80101 xi = np .tile (np .column_stack ([xi , xi_1 , xi , xi_1 ]).flatten (), (lenT ) * (lenZ ))
81102
82103 # Create DataArrays for indexing
@@ -109,7 +130,266 @@ def XLinear(
109130 + (1 - xsi ) * eta * corner_data [:, 2 ]
110131 + xsi * eta * corner_data [:, 3 ]
111132 )
112- return value .compute () if isinstance (value , dask .Array ) else value
133+ return value .compute () if is_dask_collection (value ) else value
134+
135+
136+ def CGrid_Velocity (
137+ vectorfield : VectorField ,
138+ ti : int ,
139+ position : dict [_XGRID_AXES , tuple [int , float | np .ndarray ]],
140+ tau : np .float32 | np .float64 ,
141+ t : np .float32 | np .float64 ,
142+ z : np .float32 | np .float64 ,
143+ y : np .float32 | np .float64 ,
144+ x : np .float32 | np .float64 ,
145+ applyConversion : bool ,
146+ ):
147+ """
148+ Interpolation kernel for velocity fields on a C-Grid.
149+ Following Delandmeter and Van Sebille (2019), velocity fields should be interpolated
150+ only in the direction of the grid cell faces.
151+ """
152+ xi , xsi = position ["X" ]
153+ yi , eta = position ["Y" ]
154+ zi , zeta = position ["Z" ]
155+
156+ U = vectorfield .U .data
157+ V = vectorfield .V .data
158+ grid = vectorfield .grid
159+ tdim , zdim , ydim , xdim = U .shape [0 ], U .shape [1 ], U .shape [2 ], U .shape [3 ]
160+
161+ if grid .lon .ndim == 1 :
162+ px = np .array ([grid .lon [xi ], grid .lon [xi + 1 ], grid .lon [xi + 1 ], grid .lon [xi ]])
163+ py = np .array ([grid .lat [yi ], grid .lat [yi ], grid .lat [yi + 1 ], grid .lat [yi + 1 ]])
164+ else :
165+ px = np .array ([grid .lon [yi , xi ], grid .lon [yi , xi + 1 ], grid .lon [yi + 1 , xi + 1 ], grid .lon [yi + 1 , xi ]])
166+ py = np .array ([grid .lat [yi , xi ], grid .lat [yi , xi + 1 ], grid .lat [yi + 1 , xi + 1 ], grid .lat [yi + 1 , xi ]])
167+
168+ if grid ._mesh == "spherical" :
169+ px [0 ] = np .where (px [0 ] < x - 225 , px [0 ] + 360 , px [0 ])
170+ px [0 ] = np .where (px [0 ] > x + 225 , px [0 ] - 360 , px [0 ])
171+ px [1 :] = np .where (px [1 :] - px [0 ] > 180 , px [1 :] - 360 , px [1 :])
172+ px [1 :] = np .where (- px [1 :] + px [0 ] > 180 , px [1 :] + 360 , px [1 :])
173+ c1 = i_u ._geodetic_distance (
174+ py [0 ], py [1 ], px [0 ], px [1 ], grid ._mesh , np .einsum ("ij,ji->i" , i_u .phi2D_lin (0.0 , xsi ), py )
175+ )
176+ c2 = i_u ._geodetic_distance (
177+ py [1 ], py [2 ], px [1 ], px [2 ], grid ._mesh , np .einsum ("ij,ji->i" , i_u .phi2D_lin (eta , 1.0 ), py )
178+ )
179+ c3 = i_u ._geodetic_distance (
180+ py [2 ], py [3 ], px [2 ], px [3 ], grid ._mesh , np .einsum ("ij,ji->i" , i_u .phi2D_lin (1.0 , xsi ), py )
181+ )
182+ c4 = i_u ._geodetic_distance (
183+ py [3 ], py [0 ], px [3 ], px [0 ], grid ._mesh , np .einsum ("ij,ji->i" , i_u .phi2D_lin (eta , 0.0 ), py )
184+ )
185+
186+ lenT = 2 if np .any (tau > 0 ) else 1
187+
188+ # Create arrays of corner points for xarray.isel
189+ # TODO C grid may not need all xi and yi cornerpoints, so could speed up here?
190+
191+ # Time coordinates: 4 points at ti, then 4 points at ti+1
192+ if lenT == 1 :
193+ ti_full = np .repeat (ti , 4 )
194+ else :
195+ ti_1 = np .clip (ti + 1 , 0 , tdim - 1 )
196+ ti_full = np .concatenate ([np .repeat (ti , 4 ), np .repeat (ti_1 , 4 )])
197+
198+ # Depth coordinates: 4 points at zi, repeated for both time levels
199+ zi_full = np .repeat (zi , lenT * 4 )
200+
201+ # Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
202+ yi_1 = np .clip (yi + 1 , 0 , ydim - 1 )
203+ yi_full = np .tile (np .repeat (np .column_stack ([yi , yi_1 ]), 2 ), (lenT ))
204+ # # TODO check why in some cases minus needed here!!!
205+ # yi_minus_1 = np.clip(yi - 1, 0, ydim - 1)
206+ # yi = np.tile(np.repeat(np.column_stack([yi_minus_1, yi]), 2), (lenT))
207+
208+ # X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
209+ xi_1 = np .clip (xi + 1 , 0 , xdim - 1 )
210+ xi_full = np .tile (np .column_stack ([xi , xi_1 , xi , xi_1 ]).flatten (), (lenT ))
211+
212+ for data in [U , V ]:
213+ axis_dim = grid .get_axis_dim_mapping (data .dims )
214+
215+ # Create DataArrays for indexing
216+ selection_dict = {
217+ axis_dim ["X" ]: xr .DataArray (xi_full , dims = ("points" )),
218+ axis_dim ["Y" ]: xr .DataArray (yi_full , dims = ("points" )),
219+ }
220+ if "Z" in axis_dim :
221+ selection_dict [axis_dim ["Z" ]] = xr .DataArray (zi_full , dims = ("points" ))
222+ if "time" in data .dims :
223+ selection_dict ["time" ] = xr .DataArray (ti_full , dims = ("points" ))
224+
225+ corner_data = data .isel (selection_dict ).data .reshape (lenT , len (xsi ), 4 )
226+
227+ if lenT == 2 :
228+ tau_full = tau [:, np .newaxis ]
229+ corner_data = corner_data [0 , :, :] * (1 - tau_full ) + corner_data [1 , :, :] * tau_full
230+ else :
231+ corner_data = corner_data [0 , :, :]
232+ # # See code below for v3 version
233+ # # if self.gridindexingtype == "nemo":
234+ # # U0 = self.U.data[ti, zi, yi + 1, xi] * c4
235+ # # U1 = self.U.data[ti, zi, yi + 1, xi + 1] * c2
236+ # # V0 = self.V.data[ti, zi, yi, xi + 1] * c1
237+ # # V1 = self.V.data[ti, zi, yi + 1, xi + 1] * c3
238+ # # elif self.gridindexingtype in ["mitgcm", "croco"]:
239+ # # U0 = self.U.data[ti, zi, yi, xi] * c4
240+ # # U1 = self.U.data[ti, zi, yi, xi + 1] * c2
241+ # # V0 = self.V.data[ti, zi, yi, xi] * c1
242+ # # V1 = self.V.data[ti, zi, yi + 1, xi] * c3
243+ # # TODO Nick can you help use xgcm to fix this implementation?
244+
245+ # # CROCO and MITgcm grid indexing,
246+ # if data is U:
247+ # U0 = corner_data[:, 0] * c4
248+ # U1 = corner_data[:, 1] * c2
249+ # elif data is V:
250+ # V0 = corner_data[:, 0] * c1
251+ # V1 = corner_data[:, 2] * c3
252+ # # NEMO grid indexing
253+ if data is U :
254+ U0 = corner_data [:, 2 ] * c4
255+ U1 = corner_data [:, 3 ] * c2
256+ elif data is V :
257+ V0 = corner_data [:, 1 ] * c1
258+ V1 = corner_data [:, 3 ] * c3
259+
260+ U = (1 - xsi ) * U0 + xsi * U1
261+ V = (1 - eta ) * V0 + eta * V1
262+
263+ deg2m = 1852 * 60.0
264+ if applyConversion :
265+ meshJac = (deg2m * deg2m * np .cos (np .deg2rad (y ))) if grid ._mesh == "spherical" else 1
266+ else :
267+ meshJac = deg2m if grid ._mesh == "spherical" else 1
268+
269+ jac = i_u ._compute_jacobian_determinant (py , px , eta , xsi ) * meshJac
270+
271+ u = (
272+ (- (1 - eta ) * U - (1 - xsi ) * V ) * px [0 ]
273+ + ((1 - eta ) * U - xsi * V ) * px [1 ]
274+ + (eta * U + xsi * V ) * px [2 ]
275+ + (- eta * U + (1 - xsi ) * V ) * px [3 ]
276+ ) / jac
277+ v = (
278+ (- (1 - eta ) * U - (1 - xsi ) * V ) * py [0 ]
279+ + ((1 - eta ) * U - xsi * V ) * py [1 ]
280+ + (eta * U + xsi * V ) * py [2 ]
281+ + (- eta * U + (1 - xsi ) * V ) * py [3 ]
282+ ) / jac
283+ if is_dask_collection (u ):
284+ u = u .compute ()
285+ v = v .compute ()
286+
287+ # check whether the grid conversion has been applied correctly
288+ xx = (1 - xsi ) * (1 - eta ) * px [0 ] + xsi * (1 - eta ) * px [1 ] + xsi * eta * px [2 ] + (1 - xsi ) * eta * px [3 ]
289+ u = np .where (np .abs ((xx - x ) / x ) > 1e-4 , np .nan , u )
290+
291+ if vectorfield .W :
292+ data = vectorfield .W .data
293+ # Time coordinates: 2 points at ti, then 2 points at ti+1
294+ if lenT == 1 :
295+ ti_full = np .repeat (ti , 2 )
296+ else :
297+ ti_1 = np .clip (ti + 1 , 0 , tdim - 1 )
298+ ti_full = np .concatenate ([np .repeat (ti , 2 ), np .repeat (ti_1 , 2 )])
299+
300+ # Depth coordinates: 1 points at zi, repeated for both time levels
301+ zi_1 = np .clip (zi + 1 , 0 , zdim - 1 )
302+ zi_full = np .tile (np .array ([zi , zi_1 ]).flatten (), lenT )
303+
304+ # Y coordinates: yi+1 for each spatial point, repeated for time/depth
305+ yi_1 = np .clip (yi + 1 , 0 , ydim - 1 )
306+ yi_full = np .tile (yi_1 , (lenT ) * 2 )
307+
308+ # X coordinates: xi+1 for each spatial point, repeated for time/depth
309+ xi_1 = np .clip (xi + 1 , 0 , xdim - 1 )
310+ xi_full = np .tile (xi_1 , (lenT ) * 2 )
311+
312+ axis_dim = grid .get_axis_dim_mapping (data .dims )
313+
314+ # Create DataArrays for indexing
315+ selection_dict = {
316+ axis_dim ["X" ]: xr .DataArray (xi_full , dims = ("points" )),
317+ axis_dim ["Y" ]: xr .DataArray (yi_full , dims = ("points" )),
318+ axis_dim ["Z" ]: xr .DataArray (zi_full , dims = ("points" )),
319+ }
320+ if "time" in data .dims :
321+ selection_dict ["time" ] = xr .DataArray (ti_full , dims = ("points" ))
322+
323+ corner_data = data .isel (selection_dict ).data .reshape (lenT , 2 , len (xsi ))
324+
325+ if lenT == 2 :
326+ tau_full = tau [np .newaxis , :]
327+ corner_data = corner_data [0 , :, :] * (1 - tau_full ) + corner_data [1 , :, :] * tau_full
328+ else :
329+ corner_data = corner_data [0 , :, :]
330+
331+ w = corner_data [0 , :] * (1 - zeta ) + corner_data [1 , :] * zeta
332+ if is_dask_collection (w ):
333+ w = w .compute ()
334+ else :
335+ w = np .zeros_like (u )
336+
337+ return (u , v , w )
338+
339+
340+ def CGrid_Tracer (
341+ field : Field ,
342+ ti : int ,
343+ position : dict [_XGRID_AXES , tuple [int , float | np .ndarray ]],
344+ tau : np .float32 | np .float64 ,
345+ t : np .float32 | np .float64 ,
346+ z : np .float32 | np .float64 ,
347+ y : np .float32 | np .float64 ,
348+ x : np .float32 | np .float64 ,
349+ ):
350+ """Interpolation kernel for tracer fields on a C-Grid.
351+
352+ Following Delandmeter and Van Sebille (2019), tracer fields should be interpolated
353+ constant over the grid cell
354+ """
355+ xi , _ = position ["X" ]
356+ yi , _ = position ["Y" ]
357+ zi , _ = position ["Z" ]
358+
359+ axis_dim = field .grid .get_axis_dim_mapping (field .data .dims )
360+ data = field .data
361+
362+ lenT = 2 if np .any (tau > 0 ) else 1
363+
364+ if lenT == 2 :
365+ ti_1 = np .clip (ti + 1 , 0 , data .shape [0 ] - 1 )
366+ ti = np .concatenate ([np .repeat (ti ), np .repeat (ti_1 )])
367+ zi_1 = np .clip (zi + 1 , 0 , data .shape [1 ] - 1 )
368+ zi = np .concatenate ([np .repeat (zi ), np .repeat (zi_1 )])
369+ yi_1 = np .clip (yi + 1 , 0 , data .shape [2 ] - 1 )
370+ yi = np .concatenate ([np .repeat (yi ), np .repeat (yi_1 )])
371+ xi_1 = np .clip (xi + 1 , 0 , data .shape [3 ] - 1 )
372+ xi = np .concatenate ([np .repeat (xi ), np .repeat (xi_1 )])
373+
374+ # Create DataArrays for indexing
375+ selection_dict = {
376+ axis_dim ["X" ]: xr .DataArray (xi , dims = ("points" )),
377+ axis_dim ["Y" ]: xr .DataArray (yi , dims = ("points" )),
378+ }
379+ if "Z" in axis_dim :
380+ selection_dict [axis_dim ["Z" ]] = xr .DataArray (zi , dims = ("points" ))
381+ if "time" in field .data .dims :
382+ selection_dict ["time" ] = xr .DataArray (ti , dims = ("points" ))
383+
384+ value = data .isel (selection_dict ).data .reshape (lenT , len (xi ))
385+
386+ if lenT == 2 :
387+ tau = tau [:, np .newaxis ]
388+ value = value [0 , :] * (1 - tau ) + value [1 , :] * tau
389+ else :
390+ value = value [0 , :]
391+
392+ return value .compute () if is_dask_collection (value ) else value
113393
114394
115395def XNearest (
@@ -172,7 +452,7 @@ def XNearest(
172452 else :
173453 value = corner_data [0 , :]
174454
175- return value .compute () if isinstance (value , dask . Array ) else value
455+ return value .compute () if is_dask_collection (value ) else value
176456
177457
178458def UXPiecewiseConstantFace (
0 commit comments