@@ -139,7 +139,7 @@ def CGrid_Velocity(
139139 U = vectorfield .U .data
140140 V = vectorfield .V .data
141141 grid = vectorfield .grid
142- tdim , ydim , xdim = U .shape [0 ], U .shape [2 ], U .shape [3 ]
142+ tdim , zdim , ydim , xdim = U .shape [0 ], U . shape [ 1 ], U .shape [2 ], U .shape [3 ]
143143
144144 if grid .lon .ndim == 1 :
145145 px = np .array ([grid .lon [xi ], grid .lon [xi + 1 ], grid .lon [xi + 1 ], grid .lon [xi ]])
@@ -171,39 +171,39 @@ def CGrid_Velocity(
171171 # Create arrays of corner points for xarray.isel
172172 # TODO C grid may not need all xi and yi cornerpoints, so could speed up here?
173173
174- # Time coordinates: 8 points at ti, then 8 points at ti+1
174+ # Time coordinates: 4 points at ti, then 4 points at ti+1
175175 if lenT == 1 :
176- ti = np .repeat (ti , 4 )
176+ ti_full = np .repeat (ti , 4 )
177177 else :
178178 ti_1 = np .clip (ti + 1 , 0 , tdim - 1 )
179- ti = np .concatenate ([np .repeat (ti , 4 ), np .repeat (ti_1 , 4 )])
179+ ti_full = np .concatenate ([np .repeat (ti , 4 ), np .repeat (ti_1 , 4 )])
180180
181181 # Depth coordinates: 4 points at zi, repeated for both time levels
182- zi = np .repeat (zi , lenT * 4 )
182+ zi_full = np .repeat (zi , lenT * 4 )
183183
184184 # Y coordinates: [yi, yi, yi+1, yi+1] for each spatial point, repeated for time/depth
185185 yi_1 = np .clip (yi + 1 , 0 , ydim - 1 )
186- yi = np .tile (np .repeat (np .column_stack ([yi , yi_1 ]), 2 ), (lenT ))
186+ yi_full = np .tile (np .repeat (np .column_stack ([yi , yi_1 ]), 2 ), (lenT ))
187187 # # TODO check why in some cases minus needed here!!!
188188 # yi_minus_1 = np.clip(yi - 1, 0, ydim - 1)
189189 # yi = np.tile(np.repeat(np.column_stack([yi_minus_1, yi]), 2), (lenT))
190190
191191 # X coordinates: [xi, xi+1, xi, xi+1] for each spatial point, repeated for time/depth
192192 xi_1 = np .clip (xi + 1 , 0 , xdim - 1 )
193- xi = np .tile (np .column_stack ([xi , xi_1 , xi , xi_1 ]).flatten (), (lenT ))
193+ xi_full = np .tile (np .column_stack ([xi , xi_1 , xi , xi_1 ]).flatten (), (lenT ))
194194
195195 for data in [U , V ]:
196196 axis_dim = grid .get_axis_dim_mapping (data .dims )
197197
198198 # Create DataArrays for indexing
199199 selection_dict = {
200- axis_dim ["X" ]: xr .DataArray (xi , dims = ("points" )),
201- axis_dim ["Y" ]: xr .DataArray (yi , dims = ("points" )),
200+ axis_dim ["X" ]: xr .DataArray (xi_full , dims = ("points" )),
201+ axis_dim ["Y" ]: xr .DataArray (yi_full , dims = ("points" )),
202202 }
203203 if "Z" in axis_dim :
204- selection_dict [axis_dim ["Z" ]] = xr .DataArray (zi , dims = ("points" ))
204+ selection_dict [axis_dim ["Z" ]] = xr .DataArray (zi_full , dims = ("points" ))
205205 if "time" in data .dims :
206- selection_dict ["time" ] = xr .DataArray (ti , dims = ("points" ))
206+ selection_dict ["time" ] = xr .DataArray (ti_full , dims = ("points" ))
207207
208208 corner_data = data .isel (selection_dict ).data .reshape (lenT , len (xsi ), 4 )
209209
@@ -271,7 +271,53 @@ def CGrid_Velocity(
271271 xx = (1 - xsi ) * (1 - eta ) * px [0 ] + xsi * (1 - eta ) * px [1 ] + xsi * eta * px [2 ] + (1 - xsi ) * eta * px [3 ]
272272 u = np .where (np .abs (xx - x ) > 1e-4 , np .nan , u )
273273
274- return (u , v , 0 ) # TODO fix and test W also
274+ if vectorfield .W :
275+ data = vectorfield .W .data
276+ # Time coordinates: 2 points at ti, then 2 points at ti+1
277+ if lenT == 1 :
278+ ti_full = np .repeat (ti , 2 )
279+ else :
280+ ti_1 = np .clip (ti + 1 , 0 , tdim - 1 )
281+ ti_full = np .concatenate ([np .repeat (ti , 2 ), np .repeat (ti_1 , 2 )])
282+
283+ # Depth coordinates: 1 points at zi, repeated for both time levels
284+ zi_1 = np .clip (zi + 1 , 0 , zdim - 1 )
285+ zi_full = np .tile (np .array ([zi , zi_1 ]).flatten (), lenT )
286+
287+ # Y coordinates: yi+1 for each spatial point, repeated for time/depth
288+ yi_1 = np .clip (yi + 1 , 0 , ydim - 1 )
289+ yi_full = np .tile (yi_1 , (lenT ) * 2 )
290+
291+ # X coordinates: xi+1 for each spatial point, repeated for time/depth
292+ xi_1 = np .clip (xi + 1 , 0 , xdim - 1 )
293+ xi_full = np .tile (xi_1 , (lenT ) * 2 )
294+
295+ axis_dim = grid .get_axis_dim_mapping (data .dims )
296+
297+ # Create DataArrays for indexing
298+ selection_dict = {
299+ axis_dim ["X" ]: xr .DataArray (xi_full , dims = ("points" )),
300+ axis_dim ["Y" ]: xr .DataArray (yi_full , dims = ("points" )),
301+ axis_dim ["Z" ]: xr .DataArray (zi_full , dims = ("points" )),
302+ }
303+ if "time" in data .dims :
304+ selection_dict ["time" ] = xr .DataArray (ti_full , dims = ("points" ))
305+
306+ corner_data = data .isel (selection_dict ).data .reshape (lenT , 2 , len (xsi ))
307+
308+ if lenT == 2 :
309+ tau_full = tau [np .newaxis , :]
310+ corner_data = corner_data [0 , :, :] * (1 - tau_full ) + corner_data [1 , :, :] * tau_full
311+ else :
312+ corner_data = corner_data [0 , :, :]
313+
314+ w = corner_data [0 , :] * (1 - zeta ) + corner_data [1 , :] * zeta
315+ if isinstance (w , dask .Array ):
316+ w = w .compute ()
317+ else :
318+ w = np .zeros_like (u )
319+
320+ return (u , v , w )
275321
276322
277323def CGrid_Tracer (
0 commit comments