@@ -41,6 +41,7 @@ def __init__(
4141 self ._bitwidth = bitwidth # Max integer to use per coordinate in quantization (10 bits = 0..1023)
4242
4343 if isinstance_noimport (grid , "XGrid" ):
44+ self ._coord_dim = 2 # Number of computational coordinates is 2 (bilinear interpolation)
4445 if self ._source_grid ._mesh == "spherical" :
4546 # Boundaries of the hash grid are the unit cube
4647 self ._xmin = - 1.0
@@ -126,6 +127,7 @@ def __init__(
126127 self ._zhigh = np .zeros_like (self ._xlow )
127128
128129 elif isinstance_noimport (grid , "UxGrid" ):
130+ self ._coord_dim = grid .uxgrid .n_max_face_nodes # Number of barycentric coordinates
129131 if self ._source_grid ._mesh == "spherical" :
130132 # Boundaries of the hash grid are the unit cube
131133 self ._xmin = - 1.0
@@ -344,7 +346,11 @@ def query(self, y, x):
344346 pos = np .searchsorted (keys , query_codes ) # pos is shape (num_queries,)
345347
346348 # Valid hits: inside range with finite query coordinates and query codes give exact morton code match.
347- valid = (pos < len (keys )) & np .isfinite (x ) & np .isfinite (y ) & (query_codes == keys [pos ])
349+ valid = (pos < len (keys )) & np .isfinite (x ) & np .isfinite (y )
350+ # Clip pos to valid range to avoid out-of-bounds indexing
351+ pos = np .clip (pos , 0 , len (keys ) - 1 )
352+ # Further filter out false positives from searchsorted by checking for exact code match
353+ valid [valid ] &= query_codes [valid ] == keys [pos [valid ]]
348354
349355 # Pre-allocate i and j indices of the best match for each query
350356 # Default values to -1 (no match case)
@@ -357,7 +363,7 @@ def query(self, y, x):
357363 return (
358364 j_best .reshape (query_codes .shape ),
359365 i_best .reshape (query_codes .shape ),
360- np .full ((num_queries , 2 ), - 1.0 , dtype = np .float32 ),
366+ np .full ((num_queries , self . _coord_dim ), - 1.0 , dtype = np .float32 ),
361367 )
362368
363369 # Now, for each query, we need to gather the candidate (j,i) indices from the hash table
0 commit comments