Skip to content

Commit a9eaa88

Browse files
Merge pull request #2218 from OceanParcels/feature/uxgrid-morton-hashing
Feature/uxgrid morton hashing
2 parents 3592b64 + 0944797 commit a9eaa88

File tree

10 files changed

+405
-289
lines changed

10 files changed

+405
-289
lines changed

parcels/_datasets/unstructured/generic.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,9 @@ def _fesom2_square_delaunay_antimeridian():
218218
All fields are placed on location consistent with FESOM2 variable placement conventions
219219
"""
220220
lon, lat = np.meshgrid(
221-
np.linspace(-210.0, -150.0, Nx, dtype=np.float32), np.linspace(0, 60.0, Nx, dtype=np.float32)
221+
np.linspace(-210.0, -150.0, Nx, dtype=np.float32), np.linspace(-40.0, 40.0, Nx, dtype=np.float32)
222222
)
223223
# wrap longitude from [-180,180]
224-
lon = np.where(lon < -180, lon + 360, lon)
225224
lon_flat = lon.ravel()
226225
lat_flat = lat.ravel()
227226
zf = np.linspace(0.0, 1000.0, 10, endpoint=True, dtype=np.float32) # Vertical element faces
@@ -231,7 +230,10 @@ def _fesom2_square_delaunay_antimeridian():
231230

232231
# mask any point on one of the boundaries
233232
mask = (
234-
np.isclose(lon_flat, 0.0) | np.isclose(lon_flat, 60.0) | np.isclose(lat_flat, 0.0) | np.isclose(lat_flat, 60.0)
233+
np.isclose(lon_flat, -210.0)
234+
| np.isclose(lon_flat, -150.0)
235+
| np.isclose(lat_flat, -40.0)
236+
| np.isclose(lat_flat, 40.0)
235237
)
236238

237239
boundary_points = np.flatnonzero(mask)

parcels/_index_search.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,150 @@ def _search_indices_curvilinear_2d(
128128
eta = coords[:, 1]
129129

130130
return (yi, eta, xi, xsi)
131+
132+
133+
def uxgrid_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):
134+
"""Check if points are inside the grid cells defined by the given face indices.
135+
136+
Parameters
137+
----------
138+
grid : ux.grid.Grid
139+
The uxarray grid object containing the unstructured grid data.
140+
y : np.ndarray
141+
Array of latitudes of the points to check.
142+
x : np.ndarray
143+
Array of longitudes of the points to check.
144+
yi : np.ndarray
145+
Array of face indices corresponding to the points.
146+
xi : np.ndarray
147+
Not used, but included for compatibility with other search functions.
148+
149+
Returns
150+
-------
151+
is_in_cell : np.ndarray
152+
An array indicating whether each point is inside (1) or outside (0) the corresponding cell.
153+
coords : np.ndarray
154+
Barycentric coordinates of the points within their respective cells.
155+
"""
156+
if grid._mesh == "spherical":
157+
lon_rad = np.deg2rad(x)
158+
lat_rad = np.deg2rad(y)
159+
x_cart, y_cart, z_cart = _latlon_rad_to_xyz(lat_rad, lon_rad)
160+
points = np.column_stack((x_cart.flatten(), y_cart.flatten(), z_cart.flatten()))
161+
162+
# Get the vertex indices for each face
163+
nids = grid.uxgrid.face_node_connectivity[yi].values
164+
face_vertices = np.stack(
165+
(
166+
grid.uxgrid.node_x[nids.ravel()].values.reshape(nids.shape),
167+
grid.uxgrid.node_y[nids.ravel()].values.reshape(nids.shape),
168+
grid.uxgrid.node_z[nids.ravel()].values.reshape(nids.shape),
169+
),
170+
axis=-1,
171+
)
172+
else:
173+
nids = grid.uxgrid.face_node_connectivity[yi].values
174+
face_vertices = np.stack(
175+
(
176+
grid.uxgrid.node_lon[nids.ravel()].values.reshape(nids.shape),
177+
grid.uxgrid.node_lat[nids.ravel()].values.reshape(nids.shape),
178+
),
179+
axis=-1,
180+
)
181+
points = np.stack((x, y))
182+
183+
M = len(points)
184+
185+
is_in_cell = np.zeros(M, dtype=np.int32)
186+
187+
coords = _barycentric_coordinates(face_vertices, points)
188+
is_in_cell = np.where(np.all((coords >= -1e-6) & (coords <= 1 + 1e-6), axis=1), 1, 0)
189+
190+
return is_in_cell, coords
191+
192+
193+
def _triangle_area(A, B, C):
194+
"""Compute the area of a triangle given by three points."""
195+
d1 = B - A
196+
d2 = C - A
197+
if A.shape[-1] == 2:
198+
# 2D case: cross product reduces to scalar z-component
199+
cross = d1[..., 0] * d2[..., 1] - d1[..., 1] * d2[..., 0]
200+
area = 0.5 * np.abs(cross)
201+
elif A.shape[-1] == 3:
202+
# 3D case: full vector cross product
203+
cross = np.cross(d1, d2)
204+
area = 0.5 * np.linalg.norm(cross, axis=-1)
205+
else:
206+
raise ValueError(f"Expected last dim=2 or 3, got {A.shape[-1]}")
207+
208+
return area
209+
# d3 = np.cross(d1, d2, axis=-1)
210+
# breakpoint()
211+
# return 0.5 * np.linalg.norm(d3, axis=-1)
212+
213+
214+
def _barycentric_coordinates(nodes, points, min_area=1e-8):
215+
"""
216+
Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights.
217+
So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized
218+
barycentric coordinates, which is only valid for convex polygons.
219+
220+
Parameters
221+
----------
222+
nodes : numpy.ndarray
223+
Polygon verties per query of shape (M, 3, 2/3) where M is the number of query points. The second dimension corresponds to the number
224+
of vertices
225+
The last dimension can be either 2 or 3, where 3 corresponds to the (z, y, x) coordinates of each vertex and 2 corresponds to the
226+
(lat, lon) coordinates of each vertex.
227+
228+
points : numpy.ndarray
229+
Spherical coordinates of the point (M,2/3) where M is the number of query points.
230+
231+
Returns
232+
-------
233+
numpy.ndarray
234+
Barycentric coordinates corresponding to each vertex.
235+
236+
"""
237+
M, K = nodes.shape[:2]
238+
239+
# roll(-1) to get vi+1, roll(+1) to get vi-1
240+
vi = nodes # (M,K,2)
241+
vi1 = np.roll(nodes, shift=-1, axis=1) # (M,K,2)
242+
vim1 = np.roll(nodes, shift=+1, axis=1) # (M,K,2)
243+
244+
# a0 = area(v_{i-1}, v_i, v_{i+1})
245+
a0 = _triangle_area(vim1, vi, vi1) # (M,K)
246+
247+
# a1 = area(P, v_{i-1}, v_i); a2 = area(P, v_i, v_{i+1})
248+
P = points[:, None, :] # (M,1,2) -> (M,K,2)
249+
a1 = _triangle_area(P, vim1, vi)
250+
a2 = _triangle_area(P, vi, vi1)
251+
252+
# clamp tiny denominators for stability
253+
a1c = np.maximum(a1, min_area)
254+
a2c = np.maximum(a2, min_area)
255+
256+
wi = a0 / (a1c * a2c) # (M,K)
257+
258+
sum_wi = wi.sum(axis=1, keepdims=True) # (M,1)
259+
# Avoid 0/0: if sum_wi==0 (degenerate), keep zeros
260+
with np.errstate(invalid="ignore", divide="ignore"):
261+
bcoords = wi / sum_wi
262+
263+
return bcoords
264+
265+
266+
def _latlon_rad_to_xyz(
267+
lat,
268+
lon,
269+
):
270+
"""Converts Spherical latitude and longitude coordinates into Cartesian x,
271+
y, z coordinates.
272+
"""
273+
x = np.cos(lon) * np.cos(lat)
274+
y = np.sin(lon) * np.cos(lat)
275+
z = np.sin(lat)
276+
277+
return x, y, z

parcels/application_kernels/interpolation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -657,8 +657,8 @@ def UXPiecewiseLinearNode(
657657
# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
658658
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
659659
# First, do barycentric interpolation in the lateral direction for each interface level
660-
fzk = np.dot(field.data.values[ti, k, node_ids], bcoords)
661-
fzkp1 = np.dot(field.data.values[ti, k + 1, node_ids], bcoords)
660+
fzk = np.sum(field.data.values[ti, k, node_ids] * bcoords, axis=-1)
661+
fzkp1 = np.sum(field.data.values[ti, k + 1, node_ids] * bcoords, axis=-1)
662662

663663
# Then, do piecewise linear interpolation in the vertical direction
664664
zk = field.grid.z.values[k]

parcels/basegrid.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import numpy as np
88

9+
from parcels.spatialhash import SpatialHash
10+
911
if TYPE_CHECKING:
1012
import numpy as np
1113

@@ -178,6 +180,32 @@ def get_axis_dim(self, axis: str) -> int:
178180
"""
179181
...
180182

183+
def get_spatial_hash(
184+
self,
185+
reconstruct=False,
186+
):
187+
"""Get the SpatialHash data structure of this Grid that allows for
188+
fast face search queries. Face searches are used to find the faces that
189+
a list of points, in spherical coordinates, are contained within.
190+
191+
Parameters
192+
----------
193+
global_grid : bool, default=False
194+
If true, the hash grid is constructed using the domain [-pi,pi] x [-pi,pi]
195+
reconstruct : bool, default=False
196+
If true, reconstructs the spatial hash
197+
198+
Returns
199+
-------
200+
self._spatialhash : parcels.spatialhash.SpatialHash
201+
SpatialHash instance
202+
203+
"""
204+
if self._spatialhash is None or reconstruct:
205+
self._spatialhash = SpatialHash(self)
206+
207+
return self._spatialhash
208+
181209

182210
def _unravel(dims, ei):
183211
"""

0 commit comments

Comments
 (0)