Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,18 @@ def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray


def _search_indices_curvilinear_2d(
grid: XGrid, y: np.ndarray, x: np.ndarray, yi_guess: np.ndarray | None = None, xi_guess: np.ndarray | None = None
grid: XGrid, y: np.ndarray, x: np.ndarray, yi: np.ndarray | None = None, xi: np.ndarray | None = None
):
yi_guess = np.array(yi_guess)
xi_guess = np.array(xi_guess)
xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32)
yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32)
if np.any(xi_guess):
if np.any(xi):
# If an initial guess is provided, we first perform a point in cell check for all guessed indices
is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi_guess, xi_guess)
is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi, xi)
y_check = y[is_in_cell == 0]
x_check = x[is_in_cell == 0]
zero_indices = np.where(is_in_cell == 0)[0]
else:
# Otherwise, we need to check all points
yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32)
xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32)
y_check = y
x_check = x
coords = -1.0 * np.ones((len(y), 2), dtype=np.float32)
Expand Down
14 changes: 7 additions & 7 deletions parcels/basegrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def search(self, z: float, y: float, x: float, ei=None) -> dict[str, tuple[int,
"""
...

def ravel_index(self, axis_indices: dict[str, int]) -> int:
def ravel_index(self, axis_indices: dict[str, np.ndarray]) -> np.ndarray:
"""
Convert a dictionary of axis indices to a single encoded index (ei).

Expand All @@ -79,7 +79,7 @@ def ravel_index(self, axis_indices: dict[str, int]) -> int:

Parameters
----------
axis_indices : dict[str, int]
axis_indices : dict[str, np.ndarray(int)]
A dictionary mapping axis names to their corresponding indices.
The expected keys depend on the grid dimensionality and type:

Expand All @@ -90,8 +90,8 @@ def ravel_index(self, axis_indices: dict[str, int]) -> int:

Returns
-------
int
The encoded index (ei) representing the unique grid cell or face.
np.ndarray(int)
The encoded indices (ei) representing the unique grid cells or faces.

Raises
------
Expand Down Expand Up @@ -204,13 +204,13 @@ def _unravel(dims, ei):
"""
strides = np.cumprod(dims[::-1])[::-1]

indices = np.empty(len(dims), dtype=int)
indices = np.empty((len(dims), len(ei)), dtype=int)

for i in range(len(dims) - 1):
indices[i] = ei // strides[i + 1]
indices[i, :] = ei // strides[i + 1]
ei = ei % strides[i + 1]

indices[-1] = ei
indices[-1, :] = ei
return indices


Expand Down
39 changes: 31 additions & 8 deletions parcels/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,14 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
conversion to the result. Note that we defer to
scipy.interpolate to perform spatial interpolation.
"""
# if particle is None:
_ei = None
# else:
# _ei = particle.ei[self.igrid]
if particles is None:
_ei = None
else:
_ei = particles.ei[:, self.igrid]

tau, ti = _search_time_index(self, time)
position = self.grid.search(z, y, x, ei=_ei)
_update_particles_ei(particles, position, self)
_update_particle_states_position(particles, position)

value = self._interp_method(self, ti, position, tau, time, z, y, x)
Expand Down Expand Up @@ -251,6 +252,7 @@ def __init__(
self.V = V
self.W = W
self.grid = U.grid
self.igrid = U.igrid

if W is None:
_assert_same_time_interval((U, V))
Expand Down Expand Up @@ -294,13 +296,14 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True):
conversion to the result. Note that we defer to
scipy.interpolate to perform spatial interpolation.
"""
# if particle is None:
_ei = None
# else:
# _ei = particle.ei[self.igrid]
if particles is None:
_ei = None
else:
_ei = particles.ei[:, self.igrid]

tau, ti = _search_time_index(self.U, time)
position = self.grid.search(z, y, x, ei=_ei)
_update_particles_ei(particles, position, self)
_update_particle_states_position(particles, position)

if self._vector_interp_method is None:
Expand Down Expand Up @@ -339,6 +342,26 @@ def __getitem__(self, key):
return _deal_with_errors(error, key, vector_type=self.vector_type)


def _update_particles_ei(particles, position, field):
"""Update the element index (ei) of the particles"""
if particles is not None:
if isinstance(field.grid, XGrid):
particles.ei[:, field.igrid] = field.grid.ravel_index(
{
"X": position["X"][0],
"Y": position["Y"][0],
"Z": position["Z"][0],
}
)
elif isinstance(field.grid, UxGrid):
particles.ei[:, field.igrid] = field.grid.ravel_index(
{
"Z": position["Z"][0],
"FACE": position["FACE"][0],
}
)


def _update_particle_states_position(particles, position):
"""Update the particle states based on the position dictionary."""
if particles: # TODO also support uxgrid search
Expand Down
31 changes: 9 additions & 22 deletions parcels/particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@

import numpy as np
import xarray as xr
from scipy.spatial import KDTree
from tqdm import tqdm

from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy
from parcels._reprs import particleset_repr
from parcels.application_kernels.advection import AdvectionRK4
from parcels.basegrid import GridType
from parcels.kernel import Kernel
from parcels.particle import KernelParticle, Particle, create_particle_data
from parcels.tools.converters import convert_to_flat_array
Expand Down Expand Up @@ -305,28 +303,17 @@ def _neighbors_by_coor(self, coor):
neighbor_ids = self._data["trajectory"][neighbor_idx]
return neighbor_ids

# TODO: This method is only tested in tutorial notebook. Add unit test?
def populate_indices(self):
"""Pre-populate guesses of particle ei (element id) indices using a kdtree.

This is only intended for curvilinear grids, where the initial index search
may be quite expensive.
"""
"""Pre-populate guesses of particle ei (element id) indices"""
for i, grid in enumerate(self.fieldset.gridset):
if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]:
continue

tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1)
IN = np.all(~np.isnan(tree_data), axis=1)
tree = KDTree(tree_data[IN, :])
# stack all the particle positions for a single query
pts = np.stack((self._data["lon"], self._data["lat"]), axis=-1)
# query datatype needs to match tree datatype
_, idx_nan = tree.query(pts.astype(tree_data.dtype))

idx = np.where(IN)[0][idx_nan]

self._data["ei"][:, i] = idx # assumes that we are in the surface layer (zi=0)
position = grid.search(self.depth, self.lat, self.lon)
self._data["ei"][:, i] = grid.ravel_index(
{
"X": position["X"][0],
"Y": position["Y"][0],
"Z": position["Z"][0],
}
)

@classmethod
def from_particlefile(cls, fieldset, pclass, filename, restart=True, restarttime=None, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions parcels/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,18 @@ def try_face(fid):
zi, zeta = _search_1d_array(self.z.values, z)

if ei is not None:
_, fi = self.unravel_index(ei)
indices = self.unravel_index(ei)
fi = indices["FACE"][0]
bcoords = try_face(fi)
if bcoords is not None:
return bcoords, self.ravel_index(zi, fi)
return {"Z": (zi, zeta), "FACE": (np.asarray([fi]), bcoords)}
# Try neighbors of current face
for neighbor in self.uxgrid.face_face_connectivity[fi, :]:
if neighbor == -1:
continue
bcoords = try_face(neighbor)
if bcoords is not None:
return bcoords, self.ravel_index(zi, neighbor)
return {"Z": (zi, zeta), "FACE": (np.asarray([neighbor]), bcoords)}

# Global fallback as last ditch effort
points = np.column_stack((x, y))
Expand All @@ -113,7 +114,6 @@ def try_face(fid):
def _get_barycentric_coordinates_latlon(self, y, x, fi):
"""Checks if a point is inside a given face id on a UxGrid."""
# Check if particle is in the same face, otherwise search again.

n_nodes = self.uxgrid.n_nodes_per_face[fi].to_numpy()
node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes]
nodes = np.column_stack(
Expand Down
14 changes: 7 additions & 7 deletions tests/v4/test_basegrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,26 @@ def get_axis_dim(self, axis: str) -> int:
@pytest.mark.parametrize(
"grid",
[
TestGrid({"Z": 10, "Y": 20, "X": 30}),
TestGrid({"Z": 5, "Y": 15}),
TestGrid({"Z": 8}),
TestGrid({"Z": 12, "FACE": 25}),
TestGrid({"Z": [10], "Y": [20], "X": [30]}),
TestGrid({"Z": [5], "Y": [15]}),
TestGrid({"Z": [8]}),
TestGrid({"Z": [12], "FACE": [25]}),
],
)
def test_basegrid_ravel_unravel_index(grid):
axes = grid.axes
dimensionalities = (grid.get_axis_dim(axis) for axis in axes)
all_possible_axis_indices = itertools.product(*[range(dim) for dim in dimensionalities])
all_possible_axis_indices = itertools.product(*[range(dim[0]) for dim in dimensionalities])

encountered_eis = []

for axis_indices_numeric in all_possible_axis_indices:
axis_indices = dict(zip(axes, axis_indices_numeric, strict=True))
axis_indices = {axis: [index] for axis, index in zip(axes, axis_indices_numeric, strict=True)}

ei = grid.ravel_index(axis_indices)
axis_indices_test = grid.unravel_index(ei)
assert axis_indices_test == axis_indices
encountered_eis.append(ei)
encountered_eis.append(ei[0])

encountered_eis = sorted(encountered_eis)
assert len(set(encountered_eis)) == len(encountered_eis), "Raveled indices are not unique."
Expand Down
8 changes: 8 additions & 0 deletions tests/v4/test_particleset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from parcels._datasets.structured.generic import datasets as datasets_structured
from parcels.xgrid import XGrid
from tests.common_kernels import DoNothing
from tests.utils import round_and_hash_float_array


@pytest.fixture
Expand Down Expand Up @@ -126,6 +127,13 @@ def Addlon(particles, fieldset): # pragma: no cover
assert np.allclose([p.lon + p.dlon for p in pset], [8 - t for t in times])


def test_populate_indices(fieldset):
npart = 11
pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart))
pset.populate_indices()
np.testing.assert_equal(round_and_hash_float_array(pset.ei, decimals=0), 935996932384571063274191)


def test_pset_add_explicit(fieldset):
npart = 11
lon = np.linspace(0, 1, npart)
Expand Down
4 changes: 2 additions & 2 deletions tests/v4/test_particleset_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def test_uxstommelgyre_pset_execute():
dt=np.timedelta64(60, "s"),
pyfunc=AdvectionEE,
)
assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165396086
assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142124776
assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165397121
assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142123780


@pytest.mark.xfail(reason="Output file not implemented yet")
Expand Down
Loading