Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/cuda_cccl/cuda/compute/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
DiscardIterator,
PermutationIterator,
ReverseIterator,
ShuffleIterator,
TransformIterator,
TransformOutputIterator,
ZipIterator,
Expand Down Expand Up @@ -83,6 +84,7 @@
"segmented_reduce",
"segmented_sort",
"select",
"ShuffleIterator",
"SortOrder",
"TransformIterator",
"TransformOutputIterator",
Expand Down
2 changes: 2 additions & 0 deletions python/cuda_cccl/cuda/compute/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DiscardIterator,
PermutationIterator,
ReverseIterator,
ShuffleIterator,
TransformIterator,
TransformOutputIterator,
ZipIterator,
Expand All @@ -17,6 +18,7 @@
"DiscardIterator",
"PermutationIterator",
"ReverseIterator",
"ShuffleIterator",
"TransformIterator",
"TransformOutputIterator",
"ZipIterator",
Expand Down
28 changes: 25 additions & 3 deletions python/cuda_cccl/cuda/compute/iterators/_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
make_transform_iterator,
)
from ._permutation_iterator import make_permutation_iterator
from ._shuffle_iterator import make_shuffle_iterator
from ._zip_iterator import make_zip_iterator


Expand Down Expand Up @@ -219,14 +220,35 @@ def PermutationIterator(values, indices):
return make_permutation_iterator(values, indices)


def ShuffleIterator(num_items, seed, rounds=8):
"""Iterator that produces a deterministic "random" permutation of indices in ``[0, num_items)``.

Example:
The code snippet below demonstrates the usage of a ``ShuffleIterator``
to randomly permute indices:

.. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/iterator/shuffle_iterator_basic.py
:language: python
:start-after: # example-begin

Args:
num_items: Number of elements in the domain to permute
seed: Seed used to parameterize the permutation
rounds: Number of Feistel rounds to use (default: 8)

Returns:
A ``ShuffleIterator`` object that yields shuffled indices
"""
return make_shuffle_iterator(num_items, seed, rounds)


def ZipIterator(*iterators):
"""Returns an Iterator representing a zipped sequence of values from N iterators.

Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1zip__iterator.html

The resulting iterator yields gpu_struct objects with fields corresponding to each input iterator.
For 2 iterators, fields are named 'first' and 'second'. For N iterators, fields are indexed
as field_0, field_1, ..., field_N-1.
The resulting iterator structs with fields corresponding to each input iterator.
Fields can be accessed by index using `[]`.

Example:
The code snippet below demonstrates the usage of a ``ZipIterator``
Expand Down
216 changes: 216 additions & 0 deletions python/cuda_cccl/cuda/compute/iterators/_shuffle_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception


from numba import cuda, int64, uint64

from .._caching import cache_with_key
from ._iterators import (
CountingIterator as _CountingIterator,
)
from ._iterators import (
make_transform_iterator,
)

# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------

# SplitMix64 step (≈ 2^64 / φ)
SPLITMIX64_GAMMA = 0x9E3779B97F4A7C15

# SplitMix64 avalanche multipliers
SPLITMIX64_MUL1 = 0xBF58476D1CE4E5B9
SPLITMIX64_MUL2 = 0x94D049BB133111EB

# Per-round constant to decorrelate Feistel rounds (any odd 64-bit constant works)
FEISTEL_ROUND_C = 0xD6E8FEB86659FD93


@cuda.jit(device=True, inline=True)
def _mix64(z):
"""
SplitMix64-style 64-bit mixing function.

Used as the round function core inside the Feistel network.
"""
z = uint64(z)
z ^= z >> uint64(30)
z = uint64(z * uint64(SPLITMIX64_MUL1))
z ^= z >> uint64(27)
z = uint64(z * uint64(SPLITMIX64_MUL2))
z ^= z >> uint64(31)
return z


@cuda.jit(device=True, inline=True)
def _feistel_balanced(x, key, half_bits, half_mask, rounds):
"""
Balanced Feistel permutation over 2 * half_bits bits.

The input domain is [0, 2^(2*half_bits)).
This function defines a bijection on that domain.
"""
hb = uint64(half_bits)

# Split x into equal-width halves
L = x & half_mask
R = (x >> hb) & half_mask

for rnd in range(rounds):
# Round function F(R) -> half_bits bits
z = R ^ key ^ uint64(rnd * FEISTEL_ROUND_C)
F = _mix64(z) & half_mask

# Standard Feistel step
new_L = R
new_R = (L ^ F) & half_mask
L = new_L
R = new_R

return (R << hb) | L


def _splitmix64_host(x: int) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using splitmix is in theory fine, but we should use the VariablePhilox algorithm from this paper: https://arxiv.org/abs/2106.06161. This implemention needs to round up the sequence to the nearest power of 4, the one from the paper is the nearest power of 2. And its tested.

"""
Host-side SplitMix64 used to derive a 64-bit key from the seed.
"""
x &= (1 << 64) - 1
x = (x + SPLITMIX64_GAMMA) & ((1 << 64) - 1)
z = x
z ^= z >> 30
z = (z * SPLITMIX64_MUL1) & ((1 << 64) - 1)
z ^= z >> 27
z = (z * SPLITMIX64_MUL2) & ((1 << 64) - 1)
z ^= z >> 31
return z & ((1 << 64) - 1)


def _make_cache_key(num_items: int, seed: int, rounds: int):
return (num_items, seed, rounds)


@cache_with_key(_make_cache_key)
def make_shuffle_iterator(num_items: int, seed: int, rounds: int = 8):
"""
Iterator that produces a deterministic "random" permutation
of indices in ``[0, num_items)``.



Parameters
----------
num_items : int
Number of elements in the domain to permute.
seed : int
Seed used to parameterize the permutation. Different seeds produce
different (deterministic) permutations.
rounds : int, optional
Number of Feistel rounds to use. More rounds improve diffusion at the
cost of additional arithmetic. Typical values are 6–10.

Returns
-------
TransformIterator
An iterator that yields a shuffled ordering of indices in
``[0, num_items)``.


Notes
-----
This iterator does **not** materialize a permutation table. Instead, it
computes each permuted index on demand using a *stateless bijection* derived
from a fixed seed.

The iterator is implemented as::

TransformIterator(CountingIterator(0), permute)

where ``permute(i)`` is a pure function that maps ``i`` to a unique value in
``[0, num_items)``.

The permutation is constructed as follows:

1. Let ``k = ceil(log2(num_items))`` and ``h = ceil(k / 2)``.
We define a permutation over ``2^(2h)`` elements (a power-of-two domain
large enough to cover ``[0, num_items)``).

2. A **balanced Feistel network** with ``h``-bit halves is used to define a
bijection over this ``2^(2h)`` domain. Each Feistel round applies a simple,
fast mixing function (SplitMix64-style) keyed by ``seed`` and the round
index.

3. To restrict the permutation to ``[0, num_items)``, **cycle-walking** is
used: the Feistel permutation is repeatedly applied until the result lies
within ``[0, num_items)``. This preserves bijectivity on the restricted
domain.

Properties
----------
- **Bijective on ``[0, num_items)``**: every index appears exactly once.
- **Deterministic**: the same ``num_items`` and ``seed`` always produce the
same ordering.
- **Stateless**: no per-element or per-thread state is required.
- **Lazy**: indices are computed on demand; no permutation buffer is stored.
- **Device-friendly**: implemented using simple integer arithmetic and
device-callable functions.

Limitations
-----------
- The resulting permutation is *not* uniformly sampled from all
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement kind of feels like it might be correct but I think its nonsense.

``num_items!`` permutations. It is drawn from a large, structured family
of permutations induced by the Feistel construction.
- Cycle-walking may apply the Feistel permutation more than once per element
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the user doesn't care about the implementation detail of cycle walking. One consequence of this so called "cycle-walking" is the the worst case runtime for generating an element is O(n) where n is the permutation length - it is just vanishing unlikely. I am not sure we even need to mention this.

when ``num_items`` is far from a power of two, though the expected number
of iterations is close to 1.
"""
if num_items <= 0:
raise ValueError("num_items must be > 0")

if rounds < 6:
rounds = 6

m = int(num_items)

# k = ceil(log2(m))
k = (m - 1).bit_length()

# balanced halves: total_bits = 2 * h >= k
h = (k + 1) // 2
total_bits = 2 * h

if total_bits > 63:
raise ValueError("num_items too large for uint64-based shuffle iterator")

half_mask = (1 << h) - 1
full_mask = (1 << total_bits) - 1

key = _splitmix64_host(int(seed))

# Closure capturing only constants; device-callable helpers do the work
def permute(i):
mm = uint64(m)
x = uint64(i) & uint64(full_mask)

y = _feistel_balanced(
x,
uint64(key),
h,
uint64(half_mask),
rounds,
)

# Cycle-walk into [0, m)
while y >= mm:
y = _feistel_balanced(
y,
uint64(key),
h,
uint64(half_mask),
rounds,
)

return int64(y)

return make_transform_iterator(_CountingIterator(int64(0)), permute, "input")
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# example-begin
"""
Demonstrate using ShuffleIterator for deterministic random permutation of indices in ``[0, num_items)``.
"""

import cupy as cp
import numpy as np

import cuda.compute
from cuda.compute import (
OpKind,
PermutationIterator,
ShuffleIterator,
)

# Create a shuffle iterator for 10 elements with a fixed seed
num_items = 10
seed = 42
shuffle_it = ShuffleIterator(num_items, seed)

# Collect the shuffled indices using unary_transform
d_indices = cp.empty(num_items, dtype=np.int64)
cuda.compute.unary_transform(shuffle_it, d_indices, lambda x: x, num_items)

print(f"Shuffled indices: {d_indices.get()}")
# Verify it's a valid permutation (all indices 0 to num_items-1 appear exactly once)
assert set(d_indices.get()) == set(range(num_items))

# Use ShuffleIterator with PermutationIterator to access data in shuffled order
d_values = cp.asarray([10, 20, 30, 40, 50, 60, 70, 80, 90, 100], dtype=np.int32)

# Create a new shuffle iterator (same seed for same order)
shuffle_it2 = ShuffleIterator(num_items, seed)

# Combine with PermutationIterator to access values in shuffled order
perm_it = PermutationIterator(d_values, shuffle_it2)

# Reduce the shuffled values - sum should equal sum of all values
h_init = np.array([0], dtype=np.int32)
d_output = cp.empty(1, dtype=np.int32)

cuda.compute.reduce_into(perm_it, d_output, OpKind.PLUS, num_items, h_init)

# Since shuffle is a permutation, sum equals sum of all values
expected_sum = d_values.sum()
print(f"Sum of shuffled values: {d_output[0]} (expected: {expected_sum})")
assert d_output[0] == expected_sum

# Different seeds produce different permutations
shuffle_it_a = ShuffleIterator(num_items, seed=1)
shuffle_it_b = ShuffleIterator(num_items, seed=2)

d_perm_a = cp.empty(num_items, dtype=np.int64)
d_perm_b = cp.empty(num_items, dtype=np.int64)

cuda.compute.unary_transform(shuffle_it_a, d_perm_a, lambda x: x, num_items)
cuda.compute.unary_transform(shuffle_it_b, d_perm_b, lambda x: x, num_items)

print(f"Permutation with seed=1: {d_perm_a.get()}")
print(f"Permutation with seed=2: {d_perm_b.get()}")
assert not np.array_equal(d_perm_a.get(), d_perm_b.get())
Loading
Loading