Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions parcels/_datasets/unstructured/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,9 @@ def _fesom2_square_delaunay_antimeridian():
All fields are placed on location consistent with FESOM2 variable placement conventions
"""
lon, lat = np.meshgrid(
np.linspace(-210.0, -150.0, Nx, dtype=np.float32), np.linspace(0, 60.0, Nx, dtype=np.float32)
np.linspace(-210.0, -150.0, Nx, dtype=np.float32), np.linspace(-40.0, 40.0, Nx, dtype=np.float32)
)
# wrap longitude from [-180,180]
lon = np.where(lon < -180, lon + 360, lon)
lon_flat = lon.ravel()
lat_flat = lat.ravel()
zf = np.linspace(0.0, 1000.0, 10, endpoint=True, dtype=np.float32) # Vertical element faces
Expand All @@ -231,7 +230,10 @@ def _fesom2_square_delaunay_antimeridian():

# mask any point on one of the boundaries
mask = (
np.isclose(lon_flat, 0.0) | np.isclose(lon_flat, 60.0) | np.isclose(lat_flat, 0.0) | np.isclose(lat_flat, 60.0)
np.isclose(lon_flat, -210.0)
| np.isclose(lon_flat, -150.0)
| np.isclose(lat_flat, -40.0)
| np.isclose(lat_flat, 40.0)
)

boundary_points = np.flatnonzero(mask)
Expand Down
147 changes: 147 additions & 0 deletions parcels/_index_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,150 @@ def _search_indices_curvilinear_2d(
eta = coords[:, 1]

return (yi, eta, xi, xsi)


def uxgrid_point_in_cell(grid, y: np.ndarray, x: np.ndarray, yi: np.ndarray, xi: np.ndarray):
"""Check if points are inside the grid cells defined by the given face indices.

Parameters
----------
grid : ux.grid.Grid
The uxarray grid object containing the unstructured grid data.
y : np.ndarray
Array of latitudes of the points to check.
x : np.ndarray
Array of longitudes of the points to check.
yi : np.ndarray
Array of face indices corresponding to the points.
xi : np.ndarray
Not used, but included for compatibility with other search functions.

Returns
-------
is_in_cell : np.ndarray
An array indicating whether each point is inside (1) or outside (0) the corresponding cell.
coords : np.ndarray
Barycentric coordinates of the points within their respective cells.
"""
if grid._mesh == "spherical":
lon_rad = np.deg2rad(x)
lat_rad = np.deg2rad(y)
x_cart, y_cart, z_cart = _latlon_rad_to_xyz(lat_rad, lon_rad)
points = np.column_stack((x_cart.flatten(), y_cart.flatten(), z_cart.flatten()))

# Get the vertex indices for each face
nids = grid.uxgrid.face_node_connectivity[yi].values
face_vertices = np.stack(
(
grid.uxgrid.node_x[nids.ravel()].values.reshape(nids.shape),
grid.uxgrid.node_y[nids.ravel()].values.reshape(nids.shape),
grid.uxgrid.node_z[nids.ravel()].values.reshape(nids.shape),
),
axis=-1,
)
else:
nids = grid.uxgrid.face_node_connectivity[yi].values
face_vertices = np.stack(
(
grid.uxgrid.node_lon[nids.ravel()].values.reshape(nids.shape),
grid.uxgrid.node_lat[nids.ravel()].values.reshape(nids.shape),
),
axis=-1,
)
points = np.stack((x, y))

M = len(points)

is_in_cell = np.zeros(M, dtype=np.int32)

coords = _barycentric_coordinates(face_vertices, points)
is_in_cell = np.where(np.all((coords >= -1e-6) & (coords <= 1 + 1e-6), axis=1), 1, 0)

return is_in_cell, coords


def _triangle_area(A, B, C):
"""Compute the area of a triangle given by three points."""
d1 = B - A
d2 = C - A
if A.shape[-1] == 2:
# 2D case: cross product reduces to scalar z-component
cross = d1[..., 0] * d2[..., 1] - d1[..., 1] * d2[..., 0]
area = 0.5 * np.abs(cross)
elif A.shape[-1] == 3:
# 3D case: full vector cross product
cross = np.cross(d1, d2)
area = 0.5 * np.linalg.norm(cross, axis=-1)
else:
raise ValueError(f"Expected last dim=2 or 3, got {A.shape[-1]}")

return area
# d3 = np.cross(d1, d2, axis=-1)
# breakpoint()
# return 0.5 * np.linalg.norm(d3, axis=-1)


def _barycentric_coordinates(nodes, points, min_area=1e-8):
"""
Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights.
So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized
barycentric coordinates, which is only valid for convex polygons.

Parameters
----------
nodes : numpy.ndarray
Polygon verties per query of shape (M, 3, 2/3) where M is the number of query points. The second dimension corresponds to the number
of vertices
The last dimension can be either 2 or 3, where 3 corresponds to the (z, y, x) coordinates of each vertex and 2 corresponds to the
(lat, lon) coordinates of each vertex.

points : numpy.ndarray
Spherical coordinates of the point (M,2/3) where M is the number of query points.

Returns
-------
numpy.ndarray
Barycentric coordinates corresponding to each vertex.

"""
M, K = nodes.shape[:2]

# roll(-1) to get vi+1, roll(+1) to get vi-1
vi = nodes # (M,K,2)
vi1 = np.roll(nodes, shift=-1, axis=1) # (M,K,2)
vim1 = np.roll(nodes, shift=+1, axis=1) # (M,K,2)

# a0 = area(v_{i-1}, v_i, v_{i+1})
a0 = _triangle_area(vim1, vi, vi1) # (M,K)

# a1 = area(P, v_{i-1}, v_i); a2 = area(P, v_i, v_{i+1})
P = points[:, None, :] # (M,1,2) -> (M,K,2)
a1 = _triangle_area(P, vim1, vi)
a2 = _triangle_area(P, vi, vi1)

# clamp tiny denominators for stability
a1c = np.maximum(a1, min_area)
a2c = np.maximum(a2, min_area)

wi = a0 / (a1c * a2c) # (M,K)

sum_wi = wi.sum(axis=1, keepdims=True) # (M,1)
# Avoid 0/0: if sum_wi==0 (degenerate), keep zeros
with np.errstate(invalid="ignore", divide="ignore"):
bcoords = wi / sum_wi

return bcoords


def _latlon_rad_to_xyz(
lat,
lon,
):
"""Converts Spherical latitude and longitude coordinates into Cartesian x,
y, z coordinates.
"""
x = np.cos(lon) * np.cos(lat)
y = np.sin(lon) * np.cos(lat)
z = np.sin(lat)

return x, y, z
4 changes: 2 additions & 2 deletions parcels/application_kernels/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,8 +657,8 @@ def UXPiecewiseLinearNode(
# The zi refers to the vertical layer index. The field in this routine are assumed to be defined at the vertical interface levels.
# For interface zi, the interface indices are [zi, zi+1], so we need to use the values at zi and zi+1.
# First, do barycentric interpolation in the lateral direction for each interface level
fzk = np.dot(field.data.values[ti, k, node_ids], bcoords)
fzkp1 = np.dot(field.data.values[ti, k + 1, node_ids], bcoords)
fzk = np.sum(field.data.values[ti, k, node_ids] * bcoords, axis=-1)
fzkp1 = np.sum(field.data.values[ti, k + 1, node_ids] * bcoords, axis=-1)

# Then, do piecewise linear interpolation in the vertical direction
zk = field.grid.z.values[k]
Expand Down
28 changes: 28 additions & 0 deletions parcels/basegrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np

from parcels.spatialhash import SpatialHash

if TYPE_CHECKING:
import numpy as np

Expand Down Expand Up @@ -178,6 +180,32 @@ def get_axis_dim(self, axis: str) -> int:
"""
...

def get_spatial_hash(
self,
reconstruct=False,
):
"""Get the SpatialHash data structure of this Grid that allows for
fast face search queries. Face searches are used to find the faces that
a list of points, in spherical coordinates, are contained within.

Parameters
----------
global_grid : bool, default=False
If true, the hash grid is constructed using the domain [-pi,pi] x [-pi,pi]
reconstruct : bool, default=False
If true, reconstructs the spatial hash

Returns
-------
self._spatialhash : parcels.spatialhash.SpatialHash
SpatialHash instance

"""
if self._spatialhash is None or reconstruct:
self._spatialhash = SpatialHash(self)

return self._spatialhash


def _unravel(dims, ei):
"""
Expand Down
Loading
Loading