Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,30 +44,31 @@ def _search_time_index(field: Field, time: datetime):
if time not in field.time_interval:
_raise_time_extrapolation_error(time, field=None)

time_index = field.data.time <= time
time_index = field.data.time.data <= time

if time_index.all():
# If given time > last known field time, use
# the last field frame without interpolation
ti = len(field.data.time) - 1

ti = len(field.data.time.data) - 1
elif np.logical_not(time_index).all():
# If given time < any time in the field, use
# the first field frame without interpolation
ti = 0
else:
ti = int(time_index.argmin() - 1) if time_index.any() else 0
if len(field.data.time) == 1:

if len(field.data.time.data) == 1:
tau = 0
elif ti == len(field.data.time) - 1:
elif ti == len(field.data.time.data) - 1:
tau = 1
else:
tau = (
(time - field.data.time[ti]).dt.total_seconds()
/ (field.data.time[ti + 1] - field.data.time[ti]).dt.total_seconds()
if field.data.time[ti] != field.data.time[ti + 1]
(time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
if field.data.time.data[ti] != field.data.time.data[ti + 1]
else 0
)
if tau < 0 or tau > 1: # TODO only for debugging; test can go?
raise ValueError(f"Time {time} is out of bounds for field time data {field.data.time.data}.")
return tau, ti


Expand Down
41 changes: 25 additions & 16 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
data = _transpose_xfield_data_to_tzyx(data, grid.xgcm_grid)

self.name = name
self.data = data
self.data_full = data
self.grid = grid

try:
Expand Down Expand Up @@ -189,8 +189,8 @@ def __init__(
else:
raise ValueError("Unsupported mesh type in data array attributes. Choose either: 'spherical' or 'flat'")

if self.data.shape[0] > 1:
if "time" not in self.data.coords:
if data.shape[0] > 1:
if "time" not in data.coords:
raise ValueError("Field data is missing a 'time' coordinate.")

@property
Expand All @@ -205,27 +205,27 @@ def units(self, value):

@property
def xdim(self):
if type(self.data) is xr.DataArray:
if type(self.data_full) is xr.DataArray:
return self.grid.xdim
else:
raise NotImplementedError("xdim not implemented for unstructured grids")

@property
def ydim(self):
if type(self.data) is xr.DataArray:
if type(self.data_full) is xr.DataArray:
return self.grid.ydim
else:
raise NotImplementedError("ydim not implemented for unstructured grids")

@property
def zdim(self):
if type(self.data) is xr.DataArray:
if type(self.data_full) is xr.DataArray:
return self.grid.zdim
else:
if "nz1" in self.data.dims:
return self.data.sizes["nz1"]
elif "nz" in self.data.dims:
return self.data.sizes["nz"]
if "nz1" in self.data_full.dims:
return self.data_full.sizes["nz1"]
elif "nz" in self.data_full.dims:
return self.data_full.sizes["nz"]
else:
return 0

Expand All @@ -246,6 +246,19 @@ def _check_velocitysampling(self):
stacklevel=2,
)

def _load_timesteps(self, time):
"""Load the appropriate timesteps of a field."""
ti = np.argmin(self.data_full.time.data <= time) - 1 # TODO also implement dt < 0
if not hasattr(self, "data"):
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
elif self.data_full.time.data[ti] == self.data.time.data[1]:
self.data = xr.concat([self.data[1, :], self.data_full.isel({"time": ti + 1}).load()], dim="time")
elif self.data_full.time.data[ti] != self.data.time.data[0]:
self.data = self.data_full.isel({"time": slice(ti, ti + 2)}).load()
assert (
len(self.data.time) == 2
), f"Field {self.name} has not been loaded correctly. Expected 2 timesteps, but got {len(self.data.time)}."

def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
"""Interpolate field values in space and time.

Expand All @@ -266,17 +279,14 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
if np.isnan(value):
# Detect Out-of-bounds sampling and raise exception
_raise_field_out_of_bound_error(z, y, x)
else:
return value

except (FieldSamplingError, FieldOutOfBoundError, FieldOutOfBoundSurfaceError) as e:
e.add_note(f"Error interpolating field '{self.name}'.")
raise e

if applyConversion:
return self.units.to_target(value, z, y, x)
else:
return value
value = self.units.to_target(value, z, y, x)
return value

def __getitem__(self, key):
self._check_velocitysampling()
Expand Down Expand Up @@ -359,7 +369,6 @@ def eval(self, time: datetime, z, y, x, particle=None, applyConversion=True):
else:
(u, v, w) = self._vector_interp_method(self, ti, position, time, z, y, x)

# print(u,v)
if applyConversion:
u = self.U.units.to_target(u, z, y, x)
v = self.V.units.to_target(v, z, y, x)
Expand Down
7 changes: 7 additions & 0 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ def time_interval(self):
return None
return functools.reduce(lambda x, y: x.intersection(y), time_intervals)

def _load_timesteps(self, time):
"""Load the appropriate timesteps of all fields in the fieldset."""
for fldname in self.fields:
field = self.fields[fldname]
if isinstance(field, Field):
field._load_timesteps(time)

def add_field(self, field: Field, name: str | None = None):
"""Add a :class:`parcels.field.Field` object to the FieldSet.

Expand Down
7 changes: 5 additions & 2 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,10 +805,13 @@ def execute(

time = start_time
while sign_dt * (time - end_time) < 0:
# Load the appropriate timesteps of the fieldset
self.fieldset._load_timesteps(self._data["time_nextloop"][0])

if sign_dt > 0:
next_time = end_time # TODO update to min(next_output, end_time) when ParticleFile works
next_time = min(time + dt, end_time)
else:
next_time = end_time # TODO update to max(next_output, end_time) when ParticleFile works
next_time = max(time - dt, end_time)
res = self._kernel.execute(self, endtime=next_time, dt=dt)
if res == StatusCode.StopAllExecution:
return StatusCode.StopAllExecution
Expand Down
10 changes: 7 additions & 3 deletions parcels/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,11 @@ def _gtype(self):
def search(self, z, y, x, ei=None):
ds = self.xgcm_grid._ds

zi, zeta = _search_1d_array(ds.depth.values, z)
zi, zeta = _search_1d_array(ds.depth.data, z)

if ds.lon.ndim == 1:
yi, eta = _search_1d_array(ds.lat.values, y)
xi, xsi = _search_1d_array(ds.lon.values, x)
yi, eta = _search_1d_array(ds.lat.data, y)
xi, xsi = _search_1d_array(ds.lon.data, x)
return {"Z": (zi, zeta), "Y": (yi, eta), "X": (xi, xsi)}

yi, xi = None, None
Expand Down Expand Up @@ -453,8 +453,12 @@ def _search_1d_array(
float
Barycentric coordinate.
"""
if len(arr) < 2:
return 0, 0.0
i = np.argmin(arr <= x) - 1
bcoord = (x - arr[i]) / (arr[i + 1] - arr[i])
if bcoord < 0 or bcoord > 1: # TODO only for debugging; test can go?
raise ValueError(f"Position {x} is out of bounds for array {arr}.")
return i, bcoord


Expand Down
Loading