@@ -164,27 +164,63 @@ def _initialize_hash_table(self):
164164 nz = zqhigh - zqlow + 1
165165 num_hash_per_face = nx * ny * nz
166166 total_hash_entries = np .sum (num_hash_per_face )
167-
168167 morton_codes = np .zeros (total_hash_entries , dtype = np .uint32 )
169168
170169 # Compute the j, i indices corresponding to each hash entry
171170 nface = np .size (self ._xlow )
172171 face_ids = np .repeat (np .arange (nface , dtype = np .int64 ), num_hash_per_face )
173- offsets = np .concatenate (([0 ], np .cumsum (num_hash_per_face ))).astype (np .int64 )
174-
175- for k in range (len (num_hash_per_face )):
176- if num_hash_per_face [k ] == 0 :
177- continue
178- start , end = offsets [k ], offsets [k + 1 ]
179- # Local sizes
180- nxk , nyk , nzk = int (nx [k ]), int (ny [k ]), int (nz [k ])
181-
182- # Build the Cartesian product
183- xq_block = xqlow [k ] + np .repeat (np .arange (nxk ), nyk * nzk )
184- yq_block = yqlow [k ] + np .tile (np .repeat (np .arange (nyk ), nzk ), nxk )
185- zq_block = zqlow [k ] + np .tile (np .arange (nzk ), nxk * nyk )
186-
187- morton_codes [start :end ] = _encode_quantized_morton3d (xq_block , yq_block , zq_block )
172+ offsets = np .concatenate (([0 ], np .cumsum (num_hash_per_face ))).astype (np .int64 )[:- 1 ]
173+
174+ valid = num_hash_per_face != 0
175+ if not np .any (valid ):
176+ # nothing to do
177+ pass
178+ else :
179+ # Grab only valid faces to avoid empty arrays
180+ nx_v = np .asarray (nx [valid ], dtype = np .int64 )
181+ ny_v = np .asarray (ny [valid ], dtype = np .int64 )
182+ nz_v = np .asarray (nz [valid ], dtype = np .int64 )
183+ xlow_v = np .asarray (xqlow [valid ], dtype = np .int64 )
184+ ylow_v = np .asarray (yqlow [valid ], dtype = np .int64 )
185+ zlow_v = np .asarray (zqlow [valid ], dtype = np .int64 )
186+ starts_v = np .asarray (offsets [valid ], dtype = np .int64 )
187+
188+ # Count of elements per valid face (should match num_hash_per_face[valid])
189+ counts = (nx_v * ny_v * nz_v ).astype (np .int64 )
190+ total = int (counts .sum ())
191+
192+ # Map each global element to its face and output position
193+ start_for_elem = np .repeat (starts_v , counts ) # shape (total,)
194+
195+ # Intra-face linear index for each element (0..counts_i-1)
196+ # Offsets per face within the concatenation of valid faces:
197+ face_starts_local = np .cumsum (np .r_ [0 , counts [:- 1 ]])
198+ intra = np .arange (total , dtype = np .int64 ) - np .repeat (face_starts_local , counts )
199+
200+ # Derive (zi, yi, xi) from intra using per-face sizes
201+ ny_nz = np .repeat (ny_v * nz_v , counts )
202+ nz_rep = np .repeat (nz_v , counts )
203+
204+ xi = intra // ny_nz
205+ rem = intra % ny_nz
206+ yi = rem // nz_rep
207+ zi = rem % nz_rep
208+
209+ # Add per-face lows
210+ x0 = np .repeat (xlow_v , counts )
211+ y0 = np .repeat (ylow_v , counts )
212+ z0 = np .repeat (zlow_v , counts )
213+
214+ xq = x0 + xi
215+ yq = y0 + yi
216+ zq = z0 + zi
217+
218+ # Vectorized morton encode for all elements at once
219+ codes_all = _encode_quantized_morton3d (xq , yq , zq )
220+
221+ # Scatter into the preallocated output using computed absolute indices
222+ out_idx = start_for_elem + intra
223+ morton_codes [out_idx ] = codes_all
188224
189225 # Sort face indices by morton code
190226 order = np .argsort (morton_codes )
@@ -194,6 +230,7 @@ def _initialize_hash_table(self):
194230
195231 # Get a list of unique morton codes and their corresponding starts and counts (CSR format)
196232 keys , starts , counts = np .unique (morton_codes_sorted , return_index = True , return_counts = True )
233+
197234 hash_table = {
198235 "keys" : keys ,
199236 "starts" : starts ,
@@ -458,8 +495,8 @@ def quantize_coordinates(x, y, z, xmin, xmax, ymin, ymax, zmin, zmax, bitwidth=1
458495 yn = np .where (dy != 0 , (y - ymin ) / dy , 0.0 )
459496 zn = np .where (dz != 0 , (z - zmin ) / dz , 0.0 )
460497
461- # --- 2) Quantize to 10 bits (0..1023 ). ---
462- # Multiply by 1023 , round down, and clip to be safe against slight overshoot.
498+ # --- 2) Quantize to (0..bitwidth ). ---
499+ # Multiply by bitwidth , round down, and clip to be safe against slight overshoot.
463500 xq = np .clip ((xn * bitwidth ).astype (np .uint32 ), 0 , bitwidth )
464501 yq = np .clip ((yn * bitwidth ).astype (np .uint32 ), 0 , bitwidth )
465502 zq = np .clip ((zn * bitwidth ).astype (np .uint32 ), 0 , bitwidth )
0 commit comments