Skip to content

Commit 0cbbbc9

Browse files
Vectorize hash table initialization
1 parent f9e044b commit 0cbbbc9

File tree

1 file changed

+55
-18
lines changed

1 file changed

+55
-18
lines changed

parcels/spatialhash.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,27 +164,63 @@ def _initialize_hash_table(self):
164164
nz = zqhigh - zqlow + 1
165165
num_hash_per_face = nx * ny * nz
166166
total_hash_entries = np.sum(num_hash_per_face)
167-
168167
morton_codes = np.zeros(total_hash_entries, dtype=np.uint32)
169168

170169
# Compute the j, i indices corresponding to each hash entry
171170
nface = np.size(self._xlow)
172171
face_ids = np.repeat(np.arange(nface, dtype=np.int64), num_hash_per_face)
173-
offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int64)
174-
175-
for k in range(len(num_hash_per_face)):
176-
if num_hash_per_face[k] == 0:
177-
continue
178-
start, end = offsets[k], offsets[k + 1]
179-
# Local sizes
180-
nxk, nyk, nzk = int(nx[k]), int(ny[k]), int(nz[k])
181-
182-
# Build the Cartesian product
183-
xq_block = xqlow[k] + np.repeat(np.arange(nxk), nyk * nzk)
184-
yq_block = yqlow[k] + np.tile(np.repeat(np.arange(nyk), nzk), nxk)
185-
zq_block = zqlow[k] + np.tile(np.arange(nzk), nxk * nyk)
186-
187-
morton_codes[start:end] = _encode_quantized_morton3d(xq_block, yq_block, zq_block)
172+
offsets = np.concatenate(([0], np.cumsum(num_hash_per_face))).astype(np.int64)[:-1]
173+
174+
valid = num_hash_per_face != 0
175+
if not np.any(valid):
176+
# nothing to do
177+
pass
178+
else:
179+
# Grab only valid faces to avoid empty arrays
180+
nx_v = np.asarray(nx[valid], dtype=np.int64)
181+
ny_v = np.asarray(ny[valid], dtype=np.int64)
182+
nz_v = np.asarray(nz[valid], dtype=np.int64)
183+
xlow_v = np.asarray(xqlow[valid], dtype=np.int64)
184+
ylow_v = np.asarray(yqlow[valid], dtype=np.int64)
185+
zlow_v = np.asarray(zqlow[valid], dtype=np.int64)
186+
starts_v = np.asarray(offsets[valid], dtype=np.int64)
187+
188+
# Count of elements per valid face (should match num_hash_per_face[valid])
189+
counts = (nx_v * ny_v * nz_v).astype(np.int64)
190+
total = int(counts.sum())
191+
192+
# Map each global element to its face and output position
193+
start_for_elem = np.repeat(starts_v, counts) # shape (total,)
194+
195+
# Intra-face linear index for each element (0..counts_i-1)
196+
# Offsets per face within the concatenation of valid faces:
197+
face_starts_local = np.cumsum(np.r_[0, counts[:-1]])
198+
intra = np.arange(total, dtype=np.int64) - np.repeat(face_starts_local, counts)
199+
200+
# Derive (zi, yi, xi) from intra using per-face sizes
201+
ny_nz = np.repeat(ny_v * nz_v, counts)
202+
nz_rep = np.repeat(nz_v, counts)
203+
204+
xi = intra // ny_nz
205+
rem = intra % ny_nz
206+
yi = rem // nz_rep
207+
zi = rem % nz_rep
208+
209+
# Add per-face lows
210+
x0 = np.repeat(xlow_v, counts)
211+
y0 = np.repeat(ylow_v, counts)
212+
z0 = np.repeat(zlow_v, counts)
213+
214+
xq = x0 + xi
215+
yq = y0 + yi
216+
zq = z0 + zi
217+
218+
# Vectorized morton encode for all elements at once
219+
codes_all = _encode_quantized_morton3d(xq, yq, zq)
220+
221+
# Scatter into the preallocated output using computed absolute indices
222+
out_idx = start_for_elem + intra
223+
morton_codes[out_idx] = codes_all
188224

189225
# Sort face indices by morton code
190226
order = np.argsort(morton_codes)
@@ -194,6 +230,7 @@ def _initialize_hash_table(self):
194230

195231
# Get a list of unique morton codes and their corresponding starts and counts (CSR format)
196232
keys, starts, counts = np.unique(morton_codes_sorted, return_index=True, return_counts=True)
233+
197234
hash_table = {
198235
"keys": keys,
199236
"starts": starts,
@@ -458,8 +495,8 @@ def quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1
458495
yn = np.where(dy != 0, (y - ymin) / dy, 0.0)
459496
zn = np.where(dz != 0, (z - zmin) / dz, 0.0)
460497

461-
# --- 2) Quantize to 10 bits (0..1023). ---
462-
# Multiply by 1023, round down, and clip to be safe against slight overshoot.
498+
# --- 2) Quantize to (0..bitwidth). ---
499+
# Multiply by bitwidth, round down, and clip to be safe against slight overshoot.
463500
xq = np.clip((xn * bitwidth).astype(np.uint32), 0, bitwidth)
464501
yq = np.clip((yn * bitwidth).astype(np.uint32), 0, bitwidth)
465502
zq = np.clip((zn * bitwidth).astype(np.uint32), 0, bitwidth)

0 commit comments

Comments
 (0)