Skip to content

Commit 8f0565e

Browse files
Merge pull request #1908 from OceanParcels/deferred_load
Remove deferred loading of FieldSets
2 parents ec29021 + 20a0caa commit 8f0565e

File tree

12 files changed

+141
-459
lines changed

12 files changed

+141
-459
lines changed

docs/examples/example_globcurrent.py

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
def set_globcurrent_fieldset(
1212
filename=None,
1313
indices=None,
14-
deferred_load=True,
1514
use_xarray=False,
1615
timestamps=None,
1716
):
@@ -41,7 +40,6 @@ def set_globcurrent_fieldset(
4140
variables,
4241
dimensions,
4342
indices,
44-
deferred_load=deferred_load,
4543
timestamps=timestamps,
4644
)
4745

@@ -80,9 +78,7 @@ def test_globcurrent_fieldset_advancetime(dt, lonstart, latstart, use_xarray):
8078
lat=[latstart],
8179
)
8280

83-
fieldsetall = set_globcurrent_fieldset(
84-
files[0:10], deferred_load=False, use_xarray=use_xarray
85-
)
81+
fieldsetall = set_globcurrent_fieldset(files[0:10], use_xarray=use_xarray)
8682
psetall = parcels.ParticleSet.from_list(
8783
fieldset=fieldsetall,
8884
pclass=parcels.Particle,
@@ -118,32 +114,6 @@ def test_globcurrent_particles(use_xarray):
118114
assert abs(pset[0].lat - -35.3) < 1
119115

120116

121-
@pytest.mark.v4remove
122-
@pytest.mark.xfail(reason="time_periodic removed in v4")
123-
@pytest.mark.parametrize("rundays", [300, 900])
124-
def test_globcurrent_time_periodic(rundays):
125-
sample_var = []
126-
for deferred_load in [True, False]:
127-
fieldset = set_globcurrent_fieldset(
128-
time_periodic=timedelta(days=365), deferred_load=deferred_load
129-
)
130-
131-
MyParticle = parcels.Particle.add_variable("sample_var", initial=0.0)
132-
133-
pset = parcels.ParticleSet(
134-
fieldset, pclass=MyParticle, lon=25, lat=-35, time=fieldset.U.grid.time[0]
135-
)
136-
137-
def SampleU(particle, fieldset, time): # pragma: no cover
138-
u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
139-
particle.sample_var += u
140-
141-
pset.execute(SampleU, runtime=timedelta(days=rundays), dt=timedelta(days=1))
142-
sample_var.append(pset[0].sample_var)
143-
144-
assert np.allclose(sample_var[0], sample_var[1])
145-
146-
147117
@pytest.mark.parametrize("dt", [-300, 300])
148118
def test_globcurrent_xarray_vs_netcdf(dt):
149119
fieldsetNetcdf = set_globcurrent_fieldset(use_xarray=False)
@@ -241,9 +211,19 @@ def test_globcurrent_time_extrapolation_error(use_xarray):
241211
)
242212

243213

214+
@pytest.mark.v4alpha
215+
@pytest.mark.xfail(
216+
reason="This was always broken when using eager loading `deferred_load=False` for the P field. Needs to be fixed."
217+
)
244218
@pytest.mark.parametrize("dt", [-300, 300])
245219
@pytest.mark.parametrize("with_starttime", [True, False])
246220
def test_globcurrent_startparticles_between_time_arrays(dt, with_starttime):
221+
"""Test for correctly initialising particle start times.
222+
223+
When using Fields with different temporal domains, its important to intialise particles
224+
at the beginning of the time period where all Fields have available data (i.e., the
225+
intersection of the temporal domains)
226+
"""
247227
fieldset = set_globcurrent_fieldset()
248228

249229
data_folder = parcels.download_example_dataset("GlobCurrent_example_data")

docs/examples/example_ofam.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import parcels
99

1010

11-
def set_ofam_fieldset(deferred_load=True, use_xarray=False):
11+
def set_ofam_fieldset(use_xarray=False):
1212
data_folder = parcels.download_example_dataset("OFAM_example_data")
1313
filenames = {
1414
"U": f"{data_folder}/OFAM_simple_U.nc",
@@ -32,13 +32,12 @@ def set_ofam_fieldset(deferred_load=True, use_xarray=False):
3232
variables,
3333
dimensions,
3434
allow_time_extrapolation=True,
35-
deferred_load=deferred_load,
3635
)
3736

3837

3938
@pytest.mark.parametrize("use_xarray", [True, False])
4039
def test_ofam_fieldset_fillvalues(use_xarray):
41-
fieldset = set_ofam_fieldset(deferred_load=False, use_xarray=use_xarray)
40+
fieldset = set_ofam_fieldset(use_xarray=use_xarray)
4241
# V.data[0, 0, 150] is a landpoint, that makes NetCDF4 generate a masked array, instead of an ndarray
4342
assert fieldset.V.data[0, 0, 150] == 0
4443

parcels/_typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
Mesh = Literal["spherical", "flat"] # corresponds with `mesh`
3030
VectorType = Literal["3D", "3DSigma", "2D"] | None # corresponds with `vector_type`
3131
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo", "croco"] # corresponds with `gridindexingtype`
32-
UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `_update_status`
3332
NetcdfEngine = Literal["netcdf4", "xarray", "scipy"]
3433

3534

parcels/field.py

Lines changed: 29 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ def __init__(
177177
self.filebuffername = name[1]
178178
self.data = data
179179
if grid:
180-
if grid.defer_load and isinstance(data, np.ndarray):
181-
raise ValueError(
182-
"Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately"
183-
)
184180
self._grid = grid
185181
else:
186182
if (time is not None) and isinstance(time[0], np.datetime64):
@@ -225,14 +221,12 @@ def __init__(
225221
else:
226222
self.allow_time_extrapolation = allow_time_extrapolation
227223

228-
if not self.grid.defer_load:
229-
self.data = self._reshape(self.data)
230-
self._loaded_time_indices = range(self.grid.tdim)
231-
232-
# Hack around the fact that NaN and ridiculously large values
233-
# propagate in SciPy's interpolators
234-
self.data[np.isnan(self.data)] = 0.0
224+
self.data = self._reshape(self.data)
225+
self._loaded_time_indices = range(self.grid.tdim)
235226

227+
# Hack around the fact that NaN and ridiculously large values
228+
# propagate in SciPy's interpolators
229+
self.data[np.isnan(self.data)] = 0.0
236230
self._scaling_factor = None
237231

238232
self._dimensions = kwargs.pop("dimensions", None)
@@ -355,7 +349,6 @@ def from_netcdf(
355349
mesh: Mesh = "spherical",
356350
timestamps=None,
357351
allow_time_extrapolation: bool | None = None,
358-
deferred_load: bool = True,
359352
**kwargs,
360353
) -> "Field":
361354
"""Create field from netCDF file.
@@ -388,11 +381,6 @@ def from_netcdf(
388381
boolean whether to allow for extrapolation in time
389382
(i.e. beyond the last available time snapshot)
390383
Default is False if dimensions includes time, else True
391-
deferred_load : bool
392-
boolean whether to only pre-load data (in deferred mode) or
393-
fully load them (default: True). It is advised to deferred load the data, since in
394-
that case Parcels deals with a better memory management during particle set execution.
395-
deferred_load=False is however sometimes necessary for plotting the fields.
396384
gridindexingtype : str
397385
The type of gridindexing. Either 'nemo' (default), 'mitgcm', 'mom5', 'pop', or 'croco' are supported.
398386
See also the Grid indexing documentation on oceanparcels.org
@@ -551,56 +539,29 @@ def from_netcdf(
551539
"time dimension in indices is not necessary anymore. It is then ignored.", FieldSetWarning, stacklevel=2
552540
)
553541

554-
if grid.time.size <= 2:
555-
deferred_load = False
556-
557-
if not deferred_load:
558-
# Pre-allocate data before reading files into buffer
559-
data_list = []
560-
ti = 0
561-
for tslice, fname in zip(grid.timeslices, data_filenames, strict=True):
562-
with NetcdfFileBuffer( # type: ignore[operator]
563-
fname,
564-
dimensions,
565-
indices,
566-
netcdf_engine,
567-
interp_method=interp_method,
568-
data_full_zdim=data_full_zdim,
569-
) as filebuffer:
570-
# If Field.from_netcdf is called directly, it may not have a 'data' dimension
571-
# In that case, assume that 'name' is the data dimension
572-
filebuffer.name = variable[1]
573-
buffer_data = filebuffer.data
574-
if len(buffer_data.shape) == 4:
575-
errormessage = (
576-
f"Field {filebuffer.name} expecting a data shape of [tdim={grid.tdim}, zdim={grid.zdim}, "
577-
f"ydim={grid.ydim}, xdim={grid.xdim }] "
578-
f"but got shape {buffer_data.shape}."
579-
)
580-
assert buffer_data.shape[0] == grid.tdim, errormessage
581-
assert buffer_data.shape[2] == grid.ydim, errormessage
582-
assert buffer_data.shape[3] == grid.xdim, errormessage
583-
584-
if len(buffer_data.shape) == 2:
585-
data_list.append(buffer_data.reshape(sum(((len(tslice), 1), buffer_data.shape), ())))
586-
elif len(buffer_data.shape) == 3:
587-
if len(filebuffer.indices["depth"]) > 1:
588-
data_list.append(buffer_data.reshape(sum(((1,), buffer_data.shape), ())))
589-
else:
590-
if type(tslice) not in [list, np.ndarray, xr.DataArray]:
591-
tslice = [tslice]
592-
data_list.append(buffer_data.reshape(sum(((len(tslice), 1), buffer_data.shape[1:]), ())))
593-
else:
594-
data_list.append(buffer_data)
595-
if type(tslice) not in [list, np.ndarray, xr.DataArray]:
596-
tslice = [tslice]
597-
ti += len(tslice)
598-
data = np.concatenate(data_list, axis=0)
599-
else:
600-
grid._defer_load = True
601-
grid._ti = -1
602-
data = DeferredArray()
603-
data.compute_shape(grid.xdim, grid.ydim, grid.zdim, grid.tdim, len(grid.timeslices))
542+
with NetcdfFileBuffer( # type: ignore[operator]
543+
data_filenames,
544+
dimensions,
545+
indices,
546+
netcdf_engine,
547+
interp_method=interp_method,
548+
data_full_zdim=data_full_zdim,
549+
) as filebuffer:
550+
# If Field.from_netcdf is called directly, it may not have a 'data' dimension
551+
# In that case, assume that 'name' is the data dimension
552+
filebuffer.name = variable[1]
553+
buffer_data = filebuffer.data
554+
if len(buffer_data.shape) == 4:
555+
errormessage = (
556+
f"Field {filebuffer.name} expecting a data shape of [tdim={grid.tdim}, zdim={grid.zdim}, "
557+
f"ydim={grid.ydim}, xdim={grid.xdim }] "
558+
f"but got shape {buffer_data.shape}."
559+
)
560+
assert buffer_data.shape[0] == grid.tdim, errormessage
561+
assert buffer_data.shape[2] == grid.ydim, errormessage
562+
assert buffer_data.shape[3] == grid.xdim, errormessage
563+
564+
data = buffer_data
604565

605566
if allow_time_extrapolation is None:
606567
allow_time_extrapolation = False if "time" in dimensions else True
@@ -727,8 +688,7 @@ def set_scaling_factor(self, factor):
727688
if self._scaling_factor:
728689
raise NotImplementedError(f"Scaling factor for field {self.name} already defined.")
729690
self._scaling_factor = factor
730-
if not self.grid.defer_load:
731-
self.data *= factor
691+
self.data *= factor
732692

733693
def set_depth_from_field(self, field):
734694
"""Define the depth dimensions from another (time-varying) field.
@@ -913,69 +873,6 @@ def _rescale_and_set_minmax(self, data):
913873
data *= self._scaling_factor
914874
return data
915875

916-
def _data_concatenate(self, data, data_to_concat, tindex):
917-
if data[tindex] is not None:
918-
if isinstance(data, np.ndarray):
919-
data[tindex] = None
920-
elif isinstance(data, list):
921-
del data[tindex]
922-
if tindex == 0:
923-
data = np.concatenate([data_to_concat, data[tindex + 1 :, :]], axis=0)
924-
elif tindex == 1:
925-
data = np.concatenate([data[:tindex, :], data_to_concat], axis=0)
926-
else:
927-
raise ValueError("data_concatenate is used for computeTimeChunk, with tindex in [0, 1]")
928-
return data
929-
930-
def computeTimeChunk(self, data, tindex):
931-
g = self.grid
932-
timestamp = self.timestamps
933-
if timestamp is not None:
934-
summedlen = np.cumsum([len(ls) for ls in self.timestamps])
935-
if g._ti + tindex >= summedlen[-1]:
936-
ti = g._ti + tindex - summedlen[-1]
937-
else:
938-
ti = g._ti + tindex
939-
timestamp = self.timestamps[np.where(ti < summedlen)[0][0]]
940-
941-
filebuffer = NetcdfFileBuffer(
942-
self._dataFiles[g._ti + tindex],
943-
self.dimensions,
944-
self.indices,
945-
netcdf_engine=self.netcdf_engine,
946-
timestamp=timestamp,
947-
interp_method=self.interp_method,
948-
data_full_zdim=self.data_full_zdim,
949-
)
950-
filebuffer.__enter__()
951-
time_data = filebuffer.time
952-
time_data = g.time_origin.reltime(time_data)
953-
filebuffer.ti = (time_data <= g.time[tindex]).argmin() - 1
954-
if self.netcdf_engine != "xarray":
955-
filebuffer.name = self.filebuffername
956-
buffer_data = filebuffer.data
957-
if len(buffer_data.shape) == 2:
958-
buffer_data = np.reshape(buffer_data, sum(((1, 1), buffer_data.shape), ()))
959-
elif len(buffer_data.shape) == 3 and g.zdim > 1:
960-
buffer_data = np.reshape(buffer_data, sum(((1,), buffer_data.shape), ()))
961-
elif len(buffer_data.shape) == 3:
962-
buffer_data = np.reshape(
963-
buffer_data,
964-
sum(
965-
(
966-
(
967-
buffer_data.shape[0],
968-
1,
969-
),
970-
buffer_data.shape[1:],
971-
),
972-
(),
973-
),
974-
)
975-
data = self._data_concatenate(data, buffer_data, tindex)
976-
self.filebuffers[tindex] = filebuffer
977-
return data
978-
979876
def ravel_index(self, zi, yi, xi):
980877
"""Return the flat index of the given grid points.
981878
@@ -1560,32 +1457,6 @@ def __getitem__(self, key):
15601457
return _deal_with_errors(error, key, vector_type=self.vector_type)
15611458

15621459

1563-
class DeferredArray:
1564-
"""Class used for throwing error when Field.data is not read in deferred loading mode."""
1565-
1566-
data_shape = ()
1567-
1568-
def __init__(self):
1569-
self.data_shape = (1,)
1570-
1571-
def compute_shape(self, xdim, ydim, zdim, tdim, tslices):
1572-
if zdim == 1 and tdim == 1:
1573-
self.data_shape = (tslices, 1, ydim, xdim)
1574-
elif zdim > 1 or tdim > 1:
1575-
if zdim > 1:
1576-
self.data_shape = (1, zdim, ydim, xdim)
1577-
else:
1578-
self.data_shape = (max(tdim, tslices), 1, ydim, xdim)
1579-
else:
1580-
self.data_shape = (tdim, zdim, ydim, xdim)
1581-
return self.data_shape
1582-
1583-
def __getitem__(self, key):
1584-
raise RuntimeError(
1585-
"Field is in deferred_load mode, so can't be accessed. Use .computeTimeChunk() method to force loading of data"
1586-
)
1587-
1588-
15891460
class NestedField(list):
15901461
"""NestedField is a class that allows for interpolation of fields on different grids of potentially varying resolution.
15911462

0 commit comments

Comments
 (0)