diff --git a/parcels/basegrid.py b/parcels/basegrid.py index 2765dd59e..5efbabf22 100644 --- a/parcels/basegrid.py +++ b/parcels/basegrid.py @@ -6,6 +6,8 @@ import numpy as np +from parcels.spatialhash import SpatialHash + if TYPE_CHECKING: import numpy as np @@ -140,6 +142,30 @@ def unravel_index(self, ei: int) -> dict[str, int]: indices = _unravel(dims, ei) return dict(zip(self.axes, indices, strict=True)) + def get_spatial_hash( + self, + reconstruct=False, + ): + """Get the SpatialHash data structure of this Grid that allows for + fast face search queries. Face searches are used to find the faces that + a list of points, in spherical coordinates, are contained within. + + Parameters + ---------- + reconstruct : bool, default=False + If true, reconstructs the spatial hash + + Returns + ------- + self._spatialhash : parcels.spatialhash.SpatialHash + SpatialHash instance + + """ + if self._spatialhash is None or reconstruct: + self._spatialhash = SpatialHash(self, reconstruct) + + return self._spatialhash + @property @abstractmethod def axes(self) -> list[str]: diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index 76447bc23..e6d1597c9 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -1,5 +1,7 @@ import numpy as np +import parcels + class SpatialHash: """Custom data structure that is used for performing grid searches using Spatial Hashing. This class constructs an overlying @@ -24,89 +26,100 @@ def __init__( grid, reconstruct=False, ): - # TODO : Enforce grid to be an instance of parcels.xgrid.XGrid - # Currently, this is not done due to circular import with parcels.xgrid - self._source_grid = grid self.reconstruct = reconstruct - if self._source_grid._mesh == "spherical": - # Boundaries of the hash grid are the unit cube + if isinstance(grid, parcels.xgrid.XGrid): + if self._source_grid._mesh == "spherical": + # Boundaries of the hash grid are the unit cube + self._xmin = -1.0 + self._ymin = -1.0 + self._zmin = -1.0 + self._xmax = 1.0 + self._ymax = 1.0 + self._zmax = 1.0 # Compute the cell centers of the source grid (for now, assuming Xgrid) + lon = np.deg2rad(self._source_grid.lon) + lat = np.deg2rad(self._source_grid.lat) + x, y, z = _latlon_rad_to_xyz(lat, lon) + _xbound = np.stack( + ( + x[:-1, :-1], + x[:-1, 1:], + x[1:, 1:], + x[1:, :-1], + ), + axis=-1, + ) + _ybound = np.stack( + ( + y[:-1, :-1], + y[:-1, 1:], + y[1:, 1:], + y[1:, :-1], + ), + axis=-1, + ) + _zbound = np.stack( + ( + z[:-1, :-1], + z[:-1, 1:], + z[1:, 1:], + z[1:, :-1], + ), + axis=-1, + ) + # Compute centroid locations of each cells + self._xc = np.mean(_xbound, axis=-1) + self._yc = np.mean(_ybound, axis=-1) + self._zc = np.mean(_zbound, axis=-1) + + else: + # Boundaries of the hash grid are the bounding box of the source grid + self._xmin = self._source_grid.lon.min() + self._xmax = self._source_grid.lon.max() + self._ymin = self._source_grid.lat.min() + self._ymax = self._source_grid.lat.max() + # setting min and max below is needed for mesh="flat" + self._zmin = 0.0 + self._zmax = 0.0 + x = self._source_grid.lon + y = self._source_grid.lat + + _xbound = np.stack( + ( + x[:-1, :-1], + x[:-1, 1:], + x[1:, 1:], + x[1:, :-1], + ), + axis=-1, + ) + _ybound = np.stack( + ( + y[:-1, :-1], + y[:-1, 1:], + y[1:, 1:], + y[1:, :-1], + ), + axis=-1, + ) + # Compute centroid locations of each cells + self._xc = np.mean(_xbound, axis=-1) + self._yc = np.mean(_ybound, axis=-1) + self._zc = np.zeros_like(self._xc) + else: self._xmin = -1.0 self._ymin = -1.0 self._zmin = -1.0 self._xmax = 1.0 self._ymax = 1.0 - self._zmax = 1.0 # Compute the cell centers of the source grid (for now, assuming Xgrid) - lon = np.deg2rad(self._source_grid.lon) - lat = np.deg2rad(self._source_grid.lat) - x, y, z = _latlon_rad_to_xyz(lat, lon) - _xbound = np.stack( - ( - x[:-1, :-1], - x[:-1, 1:], - x[1:, 1:], - x[1:, :-1], - ), - axis=-1, - ) - _ybound = np.stack( - ( - y[:-1, :-1], - y[:-1, 1:], - y[1:, 1:], - y[1:, :-1], - ), - axis=-1, - ) - _zbound = np.stack( - ( - z[:-1, :-1], - z[:-1, 1:], - z[1:, 1:], - z[1:, :-1], - ), - axis=-1, - ) - # Compute centroid locations of each cells - self._xc = np.mean(_xbound, axis=-1) - self._yc = np.mean(_ybound, axis=-1) - self._zc = np.mean(_zbound, axis=-1) + self._zmax = 1.0 - else: - # Boundaries of the hash grid are the bounding box of the source grid - self._xmin = self._source_grid.lon.min() - self._xmax = self._source_grid.lon.max() - self._ymin = self._source_grid.lat.min() - self._ymax = self._source_grid.lat.max() - # setting min and max below is needed for mesh="flat" - self._zmin = 0.0 - self._zmax = 0.0 - x = self._source_grid.lon - y = self._source_grid.lat - - _xbound = np.stack( - ( - x[:-1, :-1], - x[:-1, 1:], - x[1:, 1:], - x[1:, :-1], - ), - axis=-1, - ) - _ybound = np.stack( - ( - y[:-1, :-1], - y[:-1, 1:], - y[1:, 1:], - y[1:, :-1], - ), - axis=-1, - ) - # Compute centroid locations of each cells - self._xc = np.mean(_xbound, axis=-1) - self._yc = np.mean(_ybound, axis=-1) - self._zc = np.zeros_like(self._xc) + # Here, we force _xc, _yc, _zc to be 2D arrays to + # mininimizes code change requirements between curvilinear and unstructured grids + self._xc = np.atleast_2d(self._source_grid.uxgrid.face_x.values) + self._yc = np.atleast_2d(self._source_grid.uxgrid.face_y.values) + self._zc = np.atleast_2d(self._source_grid.uxgrid.face_z.values) # Generate the mapping from the hash indices to unstructured grid elements self._hash_table = None @@ -268,51 +281,6 @@ def query( return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) -def _triangle_area(A, B, C): - """Compute the area of a triangle given by three points.""" - d1 = B - A - d2 = C - A - d3 = np.cross(d1, d2) - return 0.5 * np.linalg.norm(d3) - - -def _barycentric_coordinates(nodes, point, min_area=1e-8): - """ - Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. - So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized - barycentric coordinates, which is only valid for convex polygons. - - Parameters - ---------- - nodes : numpy.ndarray - Spherical coordinates (lat,lon) of each corner node of a face - point : numpy.ndarray - Spherical coordinates (lat,lon) of the point - - Returns - ------- - numpy.ndarray - Barycentric coordinates corresponding to each vertex. - - """ - n = len(nodes) - sum_wi = 0 - w = [] - - for i in range(0, n): - vim1 = nodes[i - 1] - vi = nodes[i] - vi1 = nodes[(i + 1) % n] - a0 = _triangle_area(vim1, vi, vi1) - a1 = max(_triangle_area(point, vim1, vi), min_area) - a2 = max(_triangle_area(point, vi, vi1), min_area) - sum_wi += a0 / (a1 * a2) - w.append(a0 / (a1 * a2)) - barycentric_coords = [w_i / sum_wi for w_i in w] - - return barycentric_coords - - def _latlon_rad_to_xyz(lat, lon): """Converts Spherical latitude and longitude coordinates into Cartesian x, y, z coordinates. diff --git a/parcels/uxgrid.py b/parcels/uxgrid.py index bc8014f76..1daff2ad4 100644 --- a/parcels/uxgrid.py +++ b/parcels/uxgrid.py @@ -6,8 +6,6 @@ import uxarray as ux from parcels._typing import assert_valid_mesh -from parcels.spatialhash import _barycentric_coordinates -from parcels.tools.statuscodes import FieldOutOfBoundError from parcels.xgrid import _search_1d_array from .basegrid import BaseGrid @@ -43,6 +41,7 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid raise ValueError("z must be a 1D array of vertical coordinates") self.z = z self._mesh = mesh + self._spatialhash = None assert_valid_mesh(mesh) @@ -74,63 +73,43 @@ def get_axis_dim(self, axis: _UXGRID_AXES) -> int: return self.uxgrid.n_face def search(self, z, y, x, ei=None, tol=1e-6): - def try_face(fid): - bcoords, err = self._get_barycentric_coordinates_latlon(y, x, fid) - if (bcoords >= 0).all() and (bcoords <= 1).all() and err < tol: - return bcoords - else: - bcoords = self._get_barycentric_coordinates_cartesian(y, x, fid) - if (bcoords >= 0).all() and (bcoords <= 1).all(): - return bcoords - - return None + """ + Search for the grid cell (face) and vertical layer that contains the given points. + Parameters + ---------- + z : float or np.ndarray + The vertical coordinate(s) (depth) of the point(s). + y : float or np.ndarray + The latitude(s) of the point(s). + x : float or np.ndarray + The longitude(s) of the point(s). + ei : np.ndarray, optional + Precomputed horizontal indices (face indices) for the points. + + TO BE IMPLEMENTED : If provided, we'll check + if the points are within the faces specified by these indices. For cells where the particles + are not found, a nearest neighbor search will be performed. As a last resort, the spatial hash will be used. + tol : float, optional + Tolerance for barycentric coordinate checks. Default is 1e-6. + """ zi, zeta = _search_1d_array(self.z.values, z) + _, face_ids = self.get_spatial_hash().query(y, x) + valid_faces = face_ids != -1 + bcoords = np.zeros((len(face_ids), self.uxgrid.n_max_face_nodes), dtype=np.float32) + # Get the barycentric coordinates for all valid faces + for idx in np.where(valid_faces)[0]: + fi = face_ids[idx] + bc = self._get_barycentric_coordinates(y, x, fi) + if np.all(bc <= 1.0) and np.all(bc >= 0.0) and np.isclose(np.sum(bc), 1.0, atol=tol): + bcoords[idx, : len(bc)] = bc + else: + # If the barycentric coordinates are invalid, mark the face as invalid + face_ids[idx] = -1 - if ei is not None: - _, fi = self.unravel_index(ei) - bcoords = try_face(fi) - if bcoords is not None: - return bcoords, self.ravel_index(zi, fi) - # Try neighbors of current face - for neighbor in self.uxgrid.face_face_connectivity[fi, :]: - if neighbor == -1: - continue - bcoords = try_face(neighbor) - if bcoords is not None: - return bcoords, self.ravel_index(zi, neighbor) - - # Global fallback as last ditch effort - points = np.column_stack((x, y)) - face_ids = self.uxgrid.get_faces_containing_point(points, return_counts=False)[0] - fi = face_ids[0] if len(face_ids) > 0 else -1 - if fi == -1: - raise FieldOutOfBoundError(z, y, x) - bcoords = try_face(fi) - if bcoords is None: - raise FieldOutOfBoundError(z, y, x) - return {"Z": (zi, zeta), "FACE": (fi, bcoords)} - - def _get_barycentric_coordinates_latlon(self, y, x, fi): - """Checks if a point is inside a given face id on a UxGrid.""" - # Check if particle is in the same face, otherwise search again. - - n_nodes = self.uxgrid.n_nodes_per_face[fi].to_numpy() - node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes] - nodes = np.column_stack( - ( - np.deg2rad(self.uxgrid.node_lon[node_ids].to_numpy()), - np.deg2rad(self.uxgrid.node_lat[node_ids].to_numpy()), - ) - ) - - coord = np.deg2rad(np.column_stack((x, y))) - bcoord = np.asarray(_barycentric_coordinates(nodes, coord)) - proj_coord = np.matmul(np.transpose(nodes), bcoord) - err = np.linalg.norm(proj_coord - coord) - return bcoord, err + return {"Z": (zi, zeta), "FACE": (face_ids, bcoords)} - def _get_barycentric_coordinates_cartesian(self, y, x, fi): + def _get_barycentric_coordinates(self, y, x, fi): n_nodes = self.uxgrid.n_nodes_per_face[fi].to_numpy() node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes] @@ -152,6 +131,51 @@ def _get_barycentric_coordinates_cartesian(self, y, x, fi): return bcoord +def _triangle_area(A, B, C): + """Compute the area of a triangle given by three points.""" + d1 = B - A + d2 = C - A + d3 = np.cross(d1, d2) + return 0.5 * np.linalg.norm(d3) + + +def _barycentric_coordinates(nodes, point, min_area=1e-8): + """ + Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights. + So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized + barycentric coordinates, which is only valid for convex polygons. + + Parameters + ---------- + nodes : numpy.ndarray + Spherical coordinates (lat,lon) of each corner node of a face + point : numpy.ndarray + Spherical coordinates (lat,lon) of the point + + Returns + ------- + numpy.ndarray + Barycentric coordinates corresponding to each vertex. + + """ + n = len(nodes) + sum_wi = 0 + w = [] + + for i in range(0, n): + vim1 = nodes[i - 1] + vi = nodes[i] + vi1 = nodes[(i + 1) % n] + a0 = _triangle_area(vim1, vi, vi1) + a1 = max(_triangle_area(point, vim1, vi), min_area) + a2 = max(_triangle_area(point, vi, vi1), min_area) + sum_wi += a0 / (a1 * a2) + w.append(a0 / (a1 * a2)) + barycentric_coords = [w_i / sum_wi for w_i in w] + + return barycentric_coords + + def _lonlat_rad_to_xyz( lon, lat, diff --git a/parcels/xgrid.py b/parcels/xgrid.py index 6979d6f11..4f6eab7cf 100644 --- a/parcels/xgrid.py +++ b/parcels/xgrid.py @@ -10,7 +10,6 @@ from parcels._index_search import _search_indices_curvilinear_2d from parcels._typing import assert_valid_mesh from parcels.basegrid import BaseGrid -from parcels.spatialhash import SpatialHash _XGRID_AXES = Literal["X", "Y", "Z"] _XGRID_AXES_ORDERING: Sequence[_XGRID_AXES] = "ZYX" @@ -347,32 +346,6 @@ def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]: result[cast(_XGRID_AXES, axis)] = dim return result - def get_spatial_hash( - self, - reconstruct=False, - ): - """Get the SpatialHash data structure of this Grid that allows for - fast face search queries. Face searches are used to find the faces that - a list of points, in spherical coordinates, are contained within. - - Parameters - ---------- - global_grid : bool, default=False - If true, the hash grid is constructed using the domain [-pi,pi] x [-pi,pi] - reconstruct : bool, default=False - If true, reconstructs the spatial hash - - Returns - ------- - self._spatialhash : parcels.spatialhash.SpatialHash - SpatialHash instance - - """ - if self._spatialhash is None or reconstruct: - self._spatialhash = SpatialHash(self, reconstruct) - - return self._spatialhash - def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None: """For a given dimension name in a grid, returns the direction axis it is on."""