Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 11 additions & 31 deletions docs/examples/example_globcurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
def set_globcurrent_fieldset(
filename=None,
indices=None,
deferred_load=True,
use_xarray=False,
timestamps=None,
):
Expand Down Expand Up @@ -41,7 +40,6 @@ def set_globcurrent_fieldset(
variables,
dimensions,
indices,
deferred_load=deferred_load,
timestamps=timestamps,
)

Expand Down Expand Up @@ -80,9 +78,7 @@ def test_globcurrent_fieldset_advancetime(dt, lonstart, latstart, use_xarray):
lat=[latstart],
)

fieldsetall = set_globcurrent_fieldset(
files[0:10], deferred_load=False, use_xarray=use_xarray
)
fieldsetall = set_globcurrent_fieldset(files[0:10], use_xarray=use_xarray)
psetall = parcels.ParticleSet.from_list(
fieldset=fieldsetall,
pclass=parcels.Particle,
Expand Down Expand Up @@ -118,32 +114,6 @@ def test_globcurrent_particles(use_xarray):
assert abs(pset[0].lat - -35.3) < 1


@pytest.mark.v4remove
@pytest.mark.xfail(reason="time_periodic removed in v4")
@pytest.mark.parametrize("rundays", [300, 900])
def test_globcurrent_time_periodic(rundays):
sample_var = []
for deferred_load in [True, False]:
fieldset = set_globcurrent_fieldset(
time_periodic=timedelta(days=365), deferred_load=deferred_load
)

MyParticle = parcels.Particle.add_variable("sample_var", initial=0.0)

pset = parcels.ParticleSet(
fieldset, pclass=MyParticle, lon=25, lat=-35, time=fieldset.U.grid.time[0]
)

def SampleU(particle, fieldset, time): # pragma: no cover
u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon]
particle.sample_var += u

pset.execute(SampleU, runtime=timedelta(days=rundays), dt=timedelta(days=1))
sample_var.append(pset[0].sample_var)

assert np.allclose(sample_var[0], sample_var[1])


@pytest.mark.parametrize("dt", [-300, 300])
def test_globcurrent_xarray_vs_netcdf(dt):
fieldsetNetcdf = set_globcurrent_fieldset(use_xarray=False)
Expand Down Expand Up @@ -241,9 +211,19 @@ def test_globcurrent_time_extrapolation_error(use_xarray):
)


@pytest.mark.v4alpha
@pytest.mark.xfail(
reason="This was always broken when using eager loading `deferred_load=False` for the P field. Needs to be fixed."
)
@pytest.mark.parametrize("dt", [-300, 300])
@pytest.mark.parametrize("with_starttime", [True, False])
def test_globcurrent_startparticles_between_time_arrays(dt, with_starttime):
"""Test for correctly initialising particle start times.

When using Fields with different temporal domains, its important to intialise particles
at the beginning of the time period where all Fields have available data (i.e., the
intersection of the temporal domains)
"""
fieldset = set_globcurrent_fieldset()

data_folder = parcels.download_example_dataset("GlobCurrent_example_data")
Expand Down
5 changes: 2 additions & 3 deletions docs/examples/example_ofam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import parcels


def set_ofam_fieldset(deferred_load=True, use_xarray=False):
def set_ofam_fieldset(use_xarray=False):
data_folder = parcels.download_example_dataset("OFAM_example_data")
filenames = {
"U": f"{data_folder}/OFAM_simple_U.nc",
Expand All @@ -32,13 +32,12 @@ def set_ofam_fieldset(deferred_load=True, use_xarray=False):
variables,
dimensions,
allow_time_extrapolation=True,
deferred_load=deferred_load,
)


@pytest.mark.parametrize("use_xarray", [True, False])
def test_ofam_fieldset_fillvalues(use_xarray):
fieldset = set_ofam_fieldset(deferred_load=False, use_xarray=use_xarray)
fieldset = set_ofam_fieldset(use_xarray=use_xarray)
# V.data[0, 0, 150] is a landpoint, that makes NetCDF4 generate a masked array, instead of an ndarray
assert fieldset.V.data[0, 0, 150] == 0

Expand Down
1 change: 0 additions & 1 deletion parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Mesh = Literal["spherical", "flat"] # corresponds with `mesh`
VectorType = Literal["3D", "3DSigma", "2D"] | None # corresponds with `vector_type`
GridIndexingType = Literal["pop", "mom5", "mitgcm", "nemo", "croco"] # corresponds with `gridindexingtype`
UpdateStatus = Literal["not_updated", "first_updated", "updated"] # corresponds with `_update_status`
NetcdfEngine = Literal["netcdf4", "xarray", "scipy"]


Expand Down
187 changes: 29 additions & 158 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@
self.filebuffername = name[1]
self.data = data
if grid:
if grid.defer_load and isinstance(data, np.ndarray):
raise ValueError(
"Cannot combine Grid from defer_loaded Field with np.ndarray data. please specify lon, lat, depth and time dimensions separately"
)
self._grid = grid
else:
if (time is not None) and isinstance(time[0], np.datetime64):
Expand Down Expand Up @@ -225,14 +221,12 @@
else:
self.allow_time_extrapolation = allow_time_extrapolation

if not self.grid.defer_load:
self.data = self._reshape(self.data)
self._loaded_time_indices = range(self.grid.tdim)

# Hack around the fact that NaN and ridiculously large values
# propagate in SciPy's interpolators
self.data[np.isnan(self.data)] = 0.0
self.data = self._reshape(self.data)
self._loaded_time_indices = range(self.grid.tdim)

# Hack around the fact that NaN and ridiculously large values
# propagate in SciPy's interpolators
self.data[np.isnan(self.data)] = 0.0
self._scaling_factor = None

self._dimensions = kwargs.pop("dimensions", None)
Expand Down Expand Up @@ -355,7 +349,6 @@
mesh: Mesh = "spherical",
timestamps=None,
allow_time_extrapolation: bool | None = None,
deferred_load: bool = True,
**kwargs,
) -> "Field":
"""Create field from netCDF file.
Expand Down Expand Up @@ -388,11 +381,6 @@
boolean whether to allow for extrapolation in time
(i.e. beyond the last available time snapshot)
Default is False if dimensions includes time, else True
deferred_load : bool
boolean whether to only pre-load data (in deferred mode) or
fully load them (default: True). It is advised to deferred load the data, since in
that case Parcels deals with a better memory management during particle set execution.
deferred_load=False is however sometimes necessary for plotting the fields.
gridindexingtype : str
The type of gridindexing. Either 'nemo' (default), 'mitgcm', 'mom5', 'pop', or 'croco' are supported.
See also the Grid indexing documentation on oceanparcels.org
Expand Down Expand Up @@ -551,56 +539,29 @@
"time dimension in indices is not necessary anymore. It is then ignored.", FieldSetWarning, stacklevel=2
)

if grid.time.size <= 2:
deferred_load = False

if not deferred_load:
# Pre-allocate data before reading files into buffer
data_list = []
ti = 0
for tslice, fname in zip(grid.timeslices, data_filenames, strict=True):
with NetcdfFileBuffer( # type: ignore[operator]
fname,
dimensions,
indices,
netcdf_engine,
interp_method=interp_method,
data_full_zdim=data_full_zdim,
) as filebuffer:
# If Field.from_netcdf is called directly, it may not have a 'data' dimension
# In that case, assume that 'name' is the data dimension
filebuffer.name = variable[1]
buffer_data = filebuffer.data
if len(buffer_data.shape) == 4:
errormessage = (
f"Field {filebuffer.name} expecting a data shape of [tdim={grid.tdim}, zdim={grid.zdim}, "
f"ydim={grid.ydim}, xdim={grid.xdim }] "
f"but got shape {buffer_data.shape}."
)
assert buffer_data.shape[0] == grid.tdim, errormessage
assert buffer_data.shape[2] == grid.ydim, errormessage
assert buffer_data.shape[3] == grid.xdim, errormessage

if len(buffer_data.shape) == 2:
data_list.append(buffer_data.reshape(sum(((len(tslice), 1), buffer_data.shape), ())))
elif len(buffer_data.shape) == 3:
if len(filebuffer.indices["depth"]) > 1:
data_list.append(buffer_data.reshape(sum(((1,), buffer_data.shape), ())))
else:
if type(tslice) not in [list, np.ndarray, xr.DataArray]:
tslice = [tslice]
data_list.append(buffer_data.reshape(sum(((len(tslice), 1), buffer_data.shape[1:]), ())))
else:
data_list.append(buffer_data)
if type(tslice) not in [list, np.ndarray, xr.DataArray]:
tslice = [tslice]
ti += len(tslice)
data = np.concatenate(data_list, axis=0)
else:
grid._defer_load = True
grid._ti = -1
data = DeferredArray()
data.compute_shape(grid.xdim, grid.ydim, grid.zdim, grid.tdim, len(grid.timeslices))
with NetcdfFileBuffer( # type: ignore[operator]
data_filenames,
dimensions,
indices,
netcdf_engine,

Check warning on line 546 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L543-L546

Added lines #L543 - L546 were not covered by tests
interp_method=interp_method,
data_full_zdim=data_full_zdim,
) as filebuffer:

Check warning on line 549 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L548-L549

Added lines #L548 - L549 were not covered by tests
# If Field.from_netcdf is called directly, it may not have a 'data' dimension
# In that case, assume that 'name' is the data dimension
filebuffer.name = variable[1]
buffer_data = filebuffer.data
if len(buffer_data.shape) == 4:
errormessage = (
f"Field {filebuffer.name} expecting a data shape of [tdim={grid.tdim}, zdim={grid.zdim}, "
f"ydim={grid.ydim}, xdim={grid.xdim }] "
f"but got shape {buffer_data.shape}."

Check warning on line 558 in parcels/field.py

View check run for this annotation

Codecov / codecov/patch

parcels/field.py#L556-L558

Added lines #L556 - L558 were not covered by tests
)
assert buffer_data.shape[0] == grid.tdim, errormessage
assert buffer_data.shape[2] == grid.ydim, errormessage
assert buffer_data.shape[3] == grid.xdim, errormessage

data = buffer_data

if allow_time_extrapolation is None:
allow_time_extrapolation = False if "time" in dimensions else True
Expand Down Expand Up @@ -727,8 +688,7 @@
if self._scaling_factor:
raise NotImplementedError(f"Scaling factor for field {self.name} already defined.")
self._scaling_factor = factor
if not self.grid.defer_load:
self.data *= factor
self.data *= factor

def set_depth_from_field(self, field):
"""Define the depth dimensions from another (time-varying) field.
Expand Down Expand Up @@ -913,69 +873,6 @@
data *= self._scaling_factor
return data

def _data_concatenate(self, data, data_to_concat, tindex):
if data[tindex] is not None:
if isinstance(data, np.ndarray):
data[tindex] = None
elif isinstance(data, list):
del data[tindex]
if tindex == 0:
data = np.concatenate([data_to_concat, data[tindex + 1 :, :]], axis=0)
elif tindex == 1:
data = np.concatenate([data[:tindex, :], data_to_concat], axis=0)
else:
raise ValueError("data_concatenate is used for computeTimeChunk, with tindex in [0, 1]")
return data

def computeTimeChunk(self, data, tindex):
g = self.grid
timestamp = self.timestamps
if timestamp is not None:
summedlen = np.cumsum([len(ls) for ls in self.timestamps])
if g._ti + tindex >= summedlen[-1]:
ti = g._ti + tindex - summedlen[-1]
else:
ti = g._ti + tindex
timestamp = self.timestamps[np.where(ti < summedlen)[0][0]]

filebuffer = NetcdfFileBuffer(
self._dataFiles[g._ti + tindex],
self.dimensions,
self.indices,
netcdf_engine=self.netcdf_engine,
timestamp=timestamp,
interp_method=self.interp_method,
data_full_zdim=self.data_full_zdim,
)
filebuffer.__enter__()
time_data = filebuffer.time
time_data = g.time_origin.reltime(time_data)
filebuffer.ti = (time_data <= g.time[tindex]).argmin() - 1
if self.netcdf_engine != "xarray":
filebuffer.name = self.filebuffername
buffer_data = filebuffer.data
if len(buffer_data.shape) == 2:
buffer_data = np.reshape(buffer_data, sum(((1, 1), buffer_data.shape), ()))
elif len(buffer_data.shape) == 3 and g.zdim > 1:
buffer_data = np.reshape(buffer_data, sum(((1,), buffer_data.shape), ()))
elif len(buffer_data.shape) == 3:
buffer_data = np.reshape(
buffer_data,
sum(
(
(
buffer_data.shape[0],
1,
),
buffer_data.shape[1:],
),
(),
),
)
data = self._data_concatenate(data, buffer_data, tindex)
self.filebuffers[tindex] = filebuffer
return data

def ravel_index(self, zi, yi, xi):
"""Return the flat index of the given grid points.

Expand Down Expand Up @@ -1560,32 +1457,6 @@
return _deal_with_errors(error, key, vector_type=self.vector_type)


class DeferredArray:
"""Class used for throwing error when Field.data is not read in deferred loading mode."""

data_shape = ()

def __init__(self):
self.data_shape = (1,)

def compute_shape(self, xdim, ydim, zdim, tdim, tslices):
if zdim == 1 and tdim == 1:
self.data_shape = (tslices, 1, ydim, xdim)
elif zdim > 1 or tdim > 1:
if zdim > 1:
self.data_shape = (1, zdim, ydim, xdim)
else:
self.data_shape = (max(tdim, tslices), 1, ydim, xdim)
else:
self.data_shape = (tdim, zdim, ydim, xdim)
return self.data_shape

def __getitem__(self, key):
raise RuntimeError(
"Field is in deferred_load mode, so can't be accessed. Use .computeTimeChunk() method to force loading of data"
)


class NestedField(list):
"""NestedField is a class that allows for interpolation of fields on different grids of potentially varying resolution.

Expand Down
Loading
Loading