@@ -270,8 +270,7 @@ def __init__(
270270
271271 # Hack around the fact that NaN and ridiculously large values
272272 # propagate in SciPy's interpolators
273- lib = np if isinstance (self .data , np .ndarray ) else da
274- self .data [lib .isnan (self .data )] = 0.0
273+ self .data [np .isnan (self .data )] = 0.0
275274 if self .vmin is not None :
276275 self .data [self .data < self .vmin ] = 0.0
277276 if self .vmax is not None :
@@ -662,16 +661,15 @@ def from_netcdf(
662661 if len (filebuffer .indices ["depth" ]) > 1 :
663662 data_list .append (buffer_data .reshape (sum (((1 ,), buffer_data .shape ), ())))
664663 else :
665- if type (tslice ) not in [list , np .ndarray , da . Array , xr .DataArray ]:
664+ if type (tslice ) not in [list , np .ndarray , xr .DataArray ]:
666665 tslice = [tslice ]
667666 data_list .append (buffer_data .reshape (sum (((len (tslice ), 1 ), buffer_data .shape [1 :]), ())))
668667 else :
669668 data_list .append (buffer_data )
670- if type (tslice ) not in [list , np .ndarray , da . Array , xr .DataArray ]:
669+ if type (tslice ) not in [list , np .ndarray , xr .DataArray ]:
671670 tslice = [tslice ]
672671 ti += len (tslice )
673- lib = np if isinstance (data_list [0 ], np .ndarray ) else da
674- data = lib .concatenate (data_list , axis = 0 )
672+ data = np .concatenate (data_list , axis = 0 )
675673 else :
676674 grid ._defer_load = True
677675 grid ._ti = - 1
@@ -752,24 +750,23 @@ def from_xarray(
752750
753751 def _reshape (self , data , transpose = False ):
754752 # Ensure that field data is the right data type
755- if not isinstance (data , (np .ndarray , da . core . Array )):
753+ if not isinstance (data , (np .ndarray )):
756754 data = np .array (data )
757755 if (self .cast_data_dtype == np .float32 ) and (data .dtype != np .float32 ):
758756 data = data .astype (np .float32 )
759757 elif (self .cast_data_dtype == np .float64 ) and (data .dtype != np .float64 ):
760758 data = data .astype (np .float64 )
761- lib = np if isinstance (data , np .ndarray ) else da
762759 if transpose :
763- data = lib .transpose (data )
760+ data = np .transpose (data )
764761 if self .grid ._lat_flipped :
765- data = lib .flip (data , axis = - 2 )
762+ data = np .flip (data , axis = - 2 )
766763
767764 if self .grid .xdim == 1 or self .grid .ydim == 1 :
768- data = lib .squeeze (data ) # First remove all length-1 dimensions in data, so that we can add them below
765+ data = np .squeeze (data ) # First remove all length-1 dimensions in data, so that we can add them below
769766 if self .grid .xdim == 1 and len (data .shape ) < 4 :
770- data = lib .expand_dims (data , axis = - 1 )
767+ data = np .expand_dims (data , axis = - 1 )
771768 if self .grid .ydim == 1 and len (data .shape ) < 4 :
772- data = lib .expand_dims (data , axis = - 2 )
769+ data = np .expand_dims (data , axis = - 2 )
773770 if self .grid .tdim == 1 :
774771 if len (data .shape ) < 4 :
775772 data = data .reshape (sum (((1 ,), data .shape ), ()))
@@ -913,8 +910,6 @@ def _spatial_interpolation(self, ti, z, y, x, time, particle=None):
913910 # Detect Out-of-bounds sampling and raise exception
914911 _raise_field_out_of_bound_error (z , y , x )
915912 else :
916- if isinstance (val , da .core .Array ):
917- val = val .compute ()
918913 return val
919914
920915 except (FieldSamplingError , FieldOutOfBoundError , FieldOutOfBoundSurfaceError ) as e :
@@ -1008,26 +1003,25 @@ def add_periodic_halo(self, zonal, meridional, halosize=5, data=None):
10081003 data :
10091004 if data is not None, the periodic halo will be achieved on data instead of self.data and data will be returned (Default value = None)
10101005 """
1011- dataNone = not isinstance (data , ( np .ndarray , da . core . Array ) )
1006+ dataNone = not isinstance (data , np .ndarray )
10121007 if self .grid .defer_load and dataNone :
10131008 return
10141009 data = self .data if dataNone else data
1015- lib = np if isinstance (data , np .ndarray ) else da
10161010 if zonal :
10171011 if len (data .shape ) == 3 :
1018- data = lib .concatenate ((data [:, :, - halosize :], data , data [:, :, 0 :halosize ]), axis = len (data .shape ) - 1 )
1012+ data = np .concatenate ((data [:, :, - halosize :], data , data [:, :, 0 :halosize ]), axis = len (data .shape ) - 1 )
10191013 assert data .shape [2 ] == self .grid .xdim , "Third dim must be x."
10201014 else :
1021- data = lib .concatenate (
1015+ data = np .concatenate (
10221016 (data [:, :, :, - halosize :], data , data [:, :, :, 0 :halosize ]), axis = len (data .shape ) - 1
10231017 )
10241018 assert data .shape [3 ] == self .grid .xdim , "Fourth dim must be x."
10251019 if meridional :
10261020 if len (data .shape ) == 3 :
1027- data = lib .concatenate ((data [:, - halosize :, :], data , data [:, 0 :halosize , :]), axis = len (data .shape ) - 2 )
1021+ data = np .concatenate ((data [:, - halosize :, :], data , data [:, 0 :halosize , :]), axis = len (data .shape ) - 2 )
10281022 assert data .shape [1 ] == self .grid .ydim , "Second dim must be y."
10291023 else :
1030- data = lib .concatenate (
1024+ data = np .concatenate (
10311025 (data [:, :, - halosize :, :], data , data [:, :, 0 :halosize , :]), axis = len (data .shape ) - 2
10321026 )
10331027 assert data .shape [2 ] == self .grid .ydim , "Third dim must be y."
@@ -1099,11 +1093,10 @@ def _data_concatenate(self, data, data_to_concat, tindex):
10991093 data [tindex ] = None
11001094 elif isinstance (data , list ):
11011095 del data [tindex ]
1102- lib = np if isinstance (data , np .ndarray ) else da
11031096 if tindex == 0 :
1104- data = lib .concatenate ([data_to_concat , data [tindex + 1 :, :]], axis = 0 )
1097+ data = np .concatenate ([data_to_concat , data [tindex + 1 :, :]], axis = 0 )
11051098 elif tindex == 1 :
1106- data = lib .concatenate ([data [:tindex , :], data_to_concat ], axis = 0 )
1099+ data = np .concatenate ([data [:tindex , :], data_to_concat ], axis = 0 )
11071100 else :
11081101 raise ValueError ("data_concatenate is used for computeTimeChunk, with tindex in [0, 1]" )
11091102 return data
@@ -1136,13 +1129,12 @@ def computeTimeChunk(self, data, tindex):
11361129 if self .netcdf_engine != "xarray" :
11371130 filebuffer .name = filebuffer .parse_name (self .filebuffername )
11381131 buffer_data = filebuffer .data
1139- lib = np if isinstance (buffer_data , np .ndarray ) else da
11401132 if len (buffer_data .shape ) == 2 :
1141- buffer_data = lib .reshape (buffer_data , sum (((1 , 1 ), buffer_data .shape ), ()))
1133+ buffer_data = np .reshape (buffer_data , sum (((1 , 1 ), buffer_data .shape ), ()))
11421134 elif len (buffer_data .shape ) == 3 and g .zdim > 1 :
1143- buffer_data = lib .reshape (buffer_data , sum (((1 ,), buffer_data .shape ), ()))
1135+ buffer_data = np .reshape (buffer_data , sum (((1 ,), buffer_data .shape ), ()))
11441136 elif len (buffer_data .shape ) == 3 :
1145- buffer_data = lib .reshape (
1137+ buffer_data = np .reshape (
11461138 buffer_data ,
11471139 sum (
11481140 (
0 commit comments