diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 6c7f8a26c..1f1e06811 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -35,6 +35,7 @@ def _search_time_index(field: Field, time: datetime): if not field.time_interval.is_all_time_in_interval(time): _raise_time_extrapolation_error(time, field=None) + # TODO this could be sped up when data has only two timeslices (i.e. when data_full is not None)? ti = np.searchsorted(field.data.time.data, time, side="right") - 1 tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti]) return np.atleast_1d(tau), np.atleast_1d(ti) diff --git a/parcels/field.py b/parcels/field.py index b248414a4..24ba91583 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -8,6 +8,7 @@ import numpy as np import uxarray as ux import xarray as xr +from dask import is_dask_collection from parcels._core.utils.time import TimeInterval from parcels._reprs import default_repr @@ -132,8 +133,14 @@ def __init__( data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid) self.name = name - self.data = data self.grid = grid + if is_dask_collection(data) and ("time" in data.dims): + self.data = None + self.data_full = data + else: + self.data = data + self.data_full = None + self._nexttime_to_load = None try: self.time_interval = _get_time_interval(data) @@ -167,8 +174,8 @@ def __init__( elif self.grid._mesh == "spherical": self.units = unitconverters_map[self.name] - if self.data.shape[0] > 1: - if "time" not in self.data.coords: + if data.shape[0] > 1: + if "time" not in data.coords: raise ValueError("Field data is missing a 'time' coordinate.") @property @@ -183,25 +190,29 @@ def units(self, value): @property def xdim(self): - if type(self.data) is xr.DataArray: + if hasattr(self.grid, "xdim"): return self.grid.xdim else: raise NotImplementedError("xdim not implemented for unstructured grids") @property def ydim(self): - if type(self.data) is xr.DataArray: + if hasattr(self.grid, "ydim"): return self.grid.ydim else: raise NotImplementedError("ydim not implemented for unstructured grids") @property def zdim(self): - if type(self.data) is xr.DataArray: + if hasattr(self.grid, "zdim"): return self.grid.zdim else: - if "nz1" in self.data.dims: + if "nz1" in self.data_full.dims: + return self.data_full.sizes["nz1"] + elif "nz1" in self.data.dims: return self.data.sizes["nz1"] + elif "nz" in self.data_full.dims: + return self.data_full.sizes["nz"] elif "nz" in self.data.dims: return self.data.sizes["nz"] else: @@ -224,6 +235,21 @@ def _check_velocitysampling(self): stacklevel=2, ) + def _load_timesteps(self, time): + """Load the appropriate timesteps of a field.""" + if self.data_full is not None: + ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0 + if self.data is None: + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + elif self.data_full.time.data[ti] == self.data.time.data[1]: + self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time") + elif self.data_full.time.data[ti] != self.data.time.data[0]: + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + assert len(self.data.time) == 2, ( + f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}." + ) + self._nexttime_to_load = self.data_full.time.data[ti + 1] + def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): """Interpolate field values in space and time. diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 113ef637d..6f666c3c6 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -82,6 +82,18 @@ def time_interval(self): return None return functools.reduce(lambda x, y: x.intersection(y), time_intervals) + def _load_timesteps(self, time): + """Load the appropriate timesteps of all fields in the fieldset.""" + next_times = [] + for fldname in self.fields: + field = self.fields[fldname] + if isinstance(field, Field): + field._load_timesteps(time) + if field._nexttime_to_load is not None: + next_times.append(field._nexttime_to_load) + + return min(next_times) if next_times else None + def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. diff --git a/parcels/particleset.py b/parcels/particleset.py index b44cef0be..187beabbf 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -547,11 +547,16 @@ def execute( time = start_time while sign_dt * (time - end_time) < 0: + # Load the appropriate timesteps of the fieldset + next_load_time = self.fieldset._load_timesteps(time) + + possible_next_time = [end_time] + if next_load_time is not None: + possible_next_time.append(next_load_time) if next_output is not None: - f = min if sign_dt > 0 else max - next_time = f(next_output, end_time) - else: - next_time = end_time + possible_next_time.append(next_output) + f = min if sign_dt > 0 else max + next_time = f(possible_next_time) self._kernel.execute(self, endtime=next_time, dt=dt) diff --git a/tests/v4/test_advection.py b/tests/v4/test_advection.py index a8bae9686..79396b909 100644 --- a/tests/v4/test_advection.py +++ b/tests/v4/test_advection.py @@ -110,9 +110,9 @@ def test_horizontal_advection_in_3D_flow(npart=10): """Flat 2D zonal flow that increases linearly with depth from 0 m/s to 1 m/s.""" ds = simple_UV_dataset(mesh="flat") ds["U"].data[:] = 1.0 + ds["U"].data[:, 0, :, :] = 0.0 # Set U to 0 at the surface grid = XGrid.from_dataset(ds) U = Field("U", ds["U"], grid, interp_method=XLinear) - U.data[:, 0, :, :] = 0.0 # Set U to 0 at the surface V = Field("V", ds["V"], grid, interp_method=XLinear) UV = VectorField("UV", U, V) fieldset = FieldSet([U, V, UV]) @@ -128,12 +128,13 @@ def test_horizontal_advection_in_3D_flow(npart=10): @pytest.mark.parametrize("wErrorThroughSurface", [True, False]) def test_advection_3D_outofbounds(direction, wErrorThroughSurface): ds = simple_UV_dataset(mesh="flat") + ds["U"].data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds) + ds["W"] = ds["V"].copy() # Use V as W for testing + ds["W"].data[:] = -1.0 if direction == "up" else 1.0 grid = XGrid.from_dataset(ds) U = Field("U", ds["U"], grid, interp_method=XLinear) - U.data[:] = 0.01 # Set U to small value (to avoid horizontal out of bounds) V = Field("V", ds["V"], grid, interp_method=XLinear) - W = Field("W", ds["V"], grid, interp_method=XLinear) # Use V as W for testing - W.data[:] = -1.0 if direction == "up" else 1.0 + W = Field("W", ds["W"], grid, interp_method=XLinear) UVW = VectorField("UVW", U, V, W) UV = VectorField("UV", U, V) fieldset = FieldSet([U, V, W, UVW, UV]) @@ -213,6 +214,7 @@ def test_length1dimensions(u, v, w): # TODO: Refactor this test to be more read fields = [U, V, VectorField("UV", U, V)] if w: W = Field("W", ds["W"], grid, interp_method=XLinear) + fields.append(W) fields.append(VectorField("UVW", U, V, W)) fieldset = FieldSet(fields) diff --git a/tests/v4/test_uxarray_fieldset.py b/tests/v4/test_uxarray_fieldset.py index d3c8b5de6..53d91085c 100644 --- a/tests/v4/test_uxarray_fieldset.py +++ b/tests/v4/test_uxarray_fieldset.py @@ -78,16 +78,16 @@ def uvw_fesom_channel(ds_fesom_channel) -> VectorField: def test_fesom_fieldset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() + assert (fieldset.U.data_full == ds_fesom_channel.U).all() + assert (fieldset.V.data_full == ds_fesom_channel.V).all() def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() + assert (fieldset.U.data_full == ds_fesom_channel.U).all() + assert (fieldset.V.data_full == ds_fesom_channel.V).all() pset = ParticleSet(fieldset, pclass=Particle) assert pset.fieldset == fieldset @@ -95,8 +95,8 @@ def test_fesom_in_particleset(ds_fesom_channel, uv_fesom_channel): def test_set_interp_methods(ds_fesom_channel, uv_fesom_channel): fieldset = FieldSet([uv_fesom_channel, uv_fesom_channel.U, uv_fesom_channel.V]) # Check that the fieldset has the expected properties - assert (fieldset.U.data == ds_fesom_channel.U).all() - assert (fieldset.V.data == ds_fesom_channel.V).all() + assert (fieldset.U.data_full == ds_fesom_channel.U).all() + assert (fieldset.V.data_full == ds_fesom_channel.V).all() # Set the interpolation method for each field fieldset.U.interp_method = UXPiecewiseConstantFace