Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 39 additions & 32 deletions parcels/spatialhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,18 @@ def query(
# Locate each query in the unique key array
pos = np.searchsorted(keys, query_codes) # pos is shape (N,)

# Valid hits: inside range
valid = pos < len(keys)
# Valid hits: inside range with finite query coordinates
valid = (pos < len(keys)) & np.isfinite(x) & np.isfinite(y)

# Pre-allocate i and j indices of the best match for each query
# Default values to -1 (no match case)
j_best = np.full(num_queries, -1, dtype=np.int64)
i_best = np.full(num_queries, -1, dtype=np.int64)

# How many matches each query has; hit_counts[i] is the number of hits for query i
hit_counts = np.where(valid, counts[pos], 0).astype(np.int64) # has shape (N,)
if hit_counts.sum() == 0:
return (j_best.reshape(query_codes.shape), i_best.reshape(query_codes.shape))

# CSR-style offsets (prefix sum), total number of hits
offsets = np.empty(hit_counts.size + 1, dtype=np.int64)
Expand All @@ -220,40 +227,40 @@ def query(
j_all = j[source_idx]
i_all = i[source_idx]

# Gather centroid coordinates at those (j,i)
xc_all = xc[j_all, i_all]
yc_all = yc[j_all, i_all]
zc_all = zc[j_all, i_all]

# Broadcast to flat (same as q_flat), then repeat per candidate
# This makes it easy to compute distances from the query points
# to the candidate points for minimization.
qx_all = np.repeat(qx.ravel(), hit_counts)
qy_all = np.repeat(qy.ravel(), hit_counts)
qz_all = np.repeat(qz.ravel(), hit_counts)

# Squared distances for all candidates
dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2

# Segment-wise minima per query using reduceat
# For each query, we need to find the minimum distance.
dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1])

# To get argmin indices per query (without loops):
# Build a masked "within-index" array that is large unless it equals the segment-min.
big = np.iinfo(np.int64).max
within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big)
argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties

# Build absolute source index for the winning candidate in each query
start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries
src_best = (start_for_q + argmin_within).astype(np.int64)
if total == 1:
# Build absolute source index for the winning candidate in each query
start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries
src_best = (start_for_q).astype(np.int64)
else:
# Gather centroid coordinates at those (j,i)
xc_all = xc[j_all, i_all]
yc_all = yc[j_all, i_all]
zc_all = zc[j_all, i_all]

# Broadcast to flat (same as q_flat), then repeat per candidate
# This makes it easy to compute distances from the query points
# to the candidate points for minimization.
qx_all = np.repeat(qx.ravel(), hit_counts)
qy_all = np.repeat(qy.ravel(), hit_counts)
qz_all = np.repeat(qz.ravel(), hit_counts)

# Squared distances for all candidates
dist_all = (xc_all - qx_all) ** 2 + (yc_all - qy_all) ** 2 + (zc_all - qz_all) ** 2

dmin_per_q = np.minimum.reduceat(dist_all, offsets[:-1])
# To get argmin indices per query (without loops):
# Build a masked "within-index" array that is large unless it equals the segment-min.
big = np.iinfo(np.int64).max
within_masked = np.where(dist_all == np.repeat(dmin_per_q, hit_counts), intra, big)
argmin_within = np.minimum.reduceat(within_masked, offsets[:-1]) # first occurrence in ties

# Build absolute source index for the winning candidate in each query
start_for_q = np.where(valid, starts[pos], 0) # 0 is dummy for invalid queries
src_best = (start_for_q + argmin_within).astype(np.int64)

# Write outputs only for queries that had candidates
# Pre-allocate i and j indices of the best match for each query
# Default values to -1 (no match case)
j_best = np.full(num_queries, -1, dtype=np.int64)
i_best = np.full(num_queries, -1, dtype=np.int64)
has_hits = hit_counts > 0
j_best[has_hits] = j[src_best[has_hits]]
i_best[has_hits] = i[src_best[has_hits]]
Expand Down
25 changes: 25 additions & 0 deletions tests/v4/test_spatialhash.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from parcels._datasets.structured.generic import datasets
from parcels.xgrid import XGrid

Expand All @@ -7,3 +9,26 @@ def test_spatialhash_init():
grid = XGrid.from_dataset(ds)
spatialhash = grid.get_spatial_hash()
assert spatialhash is not None


def test_invalid_positions():
ds = datasets["2d_left_rotated"]
grid = XGrid.from_dataset(ds)

j, i = grid.get_spatial_hash().query([np.nan, np.inf], [np.nan, np.inf])
assert np.all(j == -1)
assert np.all(i == -1)


def test_mixed_positions():
ds = datasets["2d_left_rotated"]
grid = XGrid.from_dataset(ds)
lat = grid.lat.mean()
lon = grid.lon.mean()
y = [lat, np.nan]
x = [lon, np.nan]
j, i = grid.get_spatial_hash().query(y, x)
assert j[0] == 29 # Actual value for 2d_left_rotated center
assert i[0] == 14 # Actual value for 2d_left_rotated center
assert j[1] == -1
assert i[1] == -1
Loading