Skip to content

Commit fd1dda2

Browse files
Merge pull request #1891 from OceanParcels/removing_particle_ti
Removing particle ti
2 parents 5e36f8e + 560f9ce commit fd1dda2

File tree

6 files changed

+18
-20
lines changed

6 files changed

+18
-20
lines changed

parcels/_index_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def search_indices_vertical_s(
137137

138138

139139
def _search_indices_rectilinear(
140-
field: Field, time: float, z: float, y: float, x: float, ti=-1, particle=None, search2D=False
140+
field: Field, time: float, z: float, y: float, x: float, ti: int, particle=None, search2D=False
141141
):
142142
grid = field.grid
143143

@@ -223,7 +223,7 @@ def _search_indices_rectilinear(
223223
return (zeta, eta, xsi, zi, yi, xi)
224224

225225

226-
def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False):
226+
def _search_indices_curvilinear(field: Field, time, z, y, x, ti, particle=None, search2D=False):
227227
if particle:
228228
zi, yi, xi = field.unravel_index(particle.ei)
229229
else:

parcels/application_kernels/advection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover
185185
ds_t = min(ds_t, time_i[np.where(time - fieldset.U.grid.time[ti] < time_i)[0][0]])
186186

187187
zeta, eta, xsi, zi, yi, xi = fieldset.U._search_indices(
188-
-1, particle.depth, particle.lat, particle.lon, particle=particle
188+
time, particle.depth, particle.lat, particle.lon, ti, particle=particle
189189
)
190190
if withW:
191191
if abs(xsi - 1) < tol:

parcels/field.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -877,15 +877,15 @@ def cell_areas(self):
877877
"""
878878
return _calc_cell_areas(self.grid)
879879

880-
def _search_indices(self, time, z, y, x, ti=-1, particle=None, search2D=False):
880+
def _search_indices(self, time, z, y, x, ti, particle=None, search2D=False):
881881
if self.grid._gtype in [GridType.RectilinearSGrid, GridType.RectilinearZGrid]:
882882
return _search_indices_rectilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D)
883883
else:
884884
return _search_indices_curvilinear(self, time, z, y, x, ti, particle=particle, search2D=search2D)
885885

886-
def _interpolator2D(self, ti, z, y, x, particle=None):
886+
def _interpolator2D(self, time, z, y, x, ti, particle=None):
887887
"""Impelement 2D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.."""
888-
(_, eta, xsi, _, yi, xi) = self._search_indices(-1, z, y, x, particle=particle)
888+
(_, eta, xsi, _, yi, xi) = self._search_indices(time, z, y, x, ti, particle=particle)
889889
ctx = InterpolationContext2D(self.data, eta, xsi, ti, yi, xi)
890890

891891
try:
@@ -899,7 +899,7 @@ def _interpolator2D(self, ti, z, y, x, particle=None):
899899
raise RuntimeError(self.interp_method + " is not implemented for 2D grids")
900900
return f(ctx)
901901

902-
def _interpolator3D(self, ti, z, y, x, time, particle=None):
902+
def _interpolator3D(self, time, z, y, x, ti, particle=None):
903903
"""Impelement 3D interpolation with coordinate transformations as seen in Delandmeter and Van Sebille (2019), 10.5194/gmd-12-3571-2019.."""
904904
(zeta, eta, xsi, zi, yi, xi) = self._search_indices(time, z, y, x, ti, particle=particle)
905905
ctx = InterpolationContext3D(self.data, zeta, eta, xsi, ti, zi, yi, xi, self.gridindexingtype)
@@ -910,13 +910,13 @@ def _interpolator3D(self, ti, z, y, x, time, particle=None):
910910
raise RuntimeError(self.interp_method + " is not implemented for 3D grids")
911911
return f(ctx)
912912

913-
def _spatial_interpolation(self, ti, z, y, x, time, particle=None):
914-
"""Interpolate horizontal field values using a SciPy interpolator."""
913+
def _spatial_interpolation(self, time, z, y, x, ti, particle=None):
914+
"""Interpolate spatial field values."""
915915
try:
916916
if self.grid.zdim == 1:
917-
val = self._interpolator2D(ti, z, y, x, particle=particle)
917+
val = self._interpolator2D(time, z, y, x, ti, particle=particle)
918918
else:
919-
val = self._interpolator3D(ti, z, y, x, time, particle=particle)
919+
val = self._interpolator3D(time, z, y, x, ti, particle=particle)
920920

921921
if np.isnan(val):
922922
# Detect Out-of-bounds sampling and raise exception
@@ -980,16 +980,16 @@ def eval(self, time, z, y, x, particle=None, applyConversion=True):
980980
if self.gridindexingtype == "croco" and self not in [self.fieldset.H, self.fieldset.Zeta]:
981981
z = _croco_from_z_to_sigma_scipy(self.fieldset, time, z, y, x, particle=particle)
982982
if ti < self.grid.tdim - 1 and time > self.grid.time[ti]:
983-
f0 = self._spatial_interpolation(ti, z, y, x, time, particle=particle)
984-
f1 = self._spatial_interpolation(ti + 1, z, y, x, time, particle=particle)
983+
f0 = self._spatial_interpolation(time, z, y, x, ti, particle=particle)
984+
f1 = self._spatial_interpolation(time, z, y, x, ti + 1, particle=particle)
985985
t0 = self.grid.time[ti]
986986
t1 = self.grid.time[ti + 1]
987987
value = f0 + (f1 - f0) * ((time - t0) / (t1 - t0))
988988
else:
989989
# Skip temporal interpolation if time is outside
990990
# of the defined time range or if we have hit an
991991
# exact value in the time array.
992-
value = self._spatial_interpolation(ti, z, y, x, self.grid.time[ti], particle=particle)
992+
value = self._spatial_interpolation(self.grid.time[ti], z, y, x, ti, particle=particle)
993993

994994
if applyConversion:
995995
return self.units.to_target(value, z, y, x)

parcels/particledata.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
122122
self._ncount = len(lon)
123123

124124
for v in self.ptype.variables:
125-
if v.name in ["ei", "ti"]:
126-
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype)
125+
if v.name == "ei":
126+
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype) # TODO len(lon) can be self._ncount?
127127
else:
128128
self._data[v.name] = np.empty(self._ncount, dtype=v.dtype)
129129

parcels/particleset.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,11 @@ def ArrayClass_init(self, *args, **kwargs):
128128
self.ngrids = type(self).ngrids.initial
129129
if self.ngrids >= 0:
130130
self.ei = np.zeros(self.ngrids, dtype=np.int32)
131-
self.ti = -1 * np.ones(self.ngrids, dtype=np.int32)
132131
super(type(self), self).__init__(*args, **kwargs)
133132

134133
array_class_vdict = {
135134
"ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1),
136135
"ei": Variable("ei", dtype=np.int32, to_write=False),
137-
"ti": Variable("ti", dtype=np.int32, to_write=False, initial=-1),
138136
"__init__": ArrayClass_init,
139137
}
140138
array_class = type(class_name, (pclass,), array_class_vdict)
@@ -719,7 +717,6 @@ def from_particlefile(
719717
v.name
720718
not in [
721719
"ei",
722-
"ti",
723720
"dt",
724721
"depth",
725722
"id",

tests/test_fieldset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,8 @@ def test_fieldset_write(tmp_zarrfile):
571571
def UpdateU(particle, fieldset, time): # pragma: no cover
572572
tmp1, tmp2 = fieldset.UV[particle]
573573
_, yi, xi = fieldset.U.unravel_index(particle.ei)
574-
fieldset.U.data[particle.ti, yi, xi] += 1
574+
ti = fieldset.U._time_index(time)
575+
fieldset.U.data[ti, yi, xi] += 1
575576
fieldset.U.grid.time[0] = time
576577

577578
pset = ParticleSet(fieldset, pclass=Particle, lon=5, lat=5)

0 commit comments

Comments
 (0)