Skip to content

Commit 8fe61fe

Browse files
reint-fischerreint-fischer
authored andcommitted
Implementing separate particle_positions and grid_positions dictionary as suggested in #2287
1 parent 5752fa2 commit 8fe61fe

File tree

10 files changed

+128
-91
lines changed

10 files changed

+128
-91
lines changed

src/parcels/_core/field.py

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,14 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
216216
else:
217217
_ei = particles.ei[:, self.igrid]
218218

219-
position = {"time": time, "z": z, "lat": y, "lon": x}
220-
position.update(_search_time_index(self, time))
221-
position.update(self.grid.search(z, y, x, ei=_ei))
222-
_update_particles_ei(particles, position, self)
223-
_update_particle_states_position(particles, position)
219+
particle_positions = {"time": time, "z": z, "lat": y, "lon": x}
220+
grid_positions = {}
221+
grid_positions.update(_search_time_index(self, time))
222+
grid_positions.update(self.grid.search(z, y, x, ei=_ei))
223+
_update_particles_ei(particles, grid_positions, self)
224+
_update_particle_states_position(particles, grid_positions)
224225

225-
value = self._interp_method(position, self)
226+
value = self._interp_method(particle_positions, grid_positions, self)
226227

227228
_update_particle_states_interp_value(particles, value)
228229

@@ -301,21 +302,22 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
301302
else:
302303
_ei = particles.ei[:, self.igrid]
303304

304-
position = {"time": time, "z": z, "lat": y, "lon": x}
305-
position.update(_search_time_index(self.U, time))
306-
position.update(self.grid.search(z, y, x, ei=_ei))
307-
_update_particles_ei(particles, position, self)
308-
_update_particle_states_position(particles, position)
305+
particle_positions = {"time": time, "z": z, "lat": y, "lon": x}
306+
grid_positions = {}
307+
grid_positions.update(_search_time_index(self.U, time))
308+
grid_positions.update(self.grid.search(z, y, x, ei=_ei))
309+
_update_particles_ei(particles, grid_positions, self)
310+
_update_particle_states_position(particles, grid_positions)
309311

310312
if self._vector_interp_method is None:
311-
u = self.U._interp_method(position, self.U)
312-
v = self.V._interp_method(position, self.V)
313+
u = self.U._interp_method(particle_positions, grid_positions, self.U)
314+
v = self.V._interp_method(particle_positions, grid_positions, self.V)
313315
if "3D" in self.vector_type:
314-
w = self.W._interp_method(position, self.W)
316+
w = self.W._interp_method(particle_positions, grid_positions, self.W)
315317
else:
316318
w = 0.0
317319
else:
318-
(u, v, w) = self._vector_interp_method(position, self)
320+
(u, v, w) = self._vector_interp_method(particle_positions, grid_positions, self)
319321

320322
if applyConversion:
321323
u = self.U.units.to_target(u, z, y, x)
@@ -341,45 +343,54 @@ def __getitem__(self, key):
341343
return _deal_with_errors(error, key, vector_type=self.vector_type)
342344

343345

344-
def _update_particles_ei(particles, position, field):
346+
def _update_particles_ei(particles, grid_positions, field):
345347
"""Update the element index (ei) of the particles"""
346348
if particles is not None:
347349
if isinstance(field.grid, XGrid):
348350
particles.ei[:, field.igrid] = field.grid.ravel_index(
349351
{
350-
"X": position["xi"],
351-
"Y": position["yi"],
352-
"Z": position["zi"],
352+
"X": grid_positions["X"]["index"],
353+
"Y": grid_positions["Y"]["index"],
354+
"Z": grid_positions["Z"]["index"],
353355
}
354356
)
355357
elif isinstance(field.grid, UxGrid):
356358
particles.ei[:, field.igrid] = field.grid.ravel_index(
357359
{
358-
"Z": position["Z"][0],
359-
"FACE": position["FACE"][0],
360+
"Z": grid_positions["Z"]["index"],
361+
"FACE": grid_positions["FACE"]["index"],
360362
}
361363
)
362364

363365

364-
def _update_particle_states_position(particles, position):
366+
def _update_particle_states_position(particles, grid_positions):
365367
"""Update the particle states based on the position dictionary."""
366368
if particles: # TODO also support uxgrid search
367-
for dim in ["xi", "yi"]:
368-
if dim in position:
369+
for dim in ["X", "Y"]:
370+
if dim in grid_positions:
369371
particles.state = np.maximum(
370-
np.where(position[dim] == -1, StatusCode.ErrorOutOfBounds, particles.state), particles.state
372+
np.where(grid_positions[dim]["index"] == -1, StatusCode.ErrorOutOfBounds, particles.state),
373+
particles.state,
371374
)
372375
particles.state = np.maximum(
373-
np.where(position[dim] == GRID_SEARCH_ERROR, StatusCode.ErrorGridSearching, particles.state),
376+
np.where(
377+
grid_positions[dim]["index"] == GRID_SEARCH_ERROR,
378+
StatusCode.ErrorGridSearching,
379+
particles.state,
380+
),
374381
particles.state,
375382
)
376-
if "zi" in position:
383+
if "Z" in grid_positions:
377384
particles.state = np.maximum(
378-
np.where(position["zi"][0] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particles.state),
385+
np.where(
386+
grid_positions["Z"]["index"] == RIGHT_OUT_OF_BOUNDS, StatusCode.ErrorOutOfBounds, particles.state
387+
),
379388
particles.state,
380389
)
381390
particles.state = np.maximum(
382-
np.where(position["zi"][0] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particles.state),
391+
np.where(
392+
grid_positions["Z"]["index"] == LEFT_OUT_OF_BOUNDS, StatusCode.ErrorThroughSurface, particles.state
393+
),
383394
particles.state,
384395
)
385396

src/parcels/_core/index_search.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,19 @@ def _search_time_index(field: Field, time: datetime):
7575
if the sampled value is outside the time value range.
7676
"""
7777
if field.time_interval is None:
78-
return {"ti": np.zeros(shape=time.shape, dtype=np.int32), "tau": np.zeros(shape=time.shape, dtype=np.float32)}
78+
return {
79+
"T": {
80+
"index": np.zeros(shape=time.shape, dtype=np.int32),
81+
"bcoord": np.zeros(shape=time.shape, dtype=np.float32),
82+
}
83+
}
7984

8085
if not field.time_interval.is_all_time_in_interval(time):
8186
_raise_time_extrapolation_error(time, field=None)
8287

8388
ti = np.searchsorted(field.data.time.data, time, side="right") - 1
8489
tau = (time - field.data.time.data[ti]) / (field.data.time.data[ti + 1] - field.data.time.data[ti])
85-
return {"ti": np.atleast_1d(ti), "tau": np.atleast_1d(tau)}
90+
return {"T": {"index": np.atleast_1d(ti), "bcoord": np.atleast_1d(tau)}}
8691

8792

8893
def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):

src/parcels/_core/particleset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -295,12 +295,12 @@ def _neighbors_by_coor(self, coor):
295295
def populate_indices(self):
296296
"""Pre-populate guesses of particle ei (element id) indices"""
297297
for i, grid in enumerate(self.fieldset.gridset):
298-
position = grid.search(self.z, self.lat, self.lon)
298+
grid_positions = grid.search(self.z, self.lat, self.lon)
299299
self._data["ei"][:, i] = grid.ravel_index(
300300
{
301-
"X": position["xi"],
302-
"Y": position["yi"],
303-
"Z": position["zi"],
301+
"X": grid_positions["X"]["index"],
302+
"Y": grid_positions["Y"]["index"],
303+
"Z": grid_positions["Z"]["index"],
304304
}
305305
)
306306

src/parcels/_core/uxgrid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,4 @@ def search(self, z, y, x, ei=None, tol=1e-6):
118118
coords[zero_indices, :] = coords_q
119119
fi[zero_indices] = face_ids_q
120120

121-
return {"Z": (zi, zeta), "FACE": (fi, coords)}
121+
return {"Z": {"index":zi, "bcoord":zeta}, "FACE": {"index":fi, "bcoord":coords}}

src/parcels/_core/xgrid.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,11 @@ def search(self, z, y, x, ei=None):
289289
if ds.lon.ndim == 1:
290290
yi, eta = _search_1d_array(ds.lat.values, y)
291291
xi, xsi = _search_1d_array(ds.lon.values, x)
292-
return {"zi": zi, "zeta": zeta, "yi": yi, "eta": eta, "xi": xi, "xsi": xsi}
292+
return {
293+
"Z": {"index": zi, "bcoord": zeta},
294+
"Y": {"index": yi, "bcoord": eta},
295+
"X": {"index": xi, "bcoord": xsi},
296+
}
293297

294298
yi, xi = None, None
295299
if ei is not None:
@@ -300,7 +304,11 @@ def search(self, z, y, x, ei=None):
300304
if ds.lon.ndim == 2:
301305
yi, eta, xi, xsi = _search_indices_curvilinear_2d(self, y, x, yi, xi)
302306

303-
return {"zi": zi, "zeta": zeta, "yi": yi, "eta": eta, "xi": xi, "xsi": xsi}
307+
return {
308+
"Z": {"index": zi, "bcoord": zeta},
309+
"Y": {"index": yi, "bcoord": eta},
310+
"X": {"index": xi, "bcoord": xsi},
311+
}
304312

305313
raise NotImplementedError("Searching in >2D lon/lat arrays is not implemented yet.")
306314

0 commit comments

Comments
 (0)