Skip to content

Commit 69ae3e5

Browse files
committed
Enhance unique_mask function by refining sorting keys for lexsort. Introduce a stable tie-breaking mechanism using original indices and clarify the role of primary and secondary keys in the sorting process, improving the function's efficiency in handling unique states.
1 parent 08ecc2a commit 69ae3e5

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

xtructure/core/xtructure_numpy/dataclass_ops/unique_ops/optimized_unique_ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,21 +118,23 @@ def key_fn(x):
118118
hashes[i] = jnp.where(filled, hashes[i], max_val)
119119

120120
# 4. Prepare columns for lexsort
121+
# Lexsort expects [least_significant, ..., most_significant]
121122
sort_keys = []
122123

124+
# Tertiary key: original index (stable tie-breaking if hashes and keys are equal)
125+
sort_keys.append(jnp.arange(batch_len, dtype=jnp.int32))
126+
123127
if key is not None:
128+
# Secondary key: cost (minimize cost within same hash group)
124129
if filled is not None:
125130
inf_fill = jnp.array(jnp.inf, dtype=key.dtype)
126131
valid_key = jnp.where(filled, key, inf_fill)
127132
else:
128133
valid_key = key
129134
sort_keys.append(valid_key)
130-
else:
131-
# Stable sort: prefer earlier index
132-
sort_keys.append(jnp.arange(batch_len, dtype=jnp.int32))
133135

134-
# Add hashes as primary (last in lexsort list).
135-
# Reversed so the first hash is most significant.
136+
# Primary key: hashes (to group identical states)
137+
# Reversed so the first hash is the most significant sorting factor.
136138
sort_keys.extend(reversed(hashes))
137139

138140
# 5. Perform Lexsort

0 commit comments

Comments
 (0)