Skip to content

Commit 38770d1

Browse files
Further removals of DaskFielBuffer
1 parent 121aa3c commit 38770d1

File tree

2 files changed

+36
-53
lines changed

2 files changed

+36
-53
lines changed

parcels/field.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,7 @@ def __init__(
270270

271271
# Hack around the fact that NaN and ridiculously large values
272272
# propagate in SciPy's interpolators
273-
lib = np if isinstance(self.data, np.ndarray) else da
274-
self.data[lib.isnan(self.data)] = 0.0
273+
self.data[np.isnan(self.data)] = 0.0
275274
if self.vmin is not None:
276275
self.data[self.data < self.vmin] = 0.0
277276
if self.vmax is not None:
@@ -662,16 +661,15 @@ def from_netcdf(
662661
if len(filebuffer.indices["depth"]) > 1:
663662
data_list.append(buffer_data.reshape(sum(((1,), buffer_data.shape), ())))
664663
else:
665-
if type(tslice) not in [list, np.ndarray, da.Array, xr.DataArray]:
664+
if type(tslice) not in [list, np.ndarray, xr.DataArray]:
666665
tslice = [tslice]
667666
data_list.append(buffer_data.reshape(sum(((len(tslice), 1), buffer_data.shape[1:]), ())))
668667
else:
669668
data_list.append(buffer_data)
670-
if type(tslice) not in [list, np.ndarray, da.Array, xr.DataArray]:
669+
if type(tslice) not in [list, np.ndarray, xr.DataArray]:
671670
tslice = [tslice]
672671
ti += len(tslice)
673-
lib = np if isinstance(data_list[0], np.ndarray) else da
674-
data = lib.concatenate(data_list, axis=0)
672+
data = np.concatenate(data_list, axis=0)
675673
else:
676674
grid._defer_load = True
677675
grid._ti = -1
@@ -752,24 +750,23 @@ def from_xarray(
752750

753751
def _reshape(self, data, transpose=False):
754752
# Ensure that field data is the right data type
755-
if not isinstance(data, (np.ndarray, da.core.Array)):
753+
if not isinstance(data, (np.ndarray)):
756754
data = np.array(data)
757755
if (self.cast_data_dtype == np.float32) and (data.dtype != np.float32):
758756
data = data.astype(np.float32)
759757
elif (self.cast_data_dtype == np.float64) and (data.dtype != np.float64):
760758
data = data.astype(np.float64)
761-
lib = np if isinstance(data, np.ndarray) else da
762759
if transpose:
763-
data = lib.transpose(data)
760+
data = np.transpose(data)
764761
if self.grid._lat_flipped:
765-
data = lib.flip(data, axis=-2)
762+
data = np.flip(data, axis=-2)
766763

767764
if self.grid.xdim == 1 or self.grid.ydim == 1:
768-
data = lib.squeeze(data) # First remove all length-1 dimensions in data, so that we can add them below
765+
data = np.squeeze(data) # First remove all length-1 dimensions in data, so that we can add them below
769766
if self.grid.xdim == 1 and len(data.shape) < 4:
770-
data = lib.expand_dims(data, axis=-1)
767+
data = np.expand_dims(data, axis=-1)
771768
if self.grid.ydim == 1 and len(data.shape) < 4:
772-
data = lib.expand_dims(data, axis=-2)
769+
data = np.expand_dims(data, axis=-2)
773770
if self.grid.tdim == 1:
774771
if len(data.shape) < 4:
775772
data = data.reshape(sum(((1,), data.shape), ()))
@@ -913,8 +910,6 @@ def _spatial_interpolation(self, ti, z, y, x, time, particle=None):
913910
# Detect Out-of-bounds sampling and raise exception
914911
_raise_field_out_of_bound_error(z, y, x)
915912
else:
916-
if isinstance(val, da.core.Array):
917-
val = val.compute()
918913
return val
919914

920915
except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e:
@@ -1008,26 +1003,25 @@ def add_periodic_halo(self, zonal, meridional, halosize=5, data=None):
10081003
data :
10091004
if data is not None, the periodic halo will be achieved on data instead of self.data and data will be returned (Default value = None)
10101005
"""
1011-
dataNone = not isinstance(data, (np.ndarray, da.core.Array))
1006+
dataNone = not isinstance(data, np.ndarray)
10121007
if self.grid.defer_load and dataNone:
10131008
return
10141009
data = self.data if dataNone else data
1015-
lib = np if isinstance(data, np.ndarray) else da
10161010
if zonal:
10171011
if len(data.shape) == 3:
1018-
data = lib.concatenate((data[:, :, -halosize:], data, data[:, :, 0:halosize]), axis=len(data.shape) - 1)
1012+
data = np.concatenate((data[:, :, -halosize:], data, data[:, :, 0:halosize]), axis=len(data.shape) - 1)
10191013
assert data.shape[2] == self.grid.xdim, "Third dim must be x."
10201014
else:
1021-
data = lib.concatenate(
1015+
data = np.concatenate(
10221016
(data[:, :, :, -halosize:], data, data[:, :, :, 0:halosize]), axis=len(data.shape) - 1
10231017
)
10241018
assert data.shape[3] == self.grid.xdim, "Fourth dim must be x."
10251019
if meridional:
10261020
if len(data.shape) == 3:
1027-
data = lib.concatenate((data[:, -halosize:, :], data, data[:, 0:halosize, :]), axis=len(data.shape) - 2)
1021+
data = np.concatenate((data[:, -halosize:, :], data, data[:, 0:halosize, :]), axis=len(data.shape) - 2)
10281022
assert data.shape[1] == self.grid.ydim, "Second dim must be y."
10291023
else:
1030-
data = lib.concatenate(
1024+
data = np.concatenate(
10311025
(data[:, :, -halosize:, :], data, data[:, :, 0:halosize, :]), axis=len(data.shape) - 2
10321026
)
10331027
assert data.shape[2] == self.grid.ydim, "Third dim must be y."
@@ -1099,11 +1093,10 @@ def _data_concatenate(self, data, data_to_concat, tindex):
10991093
data[tindex] = None
11001094
elif isinstance(data, list):
11011095
del data[tindex]
1102-
lib = np if isinstance(data, np.ndarray) else da
11031096
if tindex == 0:
1104-
data = lib.concatenate([data_to_concat, data[tindex + 1 :, :]], axis=0)
1097+
data = np.concatenate([data_to_concat, data[tindex + 1 :, :]], axis=0)
11051098
elif tindex == 1:
1106-
data = lib.concatenate([data[:tindex, :], data_to_concat], axis=0)
1099+
data = np.concatenate([data[:tindex, :], data_to_concat], axis=0)
11071100
else:
11081101
raise ValueError("data_concatenate is used for computeTimeChunk, with tindex in [0, 1]")
11091102
return data
@@ -1136,13 +1129,12 @@ def computeTimeChunk(self, data, tindex):
11361129
if self.netcdf_engine != "xarray":
11371130
filebuffer.name = filebuffer.parse_name(self.filebuffername)
11381131
buffer_data = filebuffer.data
1139-
lib = np if isinstance(buffer_data, np.ndarray) else da
11401132
if len(buffer_data.shape) == 2:
1141-
buffer_data = lib.reshape(buffer_data, sum(((1, 1), buffer_data.shape), ()))
1133+
buffer_data = np.reshape(buffer_data, sum(((1, 1), buffer_data.shape), ()))
11421134
elif len(buffer_data.shape) == 3 and g.zdim > 1:
1143-
buffer_data = lib.reshape(buffer_data, sum(((1,), buffer_data.shape), ()))
1135+
buffer_data = np.reshape(buffer_data, sum(((1,), buffer_data.shape), ()))
11441136
elif len(buffer_data.shape) == 3:
1145-
buffer_data = lib.reshape(
1137+
buffer_data = np.reshape(
11461138
buffer_data,
11471139
sum(
11481140
(

parcels/fieldset.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from copy import deepcopy
66
from glob import glob
77

8-
import dask.array as da
98
import numpy as np
109

1110
from parcels._compat import MPI
@@ -1445,12 +1444,11 @@ def computeTimeChunk(self, time=0.0, dt=1):
14451444
for i in range(len(f.data)):
14461445
del f.data[i, :]
14471446

1448-
lib = np
14491447
if f.gridindexingtype == "pop" and g.zdim > 1:
14501448
zd = g.zdim - 1
14511449
else:
14521450
zd = g.zdim
1453-
data = lib.empty(
1451+
data = np.empty(
14541452
(g.tdim, zd, g.ydim - 2 * g.meridional_halo, g.xdim - 2 * g.zonal_halo), dtype=np.float32
14551453
)
14561454
f._loaded_time_indices = range(2)
@@ -1467,12 +1465,11 @@ def computeTimeChunk(self, time=0.0, dt=1):
14671465
f.data = f._reshape(data)
14681466

14691467
elif g._update_status == "updated":
1470-
lib = np if isinstance(f.data, np.ndarray) else da
14711468
if f.gridindexingtype == "pop" and g.zdim > 1:
14721469
zd = g.zdim - 1
14731470
else:
14741471
zd = g.zdim
1475-
data = lib.empty(
1472+
data = np.empty(
14761473
(g.tdim, zd, g.ydim - 2 * g.meridional_halo, g.xdim - 2 * g.zonal_halo), dtype=np.float32
14771474
)
14781475
if signdt >= 0:
@@ -1492,28 +1489,22 @@ def computeTimeChunk(self, time=0.0, dt=1):
14921489
data = f._rescale_and_set_minmax(data)
14931490
if signdt >= 0:
14941491
data = f._reshape(data)[1, :]
1495-
if lib is da:
1496-
f.data = lib.stack([f.data[1, :], data], axis=0)
1497-
else:
1498-
if not isinstance(f.data, DeferredArray):
1499-
if isinstance(f.data, list):
1500-
del f.data[0, :]
1501-
else:
1502-
f.data[0, :] = None
1503-
f.data[0, :] = f.data[1, :]
1504-
f.data[1, :] = data
1492+
if not isinstance(f.data, DeferredArray):
1493+
if isinstance(f.data, list):
1494+
del f.data[0, :]
1495+
else:
1496+
f.data[0, :] = None
1497+
f.data[0, :] = f.data[1, :]
1498+
f.data[1, :] = data
15051499
else:
15061500
data = f._reshape(data)[0, :]
1507-
if lib is da:
1508-
f.data = lib.stack([data, f.data[0, :]], axis=0)
1509-
else:
1510-
if not isinstance(f.data, DeferredArray):
1511-
if isinstance(f.data, list):
1512-
del f.data[1, :]
1513-
else:
1514-
f.data[1, :] = None
1515-
f.data[1, :] = f.data[0, :]
1516-
f.data[0, :] = data
1501+
if not isinstance(f.data, DeferredArray):
1502+
if isinstance(f.data, list):
1503+
del f.data[1, :]
1504+
else:
1505+
f.data[1, :] = None
1506+
f.data[1, :] = f.data[0, :]
1507+
f.data[0, :] = data
15171508
# do user-defined computations on fieldset data
15181509
if self.compute_on_defer:
15191510
self.compute_on_defer(self)

0 commit comments

Comments
 (0)