Skip to content

Commit 95c8a31

Browse files
Move get_spatial_hash() to basegrid and update uxgrid.search
Now that both xgrid and uxgrid use get_spatial_hash(), this method has been pushed to the basegrid class.
1 parent 4048685 commit 95c8a31

File tree

3 files changed

+105
-82
lines changed

3 files changed

+105
-82
lines changed

parcels/basegrid.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import numpy as np
88

9+
from parcels.spatialhash import SpatialHash
10+
911
if TYPE_CHECKING:
1012
import numpy as np
1113

@@ -140,6 +142,30 @@ def unravel_index(self, ei: int) -> dict[str, int]:
140142
indices = _unravel(dims, ei)
141143
return dict(zip(self.axes, indices, strict=True))
142144

145+
def get_spatial_hash(
146+
self,
147+
reconstruct=False,
148+
):
149+
"""Get the SpatialHash data structure of this Grid that allows for
150+
fast face search queries. Face searches are used to find the faces that
151+
a list of points, in spherical coordinates, are contained within.
152+
153+
Parameters
154+
----------
155+
reconstruct : bool, default=False
156+
If true, reconstructs the spatial hash
157+
158+
Returns
159+
-------
160+
self._spatialhash : parcels.spatialhash.SpatialHash
161+
SpatialHash instance
162+
163+
"""
164+
if self._spatialhash is None or reconstruct:
165+
self._spatialhash = SpatialHash(self, reconstruct)
166+
167+
return self._spatialhash
168+
143169
@property
144170
@abstractmethod
145171
def axes(self) -> list[str]:

parcels/uxgrid.py

Lines changed: 79 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import uxarray as ux
77

88
from parcels._typing import assert_valid_mesh
9-
from parcels.spatialhash import _barycentric_coordinates
10-
from parcels.tools.statuscodes import FieldOutOfBoundError
119
from parcels.xgrid import _search_1d_array
1210

1311
from .basegrid import BaseGrid
@@ -43,6 +41,7 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid
4341
raise ValueError("z must be a 1D array of vertical coordinates")
4442
self.z = z
4543
self._mesh = mesh
44+
self._spatialhash = None
4645

4746
assert_valid_mesh(mesh)
4847

@@ -74,63 +73,43 @@ def get_axis_dim(self, axis: _UXGRID_AXES) -> int:
7473
return self.uxgrid.n_face
7574

7675
def search(self, z, y, x, ei=None, tol=1e-6):
77-
def try_face(fid):
78-
bcoords, err = self._get_barycentric_coordinates_latlon(y, x, fid)
79-
if (bcoords >= 0).all() and (bcoords <= 1).all() and err < tol:
80-
return bcoords
81-
else:
82-
bcoords = self._get_barycentric_coordinates_cartesian(y, x, fid)
83-
if (bcoords >= 0).all() and (bcoords <= 1).all():
84-
return bcoords
85-
86-
return None
76+
"""
77+
Search for the grid cell (face) and vertical layer that contains the given points.
8778
79+
Parameters
80+
----------
81+
z : float or np.ndarray
82+
The vertical coordinate(s) (depth) of the point(s).
83+
y : float or np.ndarray
84+
The latitude(s) of the point(s).
85+
x : float or np.ndarray
86+
The longitude(s) of the point(s).
87+
ei : np.ndarray, optional
88+
Precomputed horizontal indices (face indices) for the points.
89+
90+
TO BE IMPLEMENTED : If provided, we'll check
91+
if the points are within the faces specified by these indices. For cells where the particles
92+
are not found, a nearest neighbor search will be performed. As a last resort, the spatial hash will be used.
93+
tol : float, optional
94+
Tolerance for barycentric coordinate checks. Default is 1e-6.
95+
"""
8896
zi, zeta = _search_1d_array(self.z.values, z)
97+
_, face_ids = self.get_spatial_hash().query(y, x)
98+
valid_faces = face_ids != -1
99+
bcoords = np.zeros((len(face_ids), self.uxgrid.n_max_face_nodes), dtype=np.float32)
100+
# Get the barycentric coordinates for all valid faces
101+
for idx in np.where(valid_faces)[0]:
102+
fi = face_ids[idx]
103+
bc = self._get_barycentric_coordinates(y, x, fi)
104+
if np.all(bc <= 1.0) and np.all(bc >= 0.0) and np.isclose(np.sum(bc), 1.0, atol=tol):
105+
bcoords[idx, : len(bc)] = bc
106+
else:
107+
# If the barycentric coordinates are invalid, mark the face as invalid
108+
face_ids[idx] = -1
89109

90-
if ei is not None:
91-
_, fi = self.unravel_index(ei)
92-
bcoords = try_face(fi)
93-
if bcoords is not None:
94-
return bcoords, self.ravel_index(zi, fi)
95-
# Try neighbors of current face
96-
for neighbor in self.uxgrid.face_face_connectivity[fi, :]:
97-
if neighbor == -1:
98-
continue
99-
bcoords = try_face(neighbor)
100-
if bcoords is not None:
101-
return bcoords, self.ravel_index(zi, neighbor)
102-
103-
# Global fallback as last ditch effort
104-
points = np.column_stack((x, y))
105-
face_ids = self.uxgrid.get_faces_containing_point(points, return_counts=False)[0]
106-
fi = face_ids[0] if len(face_ids) > 0 else -1
107-
if fi == -1:
108-
raise FieldOutOfBoundError(z, y, x)
109-
bcoords = try_face(fi)
110-
if bcoords is None:
111-
raise FieldOutOfBoundError(z, y, x)
112-
return {"Z": (zi, zeta), "FACE": (fi, bcoords)}
113-
114-
def _get_barycentric_coordinates_latlon(self, y, x, fi):
115-
"""Checks if a point is inside a given face id on a UxGrid."""
116-
# Check if particle is in the same face, otherwise search again.
117-
118-
n_nodes = self.uxgrid.n_nodes_per_face[fi].to_numpy()
119-
node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes]
120-
nodes = np.column_stack(
121-
(
122-
np.deg2rad(self.uxgrid.node_lon[node_ids].to_numpy()),
123-
np.deg2rad(self.uxgrid.node_lat[node_ids].to_numpy()),
124-
)
125-
)
126-
127-
coord = np.deg2rad(np.column_stack((x, y)))
128-
bcoord = np.asarray(_barycentric_coordinates(nodes, coord))
129-
proj_coord = np.matmul(np.transpose(nodes), bcoord)
130-
err = np.linalg.norm(proj_coord - coord)
131-
return bcoord, err
110+
return {"Z": (zi, zeta), "FACE": (face_ids, bcoords)}
132111

133-
def _get_barycentric_coordinates_cartesian(self, y, x, fi):
112+
def _get_barycentric_coordinates(self, y, x, fi):
134113
n_nodes = self.uxgrid.n_nodes_per_face[fi].to_numpy()
135114
node_ids = self.uxgrid.face_node_connectivity[fi, 0:n_nodes]
136115

@@ -152,6 +131,51 @@ def _get_barycentric_coordinates_cartesian(self, y, x, fi):
152131
return bcoord
153132

154133

134+
def _triangle_area(A, B, C):
135+
"""Compute the area of a triangle given by three points."""
136+
d1 = B - A
137+
d2 = C - A
138+
d3 = np.cross(d1, d2)
139+
return 0.5 * np.linalg.norm(d3)
140+
141+
142+
def _barycentric_coordinates(nodes, point, min_area=1e-8):
143+
"""
144+
Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights.
145+
So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized
146+
barycentric coordinates, which is only valid for convex polygons.
147+
148+
Parameters
149+
----------
150+
nodes : numpy.ndarray
151+
Spherical coordinates (lat,lon) of each corner node of a face
152+
point : numpy.ndarray
153+
Spherical coordinates (lat,lon) of the point
154+
155+
Returns
156+
-------
157+
numpy.ndarray
158+
Barycentric coordinates corresponding to each vertex.
159+
160+
"""
161+
n = len(nodes)
162+
sum_wi = 0
163+
w = []
164+
165+
for i in range(0, n):
166+
vim1 = nodes[i - 1]
167+
vi = nodes[i]
168+
vi1 = nodes[(i + 1) % n]
169+
a0 = _triangle_area(vim1, vi, vi1)
170+
a1 = max(_triangle_area(point, vim1, vi), min_area)
171+
a2 = max(_triangle_area(point, vi, vi1), min_area)
172+
sum_wi += a0 / (a1 * a2)
173+
w.append(a0 / (a1 * a2))
174+
barycentric_coords = [w_i / sum_wi for w_i in w]
175+
176+
return barycentric_coords
177+
178+
155179
def _lonlat_rad_to_xyz(
156180
lon,
157181
lat,

parcels/xgrid.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from parcels._index_search import _search_indices_curvilinear_2d
1111
from parcels._typing import assert_valid_mesh
1212
from parcels.basegrid import BaseGrid
13-
from parcels.spatialhash import SpatialHash
1413

1514
_XGRID_AXES = Literal["X", "Y", "Z"]
1615
_XGRID_AXES_ORDERING: Sequence[_XGRID_AXES] = "ZYX"
@@ -347,32 +346,6 @@ def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]:
347346
result[cast(_XGRID_AXES, axis)] = dim
348347
return result
349348

350-
def get_spatial_hash(
351-
self,
352-
reconstruct=False,
353-
):
354-
"""Get the SpatialHash data structure of this Grid that allows for
355-
fast face search queries. Face searches are used to find the faces that
356-
a list of points, in spherical coordinates, are contained within.
357-
358-
Parameters
359-
----------
360-
global_grid : bool, default=False
361-
If true, the hash grid is constructed using the domain [-pi,pi] x [-pi,pi]
362-
reconstruct : bool, default=False
363-
If true, reconstructs the spatial hash
364-
365-
Returns
366-
-------
367-
self._spatialhash : parcels.spatialhash.SpatialHash
368-
SpatialHash instance
369-
370-
"""
371-
if self._spatialhash is None or reconstruct:
372-
self._spatialhash = SpatialHash(self, reconstruct)
373-
374-
return self._spatialhash
375-
376349

377350
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
378351
"""For a given dimension name in a grid, returns the direction axis it is on."""

0 commit comments

Comments
 (0)