@@ -192,11 +192,18 @@ def query(
192192 # Locate each query in the unique key array
193193 pos = np .searchsorted (keys , query_codes ) # pos is shape (N,)
194194
195- # Valid hits: inside range
196- valid = pos < len (keys )
195+ # Valid hits: inside range with finite query coordinates
196+ valid = (pos < len (keys )) & np .isfinite (x ) & np .isfinite (y )
197+
198+ # Pre-allocate i and j indices of the best match for each query
199+ # Default values to -1 (no match case)
200+ j_best = np .full (num_queries , - 1 , dtype = np .int64 )
201+ i_best = np .full (num_queries , - 1 , dtype = np .int64 )
197202
198203 # How many matches each query has; hit_counts[i] is the number of hits for query i
199204 hit_counts = np .where (valid , counts [pos ], 0 ).astype (np .int64 ) # has shape (N,)
205+ if hit_counts .sum () == 0 :
206+ return (j_best .reshape (query_codes .shape ), i_best .reshape (query_codes .shape ))
200207
201208 # CSR-style offsets (prefix sum), total number of hits
202209 offsets = np .empty (hit_counts .size + 1 , dtype = np .int64 )
@@ -220,40 +227,40 @@ def query(
220227 j_all = j [source_idx ]
221228 i_all = i [source_idx ]
222229
223- # Gather centroid coordinates at those (j,i)
224- xc_all = xc [j_all , i_all ]
225- yc_all = yc [j_all , i_all ]
226- zc_all = zc [j_all , i_all ]
227-
228- # Broadcast to flat (same as q_flat), then repeat per candidate
229- # This makes it easy to compute distances from the query points
230- # to the candidate points for minimization.
231- qx_all = np .repeat (qx .ravel (), hit_counts )
232- qy_all = np .repeat (qy .ravel (), hit_counts )
233- qz_all = np .repeat (qz .ravel (), hit_counts )
234-
235- # Squared distances for all candidates
236- dist_all = (xc_all - qx_all ) ** 2 + (yc_all - qy_all ) ** 2 + (zc_all - qz_all ) ** 2
237-
238230 # Segment-wise minima per query using reduceat
239231 # For each query, we need to find the minimum distance.
240- dmin_per_q = np .minimum .reduceat (dist_all , offsets [:- 1 ])
241-
242- # To get argmin indices per query (without loops):
243- # Build a masked "within-index" array that is large unless it equals the segment-min.
244- big = np .iinfo (np .int64 ).max
245- within_masked = np .where (dist_all == np .repeat (dmin_per_q , hit_counts ), intra , big )
246- argmin_within = np .minimum .reduceat (within_masked , offsets [:- 1 ]) # first occurrence in ties
247-
248- # Build absolute source index for the winning candidate in each query
249- start_for_q = np .where (valid , starts [pos ], 0 ) # 0 is dummy for invalid queries
250- src_best = (start_for_q + argmin_within ).astype (np .int64 )
232+ if total == 1 :
233+ # Build absolute source index for the winning candidate in each query
234+ start_for_q = np .where (valid , starts [pos ], 0 ) # 0 is dummy for invalid queries
235+ src_best = (start_for_q ).astype (np .int64 )
236+ else :
237+ # Gather centroid coordinates at those (j,i)
238+ xc_all = xc [j_all , i_all ]
239+ yc_all = yc [j_all , i_all ]
240+ zc_all = zc [j_all , i_all ]
241+
242+ # Broadcast to flat (same as q_flat), then repeat per candidate
243+ # This makes it easy to compute distances from the query points
244+ # to the candidate points for minimization.
245+ qx_all = np .repeat (qx .ravel (), hit_counts )
246+ qy_all = np .repeat (qy .ravel (), hit_counts )
247+ qz_all = np .repeat (qz .ravel (), hit_counts )
248+
249+ # Squared distances for all candidates
250+ dist_all = (xc_all - qx_all ) ** 2 + (yc_all - qy_all ) ** 2 + (zc_all - qz_all ) ** 2
251+
252+ dmin_per_q = np .minimum .reduceat (dist_all , offsets [:- 1 ])
253+ # To get argmin indices per query (without loops):
254+ # Build a masked "within-index" array that is large unless it equals the segment-min.
255+ big = np .iinfo (np .int64 ).max
256+ within_masked = np .where (dist_all == np .repeat (dmin_per_q , hit_counts ), intra , big )
257+ argmin_within = np .minimum .reduceat (within_masked , offsets [:- 1 ]) # first occurrence in ties
258+
259+ # Build absolute source index for the winning candidate in each query
260+ start_for_q = np .where (valid , starts [pos ], 0 ) # 0 is dummy for invalid queries
261+ src_best = (start_for_q + argmin_within ).astype (np .int64 )
251262
252263 # Write outputs only for queries that had candidates
253- # Pre-allocate i and j indices of the best match for each query
254- # Default values to -1 (no match case)
255- j_best = np .full (num_queries , - 1 , dtype = np .int64 )
256- i_best = np .full (num_queries , - 1 , dtype = np .int64 )
257264 has_hits = hit_counts > 0
258265 j_best [has_hits ] = j [src_best [has_hits ]]
259266 i_best [has_hits ] = i [src_best [has_hits ]]
0 commit comments