Skip to content

Commit 3fb4473

Browse files
[#2162] Add proper handling for invalid queries in morton coding
This commit also adds a test for queries with all invalid positions and one with a mix of valid and invalid queries.
1 parent 1c638c8 commit 3fb4473

File tree

2 files changed

+64
-32
lines changed

2 files changed

+64
-32
lines changed

parcels/spatialhash.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,18 @@ def query(
192192
# Locate each query in the unique key array
193193
pos = np.searchsorted(keys, query_codes) # pos is shape (N,)
194194

195-
# Valid hits: inside range
196-
valid = pos < len(keys)
195+
# Valid hits: inside range with finite query coordinates
196+
valid = (pos < len(keys)) & np.isfinite(x) & np.isfinite(y)
197+
198+
# Pre-allocate i and j indices of the best match for each query
199+
# Default values to -1 (no match case)
200+
j_best = np.full(num_queries, -1, dtype=np.int64)
201+
i_best = np.full(num_queries, -1, dtype=np.int64)
197202

198203
# How many matches each query has; hit_counts[i] is the number of hits for query i
199204
hit_counts = np.where(valid, counts[pos], 0).astype(np.int64) # has shape (N,)
205+
if hit_counts.sum() == 0:
206+
return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape))
200207

201208
# CSR-style offsets (prefix sum), total number of hits
202209
offsets = np.empty(hit_counts.size + 1, dtype=np.int64)
@@ -220,40 +227,40 @@ def query(
220227
j_all = j[source_idx]
221228
i_all = i[source_idx]
222229

223-
# Gather centroid coordinates at those (j,i)
224-
xc_all = xc[j_all, i_all]
225-
yc_all = yc[j_all, i_all]
226-
zc_all = zc[j_all, i_all]
227-
228-
# Broadcast to flat (same as q_flat), then repeat per candidate
229-
# This makes it easy to compute distances from the query points
230-
# to the candidate points for minimization.
231-
qx_all = np.repeat(qx.ravel(), hit_counts)
232-
qy_all = np.repeat(qy.ravel(), hit_counts)
233-
qz_all = np.repeat(qz.ravel(), hit_counts)
234-
235-
# Squared distances for all candidates
236-
dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2
237-
238230
# Segment-wise minima per query using reduceat
239231
# For each query, we need to find the minimum distance.
240-
dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1])
241-
242-
# To get argmin indices per query (without loops):
243-
# Build a masked "within-index" array that is large unless it equals the segment-min.
244-
big = np.iinfo(np.int64).max
245-
within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big)
246-
argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties
247-
248-
# Build absolute source index for the winning candidate in each query
249-
start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries
250-
src_best = (start_for_q + argmin_within).astype(np.int64)
232+
if total == 1:
233+
# Build absolute source index for the winning candidate in each query
234+
start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries
235+
src_best = (start_for_q).astype(np.int64)
236+
else:
237+
# Gather centroid coordinates at those (j,i)
238+
xc_all = xc[j_all, i_all]
239+
yc_all = yc[j_all, i_all]
240+
zc_all = zc[j_all, i_all]
241+
242+
# Broadcast to flat (same as q_flat), then repeat per candidate
243+
# This makes it easy to compute distances from the query points
244+
# to the candidate points for minimization.
245+
qx_all = np.repeat(qx.ravel(), hit_counts)
246+
qy_all = np.repeat(qy.ravel(), hit_counts)
247+
qz_all = np.repeat(qz.ravel(), hit_counts)
248+
249+
# Squared distances for all candidates
250+
dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2
251+
252+
dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1])
253+
# To get argmin indices per query (without loops):
254+
# Build a masked "within-index" array that is large unless it equals the segment-min.
255+
big = np.iinfo(np.int64).max
256+
within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big)
257+
argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties
258+
259+
# Build absolute source index for the winning candidate in each query
260+
start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries
261+
src_best = (start_for_q + argmin_within).astype(np.int64)
251262

252263
# Write outputs only for queries that had candidates
253-
# Pre-allocate i and j indices of the best match for each query
254-
# Default values to -1 (no match case)
255-
j_best = np.full(num_queries, -1, dtype=np.int64)
256-
i_best = np.full(num_queries, -1, dtype=np.int64)
257264
has_hits = hit_counts > 0
258265
j_best[has_hits] = j[src_best[has_hits]]
259266
i_best[has_hits] = i[src_best[has_hits]]

tests/v4/test_spatialhash.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from parcels._datasets.structured.generic import datasets
24
from parcels.xgrid import XGrid
35

@@ -7,3 +9,26 @@ def test_spatialhash_init():
79
grid = XGrid.from_dataset(ds)
810
spatialhash = grid.get_spatial_hash()
911
assert spatialhash is not None
12+
13+
14+
def test_invalid_positions():
15+
ds = datasets["2d_left_rotated"]
16+
grid = XGrid.from_dataset(ds)
17+
18+
j, i = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf])
19+
assert np.all(j == -1)
20+
assert np.all(i == -1)
21+
22+
23+
def test_mixed_positions():
24+
ds = datasets["2d_left_rotated"]
25+
grid = XGrid.from_dataset(ds)
26+
lat = grid.lat.mean()
27+
lon = grid.lon.mean()
28+
y = [lat, np.nan]
29+
x = [lon, np.nan]
30+
j, i = grid.get_spatial_hash().query(y, x)
31+
assert j[0] == 29 # Actual value for 2d_left_rotated center
32+
assert i[0] == 14 # Actual value for 2d_left_rotated center
33+
assert j[1] == -1
34+
assert i[1] == -1

0 commit comments

Comments
 (0)