66import uxarray as ux
77
88from parcels ._typing import assert_valid_mesh
9- from parcels .spatialhash import _barycentric_coordinates
10- from parcels .tools .statuscodes import FieldOutOfBoundError
119from parcels .xgrid import _search_1d_array
1210
1311from .basegrid import BaseGrid
@@ -43,6 +41,7 @@ def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh="flat") -> UxGrid
4341 raise ValueError ("z must be a 1D array of vertical coordinates" )
4442 self .z = z
4543 self ._mesh = mesh
44+ self ._spatialhash = None
4645
4746 assert_valid_mesh (mesh )
4847
@@ -74,63 +73,43 @@ def get_axis_dim(self, axis: _UXGRID_AXES) -> int:
7473 return self .uxgrid .n_face
7574
7675 def search (self , z , y , x , ei = None , tol = 1e-6 ):
77- def try_face (fid ):
78- bcoords , err = self ._get_barycentric_coordinates_latlon (y , x , fid )
79- if (bcoords >= 0 ).all () and (bcoords <= 1 ).all () and err < tol :
80- return bcoords
81- else :
82- bcoords = self ._get_barycentric_coordinates_cartesian (y , x , fid )
83- if (bcoords >= 0 ).all () and (bcoords <= 1 ).all ():
84- return bcoords
85-
86- return None
76+ """
77+ Search for the grid cell (face) and vertical layer that contains the given points.
8778
79+ Parameters
80+ ----------
81+ z : float or np.ndarray
82+ The vertical coordinate(s) (depth) of the point(s).
83+ y : float or np.ndarray
84+ The latitude(s) of the point(s).
85+ x : float or np.ndarray
86+ The longitude(s) of the point(s).
87+ ei : np.ndarray, optional
88+ Precomputed horizontal indices (face indices) for the points.
89+
90+ TO BE IMPLEMENTED : If provided, we'll check
91+ if the points are within the faces specified by these indices. For cells where the particles
92+ are not found, a nearest neighbor search will be performed. As a last resort, the spatial hash will be used.
93+ tol : float, optional
94+ Tolerance for barycentric coordinate checks. Default is 1e-6.
95+ """
8896 zi , zeta = _search_1d_array (self .z .values , z )
97+ _ , face_ids = self .get_spatial_hash ().query (y , x )
98+ valid_faces = face_ids != - 1
99+ bcoords = np .zeros ((len (face_ids ), self .uxgrid .n_max_face_nodes ), dtype = np .float32 )
100+ # Get the barycentric coordinates for all valid faces
101+ for idx in np .where (valid_faces )[0 ]:
102+ fi = face_ids [idx ]
103+ bc = self ._get_barycentric_coordinates (y , x , fi )
104+ if np .all (bc <= 1.0 ) and np .all (bc >= 0.0 ) and np .isclose (np .sum (bc ), 1.0 , atol = tol ):
105+ bcoords [idx , : len (bc )] = bc
106+ else :
107+ # If the barycentric coordinates are invalid, mark the face as invalid
108+ face_ids [idx ] = - 1
89109
90- if ei is not None :
91- _ , fi = self .unravel_index (ei )
92- bcoords = try_face (fi )
93- if bcoords is not None :
94- return bcoords , self .ravel_index (zi , fi )
95- # Try neighbors of current face
96- for neighbor in self .uxgrid .face_face_connectivity [fi , :]:
97- if neighbor == - 1 :
98- continue
99- bcoords = try_face (neighbor )
100- if bcoords is not None :
101- return bcoords , self .ravel_index (zi , neighbor )
102-
103- # Global fallback as last ditch effort
104- points = np .column_stack ((x , y ))
105- face_ids = self .uxgrid .get_faces_containing_point (points , return_counts = False )[0 ]
106- fi = face_ids [0 ] if len (face_ids ) > 0 else - 1
107- if fi == - 1 :
108- raise FieldOutOfBoundError (z , y , x )
109- bcoords = try_face (fi )
110- if bcoords is None :
111- raise FieldOutOfBoundError (z , y , x )
112- return {"Z" : (zi , zeta ), "FACE" : (fi , bcoords )}
113-
114- def _get_barycentric_coordinates_latlon (self , y , x , fi ):
115- """Checks if a point is inside a given face id on a UxGrid."""
116- # Check if particle is in the same face, otherwise search again.
117-
118- n_nodes = self .uxgrid .n_nodes_per_face [fi ].to_numpy ()
119- node_ids = self .uxgrid .face_node_connectivity [fi , 0 :n_nodes ]
120- nodes = np .column_stack (
121- (
122- np .deg2rad (self .uxgrid .node_lon [node_ids ].to_numpy ()),
123- np .deg2rad (self .uxgrid .node_lat [node_ids ].to_numpy ()),
124- )
125- )
126-
127- coord = np .deg2rad (np .column_stack ((x , y )))
128- bcoord = np .asarray (_barycentric_coordinates (nodes , coord ))
129- proj_coord = np .matmul (np .transpose (nodes ), bcoord )
130- err = np .linalg .norm (proj_coord - coord )
131- return bcoord , err
110+ return {"Z" : (zi , zeta ), "FACE" : (face_ids , bcoords )}
132111
133- def _get_barycentric_coordinates_cartesian (self , y , x , fi ):
112+ def _get_barycentric_coordinates (self , y , x , fi ):
134113 n_nodes = self .uxgrid .n_nodes_per_face [fi ].to_numpy ()
135114 node_ids = self .uxgrid .face_node_connectivity [fi , 0 :n_nodes ]
136115
@@ -152,6 +131,51 @@ def _get_barycentric_coordinates_cartesian(self, y, x, fi):
152131 return bcoord
153132
154133
134+ def _triangle_area (A , B , C ):
135+ """Compute the area of a triangle given by three points."""
136+ d1 = B - A
137+ d2 = C - A
138+ d3 = np .cross (d1 , d2 )
139+ return 0.5 * np .linalg .norm (d3 )
140+
141+
142+ def _barycentric_coordinates (nodes , point , min_area = 1e-8 ):
143+ """
144+ Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights.
145+ So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized
146+ barycentric coordinates, which is only valid for convex polygons.
147+
148+ Parameters
149+ ----------
150+ nodes : numpy.ndarray
151+ Spherical coordinates (lat,lon) of each corner node of a face
152+ point : numpy.ndarray
153+ Spherical coordinates (lat,lon) of the point
154+
155+ Returns
156+ -------
157+ numpy.ndarray
158+ Barycentric coordinates corresponding to each vertex.
159+
160+ """
161+ n = len (nodes )
162+ sum_wi = 0
163+ w = []
164+
165+ for i in range (0 , n ):
166+ vim1 = nodes [i - 1 ]
167+ vi = nodes [i ]
168+ vi1 = nodes [(i + 1 ) % n ]
169+ a0 = _triangle_area (vim1 , vi , vi1 )
170+ a1 = max (_triangle_area (point , vim1 , vi ), min_area )
171+ a2 = max (_triangle_area (point , vi , vi1 ), min_area )
172+ sum_wi += a0 / (a1 * a2 )
173+ w .append (a0 / (a1 * a2 ))
174+ barycentric_coords = [w_i / sum_wi for w_i in w ]
175+
176+ return barycentric_coords
177+
178+
155179def _lonlat_rad_to_xyz (
156180 lon ,
157181 lat ,
0 commit comments