Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def search_indices_vertical_s(


def _search_indices_rectilinear(
field: Field, time: float, z: float, y: float, x: float, ti=-1, particle=None, search2D=False
field: Field, time: float, z: float, y: float, x: float, ti: int, particle=None, search2D=False
):
grid = field.grid

Expand Down Expand Up @@ -223,7 +223,7 @@ def _search_indices_rectilinear(
return (zeta, eta, xsi, zi, yi, xi)


def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False):
def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, search2D=False):
if particle:
zi, yi, xi = field.unravel_index(particle.ei)
else:
Expand Down
2 changes: 1 addition & 1 deletion parcels/application_kernels/advection.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]])

zeta, eta, xsi, zi, yi, xi = fieldset.U._search_indices(
-1, particle.depth, particle.lat, particle.lon, particle=particle
time, particle.depth, particle.lat, particle.lon, ti, particle=particle

Check warning on line 188 in parcels/application_kernels/advection.py

View check run for this annotation

Codecov / codecov/patch

parcels/application_kernels/advection.py#L188

Added line #L188 was not covered by tests
)
if withW:
if abs(xsi - 1) < tol:
Expand Down
22 changes: 11 additions & 11 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,15 +877,15 @@ def cell_areas(self):
"""
return _calc_cell_areas(self.grid)

def _search_indices(self, time, z, y, x, ti=-1, particle=None, search2D=False):
def _search_indices(self, time, z, y, x, ti, particle=None, search2D=False):
if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
return _search_indices_rectilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D)
else:
return _search_indices_curvilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D)

def _interpolator2D(self, ti, z, y, x, particle=None):
def _interpolator2D(self, time, z, y, x, ti, particle=None):
"""Impelement 2D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.."""
(_, eta, xsi, _, yi, xi) = self._search_indices(-1, z, y, x, particle=particle)
(_, eta, xsi, _, yi, xi) = self._search_indices(time, z, y, x, ti, particle=particle)
ctx = InterpolationContext2D(self.data, eta, xsi, ti, yi, xi)

try:
Expand All @@ -899,7 +899,7 @@ def _interpolator2D(self, ti, z, y, x, particle=None):
raise RuntimeError(self.interp_method + " is not implemented for 2D grids")
return f(ctx)

def _interpolator3D(self, ti, z, y, x, time, particle=None):
def _interpolator3D(self, time, z, y, x, ti, particle=None):
"""Impelement 3D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.."""
(zeta, eta, xsi, zi, yi, xi) = self._search_indices(time, z, y, x, ti, particle=particle)
ctx = InterpolationContext3D(self.data, zeta, eta, xsi, ti, zi, yi, xi, self.gridindexingtype)
Expand All @@ -910,13 +910,13 @@ def _interpolator3D(self, ti, z, y, x, time, particle=None):
raise RuntimeError(self.interp_method + " is not implemented for 3D grids")
return f(ctx)

def _spatial_interpolation(self, ti, z, y, x, time, particle=None):
"""Interpolate horizontal field values using a SciPy interpolator."""
def _spatial_interpolation(self, time, z, y, x, ti, particle=None):
"""Interpolate spatial field values."""
try:
if self.grid.zdim == 1:
val = self._interpolator2D(ti, z, y, x, particle=particle)
val = self._interpolator2D(time, z, y, x, ti, particle=particle)
else:
val = self._interpolator3D(ti, z, y, x, time, particle=particle)
val = self._interpolator3D(time, z, y, x, ti, particle=particle)

if np.isnan(val):
# Detect Out-of-bounds sampling and raise exception
Expand Down Expand Up @@ -980,16 +980,16 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True):
if self.gridindexingtype == "croco" and self not in [self.fieldset.H, self.fieldset.Zeta]:
z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle)
if ti < self.grid.tdim - 1 and time > self.grid.time[ti]:
f0 = self._spatial_interpolation(ti, z, y, x, time, particle=particle)
f1 = self._spatial_interpolation(ti + 1, z, y, x, time, particle=particle)
f0 = self._spatial_interpolation(time, z, y, x, ti, particle=particle)
f1 = self._spatial_interpolation(time, z, y, x, ti + 1, particle=particle)
t0 = self.grid.time[ti]
t1 = self.grid.time[ti + 1]
value = f0 + (f1 - f0) * ((time - t0) / (t1 - t0))
else:
# Skip temporal interpolation if time is outside
# of the defined time range or if we have hit an
# exact value in the time array.
value = self._spatial_interpolation(ti, z, y, x, self.grid.time[ti], particle=particle)
value = self._spatial_interpolation(self.grid.time[ti], z, y, x, ti, particle=particle)

if applyConversion:
return self.units.to_target(value, z, y, x)
Expand Down
4 changes: 2 additions & 2 deletions parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
self._ncount = len(lon)

for v in self.ptype.variables:
if v.name in ["ei", "ti"]:
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype)
if v.name == "ei":
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype) # TODO len(lon) can be self._ncount?
else:
self._data[v.name] = np.empty(self._ncount, dtype=v.dtype)

Expand Down
3 changes: 0 additions & 3 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,11 @@ def ArrayClass_init(self, *args, **kwargs):
self.ngrids = type(self).ngrids.initial
if self.ngrids >= 0:
self.ei = np.zeros(self.ngrids, dtype=np.int32)
self.ti = -1 * np.ones(self.ngrids, dtype=np.int32)
super(type(self), self).__init__(*args, **kwargs)

array_class_vdict = {
"ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1),
"ei": Variable("ei", dtype=np.int32, to_write=False),
"ti": Variable("ti", dtype=np.int32, to_write=False, initial=-1),
"__init__": ArrayClass_init,
}
array_class = type(class_name, (pclass,), array_class_vdict)
Expand Down Expand Up @@ -719,7 +717,6 @@ def from_particlefile(
v.name
not in [
"ei",
"ti",
"dt",
"depth",
"id",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ def test_fieldset_write(tmp_zarrfile):
def UpdateU(particle, fieldset, time): # pragma: no cover
tmp1, tmp2 = fieldset.UV[particle]
_, yi, xi = fieldset.U.unravel_index(particle.ei)
fieldset.U.data[particle.ti, yi, xi] += 1
ti = fieldset.U._time_index(time)
fieldset.U.data[ti, yi, xi] += 1
fieldset.U.grid.time[0] = time

pset = ParticleSet(fieldset, pclass=Particle, lon=5, lat=5)
Expand Down
Loading