Skip to content

Commit ac53039

Browse files
Merge pull request #1878 from OceanParcels/1874-ravel-unravel-index
1874 ravel unravel index
2 parents 043d0ba + c8487d9 commit ac53039

File tree

10 files changed

+88
-121
lines changed

10 files changed

+88
-121
lines changed

docs/examples/tutorial_particle_field_interaction.ipynb

Lines changed: 19 additions & 84 deletions
Large diffs are not rendered by default.

parcels/_index_search.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,17 +218,14 @@ def _search_indices_rectilinear(
218218
_raise_field_sampling_error(z, y, x)
219219

220220
if particle:
221-
particle.xi[field.igrid] = xi
222-
particle.yi[field.igrid] = yi
223-
particle.zi[field.igrid] = zi
221+
particle.ei[field.igrid] = field.ravel_index(zi, yi, xi)
224222

225223
return (zeta, eta, xsi, zi, yi, xi)
226224

227225

228226
def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=None, search2D=False):
229227
if particle:
230-
xi = particle.xi[field.igrid]
231-
yi = particle.yi[field.igrid]
228+
zi, yi, xi = field.unravel_index(particle.ei)
232229
else:
233230
xi = int(field.grid.xdim / 2) - 1
234231
yi = int(field.grid.ydim / 2) - 1
@@ -310,9 +307,7 @@ def _search_indices_curvilinear(field: Field, time, z, y, x, ti=-1, particle=Non
310307
_raise_field_sampling_error(z, y, x)
311308

312309
if particle:
313-
particle.xi[field.igrid] = xi
314-
particle.yi[field.igrid] = yi
315-
particle.zi[field.igrid] = zi
310+
particle.ei[field.igrid] = field.ravel_index(zi, yi, xi)
316311

317312
return (zeta, eta, xsi, zi, yi, xi)
318313

parcels/application_kernels/advection.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,7 @@ def AdvectionAnalytical(particle, fieldset, time): # pragma: no cover
210210
yi += 1
211211
eta = 0
212212

213-
particle.xi[:] = xi
214-
particle.yi[:] = yi
215-
particle.zi[:] = zi
213+
particle.ei[:] = fieldset.U.ravel_index(zi, yi, xi)
216214

217215
grid = fieldset.U.grid
218216
if grid._gtype < 2:

parcels/field.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,49 @@ def computeTimeChunk(self, data, tindex):
12551255
self.filebuffers[tindex] = filebuffer
12561256
return data
12571257

1258+
def ravel_index(self, zi, yi, xi):
1259+
"""Return the flat index of the given grid points.
1260+
1261+
Parameters
1262+
----------
1263+
zi : int
1264+
z index
1265+
yi : int
1266+
y index
1267+
xi : int
1268+
x index
1269+
1270+
Returns
1271+
-------
1272+
int
1273+
flat index
1274+
"""
1275+
return xi + self.grid.xdim * (yi + self.grid.ydim * zi)
1276+
1277+
def unravel_index(self, ei):
1278+
"""Return the zi, yi, xi indices for a given flat index.
1279+
1280+
Parameters
1281+
----------
1282+
ei : int
1283+
The flat index to be unraveled.
1284+
1285+
Returns
1286+
-------
1287+
zi : int
1288+
The z index.
1289+
yi : int
1290+
The y index.
1291+
xi : int
1292+
The x index.
1293+
"""
1294+
_ei = ei[self.igrid]
1295+
zi = _ei // (self.grid.xdim * self.grid.ydim)
1296+
_ei = _ei % (self.grid.xdim * self.grid.ydim)
1297+
yi = _ei // self.grid.xdim
1298+
xi = _ei % self.grid.xdim
1299+
return zi, yi, xi
1300+
12581301

12591302
class VectorField:
12601303
"""Class VectorField stores 2 or 3 fields which defines together a vector field.

parcels/particledata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(self, pclass, lon, lat, depth, time, lonlatdepth_dtype, pid_orig, n
122122
self._ncount = len(lon)
123123

124124
for v in self.ptype.variables:
125-
if v.name in ["xi", "yi", "zi", "ti"]:
125+
if v.name in ["ei", "ti"]:
126126
self._data[v.name] = np.empty((len(lon), ngrid), dtype=v.dtype)
127127
else:
128128
self._data[v.name] = np.empty(self._ncount, dtype=v.dtype)

parcels/particleset.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,13 @@ def ArrayClass_init(self, *args, **kwargs):
127127
type(self).ngrids.initial = numgrids
128128
self.ngrids = type(self).ngrids.initial
129129
if self.ngrids >= 0:
130-
for index in ["xi", "yi", "zi", "ti"]:
131-
if index != "ti":
132-
setattr(self, index, np.zeros(self.ngrids, dtype=np.int32))
133-
else:
134-
setattr(self, index, -1 * np.ones(self.ngrids, dtype=np.int32))
130+
self.ei = np.zeros(self.ngrids, dtype=np.int32)
131+
self.ti = -1 * np.ones(self.ngrids, dtype=np.int32)
135132
super(type(self), self).__init__(*args, **kwargs)
136133

137134
array_class_vdict = {
138135
"ngrids": Variable("ngrids", dtype=np.int32, to_write=False, initial=-1),
139-
"xi": Variable("xi", dtype=np.int32, to_write=False),
140-
"yi": Variable("yi", dtype=np.int32, to_write=False),
141-
"zi": Variable("zi", dtype=np.int32, to_write=False),
136+
"ei": Variable("ei", dtype=np.int32, to_write=False),
142137
"ti": Variable("ti", dtype=np.int32, to_write=False, initial=-1),
143138
"__init__": ArrayClass_init,
144139
}
@@ -436,7 +431,7 @@ def _neighbors_by_coor(self, coor):
436431

437432
# TODO: This method is only tested in tutorial notebook. Add unit test?
438433
def populate_indices(self):
439-
"""Pre-populate guesses of particle xi/yi indices using a kdtree.
434+
"""Pre-populate guesses of particle ei (element id) indices using a kdtree.
440435
441436
This is only intended for curvilinear grids, where the initial index search
442437
may be quite expensive.
@@ -454,10 +449,8 @@ def populate_indices(self):
454449
_, idx_nan = tree.query(pts.astype(tree_data.dtype))
455450

456451
idx = np.where(IN)[0][idx_nan]
457-
yi, xi = np.unravel_index(idx, grid.lon.shape)
458452

459-
self.particledata.data["xi"][:, i] = xi
460-
self.particledata.data["yi"][:, i] = yi
453+
self.particledata.data["ei"][:, i] = idx # assumes that we are in the surface layer (zi=0)
461454

462455
@classmethod
463456
def from_list(
@@ -725,9 +718,7 @@ def from_particlefile(
725718
elif (
726719
v.name
727720
not in [
728-
"xi",
729-
"yi",
730-
"zi",
721+
"ei",
731722
"ti",
732723
"dt",
733724
"depth",

tests/test_fieldset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,8 @@ def test_fieldset_write(tmp_zarrfile):
570570

571571
def UpdateU(particle, fieldset, time): # pragma: no cover
572572
tmp1, tmp2 = fieldset.UV[particle]
573-
fieldset.U.data[particle.ti, particle.yi, particle.xi] += 1
573+
_, yi, xi = fieldset.U.unravel_index(particle.ei)
574+
fieldset.U.data[particle.ti, yi, xi] += 1
574575
fieldset.U.grid.time[0] = time
575576

576577
pset = ParticleSet(fieldset, pclass=Particle, lon=5, lat=5)

tests/test_fieldset_sampling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,8 @@ def test_verticalsampling(zdir):
131131
fieldset = FieldSet.from_data(data, dimensions, mesh="flat")
132132
pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0, depth=0.7 * zdir)
133133
pset.execute(AdvectionRK4, dt=1.0, runtime=1.0)
134-
assert pset[0].zi == [2]
134+
zi, yi, xi = fieldset.U.unravel_index(pset[0].ei)
135+
assert zi == [2]
135136

136137

137138
def test_pset_from_field():

tests/test_grids.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,14 @@ def sampleTemp(particle, fieldset, time): # pragma: no cover
9090
pset.execute(AdvectionRK4 + pset.Kernel(sampleTemp), runtime=3, dt=1)
9191

9292
# check if particle xi and yi are different for the two grids
93-
# assert np.all([pset.xi[i, 0] != pset.xi[i, 1] for i in range(3)])
94-
# assert np.all([pset.yi[i, 0] != pset.yi[i, 1] for i in range(3)])
95-
assert np.all([pset[i].xi[0] != pset[i].xi[1] for i in range(3)])
96-
assert np.all([pset[i].yi[0] != pset[i].yi[1] for i in range(3)])
97-
93+
# xi check from unraveled index
94+
assert np.all(
95+
[fieldset.U.unravel_index(pset[i].ei)[2] != fieldset.V.unravel_index(pset[i].ei)[2] for i in range(3)]
96+
)
97+
# yi check from unraveled index
98+
assert np.all(
99+
[fieldset.U.unravel_index(pset[i].ei)[1] != fieldset.V.unravel_index(pset[i].ei)[1] for i in range(3)]
100+
)
98101
# advect without updating temperature to test particle deletion
99102
pset.remove_indices(np.array([1]))
100103
pset.execute(AdvectionRK4, runtime=1, dt=1)

tests/test_particlefile.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ def Get_XiYi(particle, fieldset, time): # pragma: no cover
273273
and that the first outputted value is zero.
274274
Be careful when using multiple grids, as the index may be different for the grids.
275275
"""
276-
particle.pxi0 = particle.xi[0]
277-
particle.pxi1 = particle.xi[1]
278-
particle.pyi = particle.yi[0]
276+
particle.pxi0 = fieldset.U.unravel_index(particle.ei)[2]
277+
particle.pxi1 = fieldset.P.unravel_index(particle.ei)[2]
278+
particle.pyi = fieldset.U.unravel_index(particle.ei)[1]
279279

280280
def SampleP(particle, fieldset, time): # pragma: no cover
281281
if time > 5 * 3600:

0 commit comments

Comments
 (0)