@@ -128,3 +128,150 @@ def _search_indices_curvilinear_2d(
128128 eta = coords [:, 1 ]
129129
130130 return (yi , eta , xi , xsi )
131+
132+
133+ def uxgrid_point_in_cell (grid , y : np .ndarray , x : np .ndarray , yi : np .ndarray , xi : np .ndarray ):
134+ """Check if points are inside the grid cells defined by the given face indices.
135+
136+ Parameters
137+ ----------
138+ grid : ux.grid.Grid
139+ The uxarray grid object containing the unstructured grid data.
140+ y : np.ndarray
141+ Array of latitudes of the points to check.
142+ x : np.ndarray
143+ Array of longitudes of the points to check.
144+ yi : np.ndarray
145+ Array of face indices corresponding to the points.
146+ xi : np.ndarray
147+ Not used, but included for compatibility with other search functions.
148+
149+ Returns
150+ -------
151+ is_in_cell : np.ndarray
152+ An array indicating whether each point is inside (1) or outside (0) the corresponding cell.
153+ coords : np.ndarray
154+ Barycentric coordinates of the points within their respective cells.
155+ """
156+ if grid ._mesh == "spherical" :
157+ lon_rad = np .deg2rad (x )
158+ lat_rad = np .deg2rad (y )
159+ x_cart , y_cart , z_cart = _latlon_rad_to_xyz (lat_rad , lon_rad )
160+ points = np .column_stack ((x_cart .flatten (), y_cart .flatten (), z_cart .flatten ()))
161+
162+ # Get the vertex indices for each face
163+ nids = grid .uxgrid .face_node_connectivity [yi ].values
164+ face_vertices = np .stack (
165+ (
166+ grid .uxgrid .node_x [nids .ravel ()].values .reshape (nids .shape ),
167+ grid .uxgrid .node_y [nids .ravel ()].values .reshape (nids .shape ),
168+ grid .uxgrid .node_z [nids .ravel ()].values .reshape (nids .shape ),
169+ ),
170+ axis = - 1 ,
171+ )
172+ else :
173+ nids = grid .uxgrid .face_node_connectivity [yi ].values
174+ face_vertices = np .stack (
175+ (
176+ grid .uxgrid .node_lon [nids .ravel ()].values .reshape (nids .shape ),
177+ grid .uxgrid .node_lat [nids .ravel ()].values .reshape (nids .shape ),
178+ ),
179+ axis = - 1 ,
180+ )
181+ points = np .stack ((x , y ))
182+
183+ M = len (points )
184+
185+ is_in_cell = np .zeros (M , dtype = np .int32 )
186+
187+ coords = _barycentric_coordinates (face_vertices , points )
188+ is_in_cell = np .where (np .all ((coords >= - 1e-6 ) & (coords <= 1 + 1e-6 ), axis = 1 ), 1 , 0 )
189+
190+ return is_in_cell , coords
191+
192+
193+ def _triangle_area (A , B , C ):
194+ """Compute the area of a triangle given by three points."""
195+ d1 = B - A
196+ d2 = C - A
197+ if A .shape [- 1 ] == 2 :
198+ # 2D case: cross product reduces to scalar z-component
199+ cross = d1 [..., 0 ] * d2 [..., 1 ] - d1 [..., 1 ] * d2 [..., 0 ]
200+ area = 0.5 * np .abs (cross )
201+ elif A .shape [- 1 ] == 3 :
202+ # 3D case: full vector cross product
203+ cross = np .cross (d1 , d2 )
204+ area = 0.5 * np .linalg .norm (cross , axis = - 1 )
205+ else :
206+ raise ValueError (f"Expected last dim=2 or 3, got { A .shape [- 1 ]} " )
207+
208+ return area
209+ # d3 = np.cross(d1, d2, axis=-1)
210+ # breakpoint()
211+ # return 0.5 * np.linalg.norm(d3, axis=-1)
212+
213+
214+ def _barycentric_coordinates (nodes , points , min_area = 1e-8 ):
215+ """
216+ Compute the barycentric coordinates of a point P inside a convex polygon using area-based weights.
217+ So that this method generalizes to n-sided polygons, we use the Waschpress points as the generalized
218+ barycentric coordinates, which is only valid for convex polygons.
219+
220+ Parameters
221+ ----------
222+ nodes : numpy.ndarray
223+ Polygon verties per query of shape (M, 3, 2/3) where M is the number of query points. The second dimension corresponds to the number
224+ of vertices
225+ The last dimension can be either 2 or 3, where 3 corresponds to the (z, y, x) coordinates of each vertex and 2 corresponds to the
226+ (lat, lon) coordinates of each vertex.
227+
228+ points : numpy.ndarray
229+ Spherical coordinates of the point (M,2/3) where M is the number of query points.
230+
231+ Returns
232+ -------
233+ numpy.ndarray
234+ Barycentric coordinates corresponding to each vertex.
235+
236+ """
237+ M , K = nodes .shape [:2 ]
238+
239+ # roll(-1) to get vi+1, roll(+1) to get vi-1
240+ vi = nodes # (M,K,2)
241+ vi1 = np .roll (nodes , shift = - 1 , axis = 1 ) # (M,K,2)
242+ vim1 = np .roll (nodes , shift = + 1 , axis = 1 ) # (M,K,2)
243+
244+ # a0 = area(v_{i-1}, v_i, v_{i+1})
245+ a0 = _triangle_area (vim1 , vi , vi1 ) # (M,K)
246+
247+ # a1 = area(P, v_{i-1}, v_i); a2 = area(P, v_i, v_{i+1})
248+ P = points [:, None , :] # (M,1,2) -> (M,K,2)
249+ a1 = _triangle_area (P , vim1 , vi )
250+ a2 = _triangle_area (P , vi , vi1 )
251+
252+ # clamp tiny denominators for stability
253+ a1c = np .maximum (a1 , min_area )
254+ a2c = np .maximum (a2 , min_area )
255+
256+ wi = a0 / (a1c * a2c ) # (M,K)
257+
258+ sum_wi = wi .sum (axis = 1 , keepdims = True ) # (M,1)
259+ # Avoid 0/0: if sum_wi==0 (degenerate), keep zeros
260+ with np .errstate (invalid = "ignore" , divide = "ignore" ):
261+ bcoords = wi / sum_wi
262+
263+ return bcoords
264+
265+
266+ def _latlon_rad_to_xyz (
267+ lat ,
268+ lon ,
269+ ):
270+ """Converts Spherical latitude and longitude coordinates into Cartesian x,
271+ y, z coordinates.
272+ """
273+ x = np .cos (lon ) * np .cos (lat )
274+ y = np .sin (lon ) * np .cos (lat )
275+ z = np .sin (lat )
276+
277+ return x , y , z
0 commit comments