diff --git a/parcels/_index_search.py b/parcels/_index_search.py index 4638b1848..4d470488e 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -44,30 +44,31 @@ def _search_time_index(field: Field, time: datetime): if time not in field.time_interval: _raise_time_extrapolation_error(time, field=None) - time_index = field.data.time <= time + time_index = field.data.time.data <= time if time_index.all(): # If given time > last known field time, use # the last field frame without interpolation - ti = len(field.data.time) - 1 - + ti = len(field.data.time.data) - 1 elif np.logical_not(time_index).all(): # If given time < any time in the field, use # the first field frame without interpolation ti = 0 else: ti = int(time_index.argmin() - 1) if time_index.any() else 0 - if len(field.data.time) == 1: + + if len(field.data.time.data) == 1: tau = 0 - elif ti == len(field.data.time) - 1: + elif ti == len(field.data.time.data) - 1: tau = 1 else: tau = ( - (time - field.data.time[ti]).dt.total_seconds() - / (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds() - if field.data.time[ti] != field.data.time[ti + 1] + (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti]) + if field.data.time.data[ti] != field.data.time.data[ti + 1] else 0 ) + if tau < 0 or tau > 1: # TODO only for debugging; test can go? + raise ValueError(f"Time {time} is out of bounds for field time data {field.data.time.data}.") return tau, ti diff --git a/parcels/field.py b/parcels/field.py index 1a1fbb717..4fbe5a815 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -150,7 +150,7 @@ def __init__( data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid) self.name = name - self.data = data + self.data_full = data self.grid = grid try: @@ -189,8 +189,8 @@ def __init__( else: raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'") - if self.data.shape[0] > 1: - if "time" not in self.data.coords: + if data.shape[0] > 1: + if "time" not in data.coords: raise ValueError("Field data is missing a 'time' coordinate.") @property @@ -205,27 +205,27 @@ def units(self, value): @property def xdim(self): - if type(self.data) is xr.DataArray: + if type(self.data_full) is xr.DataArray: return self.grid.xdim else: raise NotImplementedError("xdim not implemented for unstructured grids") @property def ydim(self): - if type(self.data) is xr.DataArray: + if type(self.data_full) is xr.DataArray: return self.grid.ydim else: raise NotImplementedError("ydim not implemented for unstructured grids") @property def zdim(self): - if type(self.data) is xr.DataArray: + if type(self.data_full) is xr.DataArray: return self.grid.zdim else: - if "nz1" in self.data.dims: - return self.data.sizes["nz1"] - elif "nz" in self.data.dims: - return self.data.sizes["nz"] + if "nz1" in self.data_full.dims: + return self.data_full.sizes["nz1"] + elif "nz" in self.data_full.dims: + return self.data_full.sizes["nz"] else: return 0 @@ -246,6 +246,19 @@ def _check_velocitysampling(self): stacklevel=2, ) + def _load_timesteps(self, time): + """Load the appropriate timesteps of a field.""" + ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0 + if not hasattr(self, "data"): + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + elif self.data_full.time.data[ti] == self.data.time.data[1]: + self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time") + elif self.data_full.time.data[ti] != self.data.time.data[0]: + self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load() + assert ( + len(self.data.time) == 2 + ), f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}." + def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): """Interpolate field values in space and time. @@ -266,17 +279,14 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): if np.isnan(value): # Detect Out-of-bounds sampling and raise exception _raise_field_out_of_bound_error(z, y, x) - else: - return value except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e: e.add_note(f"Error interpolating field '{self.name}'.") raise e if applyConversion: - return self.units.to_target(value, z, y, x) - else: - return value + value = self.units.to_target(value, z, y, x) + return value def __getitem__(self, key): self._check_velocitysampling() @@ -359,7 +369,6 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True): else: (u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x) - # print(u,v) if applyConversion: u = self.U.units.to_target(u, z, y, x) v = self.V.units.to_target(v, z, y, x) diff --git a/parcels/fieldset.py b/parcels/fieldset.py index 5937c7f57..1fe9a6656 100644 --- a/parcels/fieldset.py +++ b/parcels/fieldset.py @@ -82,6 +82,13 @@ def time_interval(self): return None return functools.reduce(lambda x, y: x.intersection(y), time_intervals) + def _load_timesteps(self, time): + """Load the appropriate timesteps of all fields in the fieldset.""" + for fldname in self.fields: + field = self.fields[fldname] + if isinstance(field, Field): + field._load_timesteps(time) + def add_field(self, field: Field, name: str | None = None): """Add a :class:`parcels.field.Field` object to the FieldSet. diff --git a/parcels/particleset.py b/parcels/particleset.py index f018a356d..1e235f0e5 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -805,10 +805,13 @@ def execute( time = start_time while sign_dt * (time - end_time) < 0: + # Load the appropriate timesteps of the fieldset + self.fieldset._load_timesteps(self._data["time_nextloop"][0]) + if sign_dt > 0: - next_time = end_time # TODO update to min(next_output, end_time) when ParticleFile works + next_time = min(time + dt, end_time) else: - next_time = end_time # TODO update to max(next_output, end_time) when ParticleFile works + next_time = max(time - dt, end_time) res = self._kernel.execute(self, endtime=next_time, dt=dt) if res == StatusCode.StopAllExecution: return StatusCode.StopAllExecution diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 81391ee81..95f05a69f 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -270,11 +270,11 @@ def _gtype(self): def search(self, z, y, x, ei=None): ds = self.xgcm_grid._ds - zi, zeta = _search_1d_array(ds.depth.values, z) + zi, zeta = _search_1d_array(ds.depth.data, z) if ds.lon.ndim == 1: - yi, eta = _search_1d_array(ds.lat.values, y) - xi, xsi = _search_1d_array(ds.lon.values, x) + yi, eta = _search_1d_array(ds.lat.data, y) + xi, xsi = _search_1d_array(ds.lon.data, x) return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)} yi, xi = None, None @@ -453,8 +453,12 @@ def _search_1d_array( float Barycentric coordinate. """ + if len(arr) < 2: + return 0, 0.0 i = np.argmin(arr <= x) - 1 bcoord = (x - arr[i]) / (arr[i + 1] - arr[i]) + if bcoord < 0 or bcoord > 1: # TODO only for debugging; test can go? + raise ValueError(f"Position {x} is out of bounds for array {arr}.") return i, bcoord