diff --git a/parcels/spatialhash.py b/parcels/spatialhash.py index dddba7432..76447bc23 100644 --- a/parcels/spatialhash.py +++ b/parcels/spatialhash.py @@ -192,11 +192,18 @@ def query( # Locate each query in the unique key array pos = np.searchsorted(keys, query_codes) # pos is shape (N,) - # Valid hits: inside range - valid = pos < len(keys) + # Valid hits: inside range with finite query coordinates + valid = (pos < len(keys)) & np.isfinite(x) & np.isfinite(y) + + # Pre-allocate i and j indices of the best match for each query + # Default values to -1 (no match case) + j_best = np.full(num_queries, -1, dtype=np.int64) + i_best = np.full(num_queries, -1, dtype=np.int64) # How many matches each query has; hit_counts[i] is the number of hits for query i hit_counts = np.where(valid, counts[pos], 0).astype(np.int64) # has shape (N,) + if hit_counts.sum() == 0: + return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape)) # CSR-style offsets (prefix sum), total number of hits offsets = np.empty(hit_counts.size + 1, dtype=np.int64) @@ -220,40 +227,40 @@ def query( j_all = j[source_idx] i_all = i[source_idx] - # Gather centroid coordinates at those (j,i) - xc_all = xc[j_all, i_all] - yc_all = yc[j_all, i_all] - zc_all = zc[j_all, i_all] - - # Broadcast to flat (same as q_flat), then repeat per candidate - # This makes it easy to compute distances from the query points - # to the candidate points for minimization. - qx_all = np.repeat(qx.ravel(), hit_counts) - qy_all = np.repeat(qy.ravel(), hit_counts) - qz_all = np.repeat(qz.ravel(), hit_counts) - - # Squared distances for all candidates - dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2 - # Segment-wise minima per query using reduceat # For each query, we need to find the minimum distance. - dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1]) - - # To get argmin indices per query (without loops): - # Build a masked "within-index" array that is large unless it equals the segment-min. - big = np.iinfo(np.int64).max - within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big) - argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties - - # Build absolute source index for the winning candidate in each query - start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries - src_best = (start_for_q + argmin_within).astype(np.int64) + if total == 1: + # Build absolute source index for the winning candidate in each query + start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries + src_best = (start_for_q).astype(np.int64) + else: + # Gather centroid coordinates at those (j,i) + xc_all = xc[j_all, i_all] + yc_all = yc[j_all, i_all] + zc_all = zc[j_all, i_all] + + # Broadcast to flat (same as q_flat), then repeat per candidate + # This makes it easy to compute distances from the query points + # to the candidate points for minimization. + qx_all = np.repeat(qx.ravel(), hit_counts) + qy_all = np.repeat(qy.ravel(), hit_counts) + qz_all = np.repeat(qz.ravel(), hit_counts) + + # Squared distances for all candidates + dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2 + + dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1]) + # To get argmin indices per query (without loops): + # Build a masked "within-index" array that is large unless it equals the segment-min. + big = np.iinfo(np.int64).max + within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big) + argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties + + # Build absolute source index for the winning candidate in each query + start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries + src_best = (start_for_q + argmin_within).astype(np.int64) # Write outputs only for queries that had candidates - # Pre-allocate i and j indices of the best match for each query - # Default values to -1 (no match case) - j_best = np.full(num_queries, -1, dtype=np.int64) - i_best = np.full(num_queries, -1, dtype=np.int64) has_hits = hit_counts > 0 j_best[has_hits] = j[src_best[has_hits]] i_best[has_hits] = i[src_best[has_hits]] diff --git a/tests/v4/test_spatialhash.py b/tests/v4/test_spatialhash.py index 9d958507f..c9026e7af 100644 --- a/tests/v4/test_spatialhash.py +++ b/tests/v4/test_spatialhash.py @@ -1,3 +1,5 @@ +import numpy as np + from parcels._datasets.structured.generic import datasets from parcels.xgrid import XGrid @@ -7,3 +9,26 @@ def test_spatialhash_init(): grid = XGrid.from_dataset(ds) spatialhash = grid.get_spatial_hash() assert spatialhash is not None + + +def test_invalid_positions(): + ds = datasets["2d_left_rotated"] + grid = XGrid.from_dataset(ds) + + j, i = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf]) + assert np.all(j == -1) + assert np.all(i == -1) + + +def test_mixed_positions(): + ds = datasets["2d_left_rotated"] + grid = XGrid.from_dataset(ds) + lat = grid.lat.mean() + lon = grid.lon.mean() + y = [lat, np.nan] + x = [lon, np.nan] + j, i = grid.get_spatial_hash().query(y, x) + assert j[0] == 29 # Actual value for 2d_left_rotated center + assert i[0] == 14 # Actual value for 2d_left_rotated center + assert j[1] == -1 + assert i[1] == -1