diff --git a/parcels/_index_search.py b/parcels/_index_search.py index beb77c351..574a31159 100644 --- a/parcels/_index_search.py +++ b/parcels/_index_search.py @@ -75,20 +75,42 @@ def curvilinear_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray def _search_indices_curvilinear_2d( - grid: XGrid, y: np.ndarray, x: np.ndarray, yi_guess: np.ndarray | None = None, xi_guess: np.ndarray | None = None + grid: XGrid, y: np.ndarray, x: np.ndarray, yi: np.ndarray | None = None, xi: np.ndarray | None = None ): - yi_guess = np.array(yi_guess) - xi_guess = np.array(xi_guess) - xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32) - yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32) - if np.any(xi_guess): + """Searches a grid for particle locations in 2D curvilinear coordinates. + + Parameters + ---------- + grid : XGrid + The curvilinear grid to search within. + y : np.ndarray + Array of latitude-coordinates of the points to locate. + x : np.ndarray + Array of longitude-coordinates of the points to locate. + yi : np.ndarray | None, optional + Array of initial guesses for the j indices of the points to locate. + xi : np.ndarray | None, optional + Array of initial guesses for the i indices of the points to locate. + + Returns + ------- + tuple + A tuple containing four elements: + - yi (np.ndarray): Array of found j-indices corresponding to the input coordinates. + - eta (np.ndarray): Array of barycentric coordinates in the j-direction within the found grid cells. + - xi (np.ndarray): Array of found i-indices corresponding to the input cooordinates. + - xsi (np.ndarray): Array of barycentric coordinates in the i-direction within the found grid cells. + """ + if np.any(xi): # If an initial guess is provided, we first perform a point in cell check for all guessed indices - is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi_guess, xi_guess) + is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi, xi) y_check = y[is_in_cell == 0] x_check = x[is_in_cell == 0] zero_indices = np.where(is_in_cell == 0)[0] else: # Otherwise, we need to check all points + yi = np.full(len(y), GRID_SEARCH_ERROR, dtype=np.int32) + xi = np.full(len(x), GRID_SEARCH_ERROR, dtype=np.int32) y_check = y x_check = x coords = -1.0 * np.ones((len(y), 2), dtype=np.float32) diff --git a/parcels/basegrid.py b/parcels/basegrid.py index b53f51314..d758d5f17 100644 --- a/parcels/basegrid.py +++ b/parcels/basegrid.py @@ -69,7 +69,7 @@ def search(self, z: float, y: float, x: float, ei=None) -> dict[str, tuple[int, """ ... - def ravel_index(self, axis_indices: dict[str, int]) -> int: + def ravel_index(self, axis_indices: dict[str, np.ndarray]) -> np.ndarray: """ Convert a dictionary of axis indices to a single encoded index (ei). @@ -79,7 +79,7 @@ def ravel_index(self, axis_indices: dict[str, int]) -> int: Parameters ---------- - axis_indices : dict[str, int] + axis_indices : dict[str, np.ndarray(int)] A dictionary mapping axis names to their corresponding indices. The expected keys depend on the grid dimensionality and type: @@ -90,8 +90,8 @@ def ravel_index(self, axis_indices: dict[str, int]) -> int: Returns ------- - int - The encoded index (ei) representing the unique grid cell or face. + np.ndarray(int) + The encoded indices (ei) representing the unique grid cells or faces. Raises ------ @@ -204,13 +204,13 @@ def _unravel(dims, ei): """ strides = np.cumprod(dims[::-1])[::-1] - indices = np.empty(len(dims), dtype=int) + indices = np.empty((len(dims), len(ei)), dtype=int) for i in range(len(dims) - 1): - indices[i] = ei // strides[i + 1] + indices[i, :] = ei // strides[i + 1] ei = ei % strides[i + 1] - indices[-1] = ei + indices[-1, :] = ei return indices diff --git a/parcels/field.py b/parcels/field.py index fff0796b5..212c042c0 100644 --- a/parcels/field.py +++ b/parcels/field.py @@ -212,13 +212,14 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True): conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ - # if particle is None: - _ei = None - # else: - # _ei = particle.ei[self.igrid] + if particles is None: + _ei = None + else: + _ei = particles.ei[:, self.igrid] tau, ti = _search_time_index(self, time) position = self.grid.search(z, y, x, ei=_ei) + _update_particles_ei(particles, position, self) _update_particle_states_position(particles, position) value = self._interp_method(self, ti, position, tau, time, z, y, x) @@ -251,6 +252,7 @@ def __init__( self.V = V self.W = W self.grid = U.grid + self.igrid = U.igrid if W is None: _assert_same_time_interval((U, V)) @@ -294,13 +296,14 @@ def eval(self, time: datetime, z, y, x, particles=None, applyConversion=True): conversion to the result. Note that we defer to scipy.interpolate to perform spatial interpolation. """ - # if particle is None: - _ei = None - # else: - # _ei = particle.ei[self.igrid] + if particles is None: + _ei = None + else: + _ei = particles.ei[:, self.igrid] tau, ti = _search_time_index(self.U, time) position = self.grid.search(z, y, x, ei=_ei) + _update_particles_ei(particles, position, self) _update_particle_states_position(particles, position) if self._vector_interp_method is None: @@ -339,6 +342,26 @@ def __getitem__(self, key): return _deal_with_errors(error, key, vector_type=self.vector_type) +def _update_particles_ei(particles, position, field): + """Update the element index (ei) of the particles""" + if particles is not None: + if isinstance(field.grid, XGrid): + particles.ei[:, field.igrid] = field.grid.ravel_index( + { + "X": position["X"][0], + "Y": position["Y"][0], + "Z": position["Z"][0], + } + ) + elif isinstance(field.grid, UxGrid): + particles.ei[:, field.igrid] = field.grid.ravel_index( + { + "Z": position["Z"][0], + "FACE": position["FACE"][0], + } + ) + + def _update_particle_states_position(particles, position): """Update the particle states based on the position dictionary.""" if particles: # TODO also support uxgrid search diff --git a/parcels/particleset.py b/parcels/particleset.py index 7db2af65c..caded06f5 100644 --- a/parcels/particleset.py +++ b/parcels/particleset.py @@ -6,13 +6,11 @@ import numpy as np import xarray as xr -from scipy.spatial import KDTree from tqdm import tqdm from parcels._core.utils.time import TimeInterval, maybe_convert_python_timedelta_to_numpy from parcels._reprs import particleset_repr from parcels.application_kernels.advection import AdvectionRK4 -from parcels.basegrid import GridType from parcels.kernel import Kernel from parcels.particle import KernelParticle, Particle, create_particle_data from parcels.tools.converters import convert_to_flat_array @@ -305,28 +303,17 @@ def _neighbors_by_coor(self, coor): neighbor_ids = self._data["trajectory"][neighbor_idx] return neighbor_ids - # TODO: This method is only tested in tutorial notebook. Add unit test? def populate_indices(self): - """Pre-populate guesses of particle ei (element id) indices using a kdtree. - - This is only intended for curvilinear grids, where the initial index search - may be quite expensive. - """ + """Pre-populate guesses of particle ei (element id) indices""" for i, grid in enumerate(self.fieldset.gridset): - if grid._gtype not in [GridType.CurvilinearZGrid, GridType.CurvilinearSGrid]: - continue - - tree_data = np.stack((grid.lon.flat, grid.lat.flat), axis=-1) - IN = np.all(~np.isnan(tree_data), axis=1) - tree = KDTree(tree_data[IN, :]) - # stack all the particle positions for a single query - pts = np.stack((self._data["lon"], self._data["lat"]), axis=-1) - # query datatype needs to match tree datatype - _, idx_nan = tree.query(pts.astype(tree_data.dtype)) - - idx = np.where(IN)[0][idx_nan] - - self._data["ei"][:, i] = idx # assumes that we are in the surface layer (zi=0) + position = grid.search(self.depth, self.lat, self.lon) + self._data["ei"][:, i] = grid.ravel_index( + { + "X": position["X"][0], + "Y": position["Y"][0], + "Z": position["Z"][0], + } + ) @classmethod def from_particlefile(cls, fieldset, pclass, filename, restart=True, restarttime=None, **kwargs): diff --git a/parcels/uxgrid.py b/parcels/uxgrid.py index a82037ef3..4702c6f7c 100644 --- a/parcels/uxgrid.py +++ b/parcels/uxgrid.py @@ -87,17 +87,18 @@ def try_face(fid): zi, zeta = _search_1d_array(self.z.values, z) if ei is not None: - _, fi = self.unravel_index(ei) + indices = self.unravel_index(ei) + fi = indices["FACE"][0] bcoords = try_face(fi) if bcoords is not None: - return bcoords, self.ravel_index(zi, fi) + return {"Z": (zi, zeta), "FACE": (np.asarray([fi]), bcoords)} # Try neighbors of current face for neighbor in self.uxgrid.face_face_connectivity[fi, :]: if neighbor == -1: continue bcoords = try_face(neighbor) if bcoords is not None: - return bcoords, self.ravel_index(zi, neighbor) + return {"Z": (zi, zeta), "FACE": (np.asarray([neighbor]), bcoords)} # Global fallback as last ditch effort points = np.column_stack((x, y)) @@ -113,7 +114,6 @@ def try_face(fid): def _get_barycentric_coordinates_latlon(self, y, x, fi): """Checks if a point is inside a given face id on a UxGrid.""" # Check if particle is in the same face, otherwise search again. - n_nodes = self.uxgrid.n_nodes_per_face[fi].to_numpy() node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes] nodes = np.column_stack( diff --git a/tests/v4/test_basegrid.py b/tests/v4/test_basegrid.py index 86a4f12ce..1185ad933 100644 --- a/tests/v4/test_basegrid.py +++ b/tests/v4/test_basegrid.py @@ -35,7 +35,7 @@ def get_axis_dim(self, axis: str) -> int: def test_basegrid_ravel_unravel_index(grid): axes = grid.axes dimensionalities = (grid.get_axis_dim(axis) for axis in axes) - all_possible_axis_indices = itertools.product(*[range(dim) for dim in dimensionalities]) + all_possible_axis_indices = itertools.product(*[np.arange(dim)[:, np.newaxis] for dim in dimensionalities]) encountered_eis = [] @@ -45,7 +45,7 @@ def test_basegrid_ravel_unravel_index(grid): ei = grid.ravel_index(axis_indices) axis_indices_test = grid.unravel_index(ei) assert axis_indices_test == axis_indices - encountered_eis.append(ei) + encountered_eis.append(ei[0]) encountered_eis = sorted(encountered_eis) assert len(set(encountered_eis)) == len(encountered_eis), "Raveled indices are not unique." diff --git a/tests/v4/test_particleset.py b/tests/v4/test_particleset.py index 754032865..46d03d622 100644 --- a/tests/v4/test_particleset.py +++ b/tests/v4/test_particleset.py @@ -17,6 +17,7 @@ from parcels._datasets.structured.generic import datasets as datasets_structured from parcels.xgrid import XGrid from tests.common_kernels import DoNothing +from tests.utils import round_and_hash_float_array @pytest.fixture @@ -126,6 +127,13 @@ def Addlon(particles, fieldset): # pragma: no cover assert np.allclose([p.lon + p.dlon for p in pset], [8 - t for t in times]) +def test_populate_indices(fieldset): + npart = 11 + pset = ParticleSet(fieldset, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart)) + pset.populate_indices() + np.testing.assert_equal(round_and_hash_float_array(pset.ei, decimals=0), 935996932384571063274191) + + def test_pset_add_explicit(fieldset): npart = 11 lon = np.linspace(0, 1, npart) diff --git a/tests/v4/test_particleset_execute.py b/tests/v4/test_particleset_execute.py index c5aef6905..0bfaca1a4 100644 --- a/tests/v4/test_particleset_execute.py +++ b/tests/v4/test_particleset_execute.py @@ -400,8 +400,8 @@ def test_uxstommelgyre_pset_execute(): dt=np.timedelta64(60, "s"), pyfunc=AdvectionEE, ) - assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165396086 - assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142124776 + assert utils.round_and_hash_float_array([p.lon for p in pset]) == 1165397121 + assert utils.round_and_hash_float_array([p.lat for p in pset]) == 1142123780 @pytest.mark.xfail(reason="Output file not implemented yet")