diff --git a/python/cuda_cccl/cuda/compute/__init__.py b/python/cuda_cccl/cuda/compute/__init__.py index 854beee17ba..e9e8782a1fc 100644 --- a/python/cuda_cccl/cuda/compute/__init__.py +++ b/python/cuda_cccl/cuda/compute/__init__.py @@ -41,6 +41,7 @@ DiscardIterator, PermutationIterator, ReverseIterator, + ShuffleIterator, TransformIterator, TransformOutputIterator, ZipIterator, @@ -83,6 +84,7 @@ "segmented_reduce", "segmented_sort", "select", + "ShuffleIterator", "SortOrder", "TransformIterator", "TransformOutputIterator", diff --git a/python/cuda_cccl/cuda/compute/iterators/__init__.py b/python/cuda_cccl/cuda/compute/iterators/__init__.py index ca207186d2a..dd628c1ac2e 100644 --- a/python/cuda_cccl/cuda/compute/iterators/__init__.py +++ b/python/cuda_cccl/cuda/compute/iterators/__init__.py @@ -5,6 +5,7 @@ DiscardIterator, PermutationIterator, ReverseIterator, + ShuffleIterator, TransformIterator, TransformOutputIterator, ZipIterator, @@ -17,6 +18,7 @@ "DiscardIterator", "PermutationIterator", "ReverseIterator", + "ShuffleIterator", "TransformIterator", "TransformOutputIterator", "ZipIterator", diff --git a/python/cuda_cccl/cuda/compute/iterators/_factories.py b/python/cuda_cccl/cuda/compute/iterators/_factories.py index 5023e660600..8308f4e0b76 100644 --- a/python/cuda_cccl/cuda/compute/iterators/_factories.py +++ b/python/cuda_cccl/cuda/compute/iterators/_factories.py @@ -17,6 +17,7 @@ make_transform_iterator, ) from ._permutation_iterator import make_permutation_iterator +from ._shuffle_iterator import make_shuffle_iterator from ._zip_iterator import make_zip_iterator @@ -219,14 +220,34 @@ def PermutationIterator(values, indices): return make_permutation_iterator(values, indices) +def ShuffleIterator(num_items, seed): + """Iterator that produces a deterministic "random" permutation of indices in ``[0, num_items)``. + + Example: + The code snippet below demonstrates the usage of a ``ShuffleIterator`` + to randomly permute indices: + + .. literalinclude:: ../../python/cuda_cccl/tests/compute/examples/iterator/shuffle_iterator_basic.py + :language: python + :start-after: # example-begin + + Args: + num_items: Number of elements in the domain to permute + seed: Seed used to parameterize the permutation + + Returns: + A ``ShuffleIterator`` object that yields shuffled indices + """ + return make_shuffle_iterator(num_items, seed) + + def ZipIterator(*iterators): """Returns an Iterator representing a zipped sequence of values from N iterators. Similar to https://nvidia.github.io/cccl/thrust/api/classthrust_1_1zip__iterator.html - The resulting iterator yields gpu_struct objects with fields corresponding to each input iterator. - For 2 iterators, fields are named 'first' and 'second'. For N iterators, fields are indexed - as field_0, field_1, ..., field_N-1. + The resulting iterator structs with fields corresponding to each input iterator. + Fields can be accessed by index using `[]`. Example: The code snippet below demonstrates the usage of a ``ZipIterator`` diff --git a/python/cuda_cccl/cuda/compute/iterators/_shuffle_iterator.py b/python/cuda_cccl/cuda/compute/iterators/_shuffle_iterator.py new file mode 100644 index 00000000000..32d6b8ea6b8 --- /dev/null +++ b/python/cuda_cccl/cuda/compute/iterators/_shuffle_iterator.py @@ -0,0 +1,252 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import ctypes + +from numba import cuda, int64, types, uint32, uint64 + +from .._caching import cache_with_key +from ._iterators import IteratorBase, IteratorKind + +# Number of Feistel rounds (matches C++ __feistel_bijection) +NUM_ROUNDS = 24 + +# Feistel multiplier (same as C++) +FEISTEL_M0 = 0xD2B74407B1CE6E93 + +# SplitMix64 constants for key derivation +SPLITMIX64_GAMMA = 0x9E3779B97F4A7C15 +SPLITMIX64_MUL1 = 0xBF58476D1CE4E5B9 +SPLITMIX64_MUL2 = 0x94D049BB133111EB + + +@cuda.jit(device=True, inline=True) +def _splitmix64_next(state): + """ + Device-side SplitMix64 step. Returns (next_state, output). + Used to generate independent keys from a seed. + """ + GAMMA = uint64(SPLITMIX64_GAMMA) + MUL1 = uint64(SPLITMIX64_MUL1) + MUL2 = uint64(SPLITMIX64_MUL2) + + state = state + GAMMA + z = state + z = (z ^ (z >> uint64(30))) * MUL1 + z = (z ^ (z >> uint64(27))) * MUL2 + z = z ^ (z >> uint64(31)) + return state, z + + +@cuda.jit(device=True, inline=True) +def _feistel_bijection(val, seed, left_bits, right_bits, left_mask, right_mask): + """ + Feistel bijection matching libcudacxx __feistel_bijection. + """ + M0 = uint64(FEISTEL_M0) + + # Match C++ initialization exactly: + # __state.__low = val >> right_side_bits + # __state.__high = val & right_side_mask + state_low = uint32((val >> uint64(right_bits)) & uint64(left_mask)) + state_high = uint32(val & uint64(right_mask)) + + shift_amount = uint64(right_bits - left_bits) + lbits = uint64(left_bits) + lmask = uint32(left_mask) + rmask = uint32(right_mask) + + # Initialize key generator state from seed + key_state = uint64(seed) + + # 24 rounds with independent keys + for _ in range(NUM_ROUNDS): + # Generate next key using SplitMix64 + key_state, key_output = _splitmix64_next(key_state) + round_key = uint32(key_output & uint64(0xFFFFFFFF)) + + # Feistel round matching C++ exactly: + # product = M0 * __state.__high + # hi = product >> 32 + # lo = product & 0xFFFFFFFF + # lo = (lo << shift) | (__state.__low >> left_bits) + # __state.__high = (hi ^ key ^ __state.__low) & left_mask + # __state.__low = lo & right_mask + product = M0 * uint64(state_high) + hi = uint32(product >> uint64(32)) + lo = uint32(product) + + lo = uint32((uint64(lo) << shift_amount) | (uint64(state_low) >> lbits)) + + new_high = ((hi ^ round_key) ^ state_low) & lmask + new_low = lo & rmask + + state_high = new_high + state_low = new_low + + # Match C++ output: (__state.__high << right_bits) | __state.__low + return (uint64(state_high) << uint64(right_bits)) | uint64(state_low) + + +def _splitmix64_host(x: int) -> int: + """ + Host-side SplitMix64 used to derive a 64-bit seed from the user seed. + """ + x &= (1 << 64) - 1 + x = (x + SPLITMIX64_GAMMA) & ((1 << 64) - 1) + z = x + z ^= z >> 30 + z = (z * SPLITMIX64_MUL1) & ((1 << 64) - 1) + z ^= z >> 27 + z = (z * SPLITMIX64_MUL2) & ((1 << 64) - 1) + z ^= z >> 31 + return z & ((1 << 64) - 1) + + +class ShuffleIteratorKind(IteratorKind): + pass + + +# Cache key excludes seed - only structure-defining parameters +def _make_cache_key(num_items: int, seed: int): + return (num_items,) + + +@cache_with_key(_make_cache_key) +def _make_shuffle_iterator_class(num_items: int, seed: int): + """ + Factory that creates a ShuffleIterator class for a given num_items. + The seed is NOT part of the cache key, so the same class is reused for different seeds. + """ + if num_items <= 0: + raise ValueError("num_items must be > 0") + + m = int(num_items) + + # total_bits = ceil(log2(m)), minimum 4 bits for proper mixing (matches C++) + total_bits = max((m - 1).bit_length(), 4) + + # Feistel uses unbalanced halves: left = floor(total/2), right = ceil(total/2) + left_bits = total_bits // 2 + right_bits = total_bits - left_bits + + if total_bits > 63: + raise ValueError("num_items too large for uint64-based shuffle iterator") + + left_mask = (1 << left_bits) - 1 + right_mask = (1 << right_bits) - 1 + + # Capture constants for the device functions + _m = m + _left_bits = left_bits + _right_bits = right_bits + _left_mask = left_mask + _right_mask = right_mask + + @cuda.jit(device=True) + def _permute_with_seed(index, seed): + """Permute a single index using the Feistel bijection with cycle-walking.""" + mm = uint64(_m) + x = uint64(index) + + y = _feistel_bijection( + x, + seed, + _left_bits, + _right_bits, + uint64(_left_mask), + uint64(_right_mask), + ) + + # Cycle-walk into [0, m) + while y >= mm: + y = _feistel_bijection( + y, + seed, + _left_bits, + _right_bits, + uint64(_left_mask), + uint64(_right_mask), + ) + + return int64(y) + + # State: (index, seed) - matches C++ which stores (bijection, current_index) + state_type = types.UniTuple(types.int64, 2) + + class ShuffleIterator(IteratorBase): + iterator_kind_type = ShuffleIteratorKind + + def __init__(self, seed: int): + # State: (current_index, seed) + # One iterator = one permutation (matches C++ behavior) + cvalue = (ctypes.c_int64 * 2)(0, seed) + super().__init__( + cvalue=cvalue, + state_type=state_type, + value_type=types.int64, + ) + + @property + def host_advance(self): + return ShuffleIterator._advance + + @property + def advance(self): + return ShuffleIterator._advance + + @property + def input_dereference(self): + return ShuffleIterator._input_dereference + + @property + def output_dereference(self): + raise AttributeError("ShuffleIterator cannot be used as an output iterator") + + @staticmethod + def _advance(state, distance): + idx = state[0][0] + seed = state[0][1] + state[0] = (idx + distance, seed) + + @staticmethod + def _input_dereference(state, result): + idx = state[0][0] + seed = state[0][1] + result[0] = _permute_with_seed(idx, seed) + + return ShuffleIterator + + +def make_shuffle_iterator(num_items: int, seed: int): + """ + Iterator that produces a deterministic "random" permutation + of indices in ``[0, num_items)``. + + Uses a Feistel cipher bijection matching the libcudacxx implementation, + with 24 rounds and independent keys per round for high-quality shuffling. + + Parameters + ---------- + num_items : int + Number of elements in the domain to permute. + seed : int + Seed used to parameterize the permutation. Different seeds produce + different (deterministic) permutations. + + Returns + ------- + ShuffleIterator + An iterator that yields a shuffled ordering of indices in + ``[0, num_items)``. + """ + # Get the class (cached by num_items only, NOT seed) + ShuffleIteratorClass = _make_shuffle_iterator_class(num_items, seed) + + # Derive the internal seed from the user seed + internal_seed = _splitmix64_host(int(seed)) + + # Create instance with the runtime seed + return ShuffleIteratorClass(internal_seed) diff --git a/python/cuda_cccl/tests/compute/examples/iterator/shuffle_iterator_basic.py b/python/cuda_cccl/tests/compute/examples/iterator/shuffle_iterator_basic.py new file mode 100644 index 00000000000..dc64c1e593a --- /dev/null +++ b/python/cuda_cccl/tests/compute/examples/iterator/shuffle_iterator_basic.py @@ -0,0 +1,65 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# example-begin +""" +Demonstrate using ShuffleIterator for deterministic random permutation of indices in ``[0, num_items)``. +""" + +import cupy as cp +import numpy as np + +import cuda.compute +from cuda.compute import ( + OpKind, + PermutationIterator, + ShuffleIterator, +) + +# Create a shuffle iterator for 10 elements with a fixed seed +num_items = 10 +seed = 42 +shuffle_it = ShuffleIterator(num_items, seed) + +# Collect the shuffled indices using unary_transform +d_indices = cp.empty(num_items, dtype=np.int64) +cuda.compute.unary_transform(shuffle_it, d_indices, lambda x: x, num_items) + +print(f"Shuffled indices: {d_indices.get()}") +# Verify it's a valid permutation (all indices 0 to num_items-1 appear exactly once) +assert set(d_indices.get()) == set(range(num_items)) + +# Use ShuffleIterator with PermutationIterator to access data in shuffled order +d_values = cp.asarray([10, 20, 30, 40, 50, 60, 70, 80, 90, 100], dtype=np.int32) + +# Create a new shuffle iterator (same seed for same order) +shuffle_it2 = ShuffleIterator(num_items, seed) + +# Combine with PermutationIterator to access values in shuffled order +perm_it = PermutationIterator(d_values, shuffle_it2) + +# Reduce the shuffled values - sum should equal sum of all values +h_init = np.array([0], dtype=np.int32) +d_output = cp.empty(1, dtype=np.int32) + +cuda.compute.reduce_into(perm_it, d_output, OpKind.PLUS, num_items, h_init) + +# Since shuffle is a permutation, sum equals sum of all values +expected_sum = d_values.sum() +print(f"Sum of shuffled values: {d_output[0]} (expected: {expected_sum})") +assert d_output[0] == expected_sum + +# Different seeds produce different permutations +shuffle_it_a = ShuffleIterator(num_items, seed=1) +shuffle_it_b = ShuffleIterator(num_items, seed=2) + +d_perm_a = cp.empty(num_items, dtype=np.int64) +d_perm_b = cp.empty(num_items, dtype=np.int64) + +cuda.compute.unary_transform(shuffle_it_a, d_perm_a, lambda x: x, num_items) +cuda.compute.unary_transform(shuffle_it_b, d_perm_b, lambda x: x, num_items) + +print(f"Permutation with seed=1: {d_perm_a.get()}") +print(f"Permutation with seed=2: {d_perm_b.get()}") +assert not np.array_equal(d_perm_a.get(), d_perm_b.get()) diff --git a/python/cuda_cccl/tests/compute/test_shuffle_iterator.py b/python/cuda_cccl/tests/compute/test_shuffle_iterator.py new file mode 100644 index 00000000000..77db802d9a0 --- /dev/null +++ b/python/cuda_cccl/tests/compute/test_shuffle_iterator.py @@ -0,0 +1,180 @@ +# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import cupy as cp +import numpy as np +import pytest + +import cuda.compute +from cuda.compute import OpKind +from cuda.compute.iterators import ( + PermutationIterator, + ShuffleIterator, +) + + +def test_shuffle_iterator_bijectivity(): + """Test that ShuffleIterator produces a valid permutation (bijective).""" + num_items = 100 + seed = 42 + + shuffle_it = ShuffleIterator(num_items, seed) + + # Use unary_transform to collect all shuffled indices + d_output = cp.empty(num_items, dtype=np.int64) + cuda.compute.unary_transform(shuffle_it, d_output, lambda x: x, num_items) + + result = d_output.get() + + # Every index from 0 to num_items-1 should appear exactly once + assert len(set(result)) == num_items + assert set(result) == set(range(num_items)) + + +def test_shuffle_iterator_determinism(): + """Test that same seed produces same permutation.""" + num_items = 50 + seed = 12345 + + shuffle_it1 = ShuffleIterator(num_items, seed) + shuffle_it2 = ShuffleIterator(num_items, seed) + + d_output1 = cp.empty(num_items, dtype=np.int64) + d_output2 = cp.empty(num_items, dtype=np.int64) + + cuda.compute.unary_transform(shuffle_it1, d_output1, lambda x: x, num_items) + cuda.compute.unary_transform(shuffle_it2, d_output2, lambda x: x, num_items) + + cp.testing.assert_array_equal(d_output1, d_output2) + + +def test_shuffle_iterator_different_seeds(): + """Test that different seeds produce different permutations.""" + num_items = 50 + + shuffle_it1 = ShuffleIterator(num_items, seed=1) + shuffle_it2 = ShuffleIterator(num_items, seed=2) + + d_output1 = cp.empty(num_items, dtype=np.int64) + d_output2 = cp.empty(num_items, dtype=np.int64) + + cuda.compute.unary_transform(shuffle_it1, d_output1, lambda x: x, num_items) + cuda.compute.unary_transform(shuffle_it2, d_output2, lambda x: x, num_items) + + # Very unlikely that two different seeds produce the same permutation + assert not np.array_equal(d_output1.get(), d_output2.get()) + + +@pytest.mark.parametrize("num_items", [1, 2, 7, 16, 17, 100, 1000, 1023, 1024, 1025]) +def test_shuffle_iterator_various_sizes(num_items): + """Test ShuffleIterator works correctly for various sizes.""" + seed = 42 + + shuffle_it = ShuffleIterator(num_items, seed) + + d_output = cp.empty(num_items, dtype=np.int64) + cuda.compute.unary_transform(shuffle_it, d_output, lambda x: x, num_items) + + result = d_output.get() + + # Should be a valid permutation + assert len(set(result)) == num_items + assert set(result) == set(range(num_items)) + + +def test_shuffle_iterator_with_reduction(): + """Test ShuffleIterator with a reduction operation.""" + num_items = 100 + seed = 42 + + shuffle_it = ShuffleIterator(num_items, seed) + + h_init = np.array([0], dtype=np.int64) + d_output = cp.empty(1, dtype=np.int64) + + cuda.compute.reduce_into(shuffle_it, d_output, OpKind.PLUS, num_items, h_init) + + # Sum of a permutation of [0, num_items) should equal sum(0..num_items-1) + expected = sum(range(num_items)) + assert d_output.get()[0] == expected + + +def test_shuffle_iterator_with_permutation_iterator(): + """Test ShuffleIterator composed with PermutationIterator for shuffled data access.""" + num_items = 10 + seed = 42 + + # Create data array + d_values = cp.asarray([10, 20, 30, 40, 50, 60, 70, 80, 90, 100], dtype=np.int32) + + # Create a shuffle iterator to generate shuffled indices + shuffle_it = ShuffleIterator(num_items, seed) + + # Get the shuffled indices to verify correctness + d_indices = cp.empty(num_items, dtype=np.int64) + cuda.compute.unary_transform(shuffle_it, d_indices, lambda x: x, num_items) + + # Create permutation iterator using shuffle iterator as indices + shuffle_it2 = ShuffleIterator(num_items, seed) # Fresh iterator + perm_it = PermutationIterator(d_values, shuffle_it2) + + # Reduce the permuted values + h_init = np.array([0], dtype=np.int32) + d_output = cp.empty(1, dtype=np.int32) + + cuda.compute.reduce_into(perm_it, d_output, OpKind.PLUS, num_items, h_init) + + # Sum should equal sum of all values (since it's a permutation) + expected = d_values.sum() + assert d_output.get()[0] == expected + + +def test_shuffle_iterator_invalid_num_items(): + """Test that ShuffleIterator raises error for invalid num_items.""" + with pytest.raises(ValueError, match="num_items must be > 0"): + ShuffleIterator(0, seed=42) + + with pytest.raises(ValueError, match="num_items must be > 0"): + ShuffleIterator(-1, seed=42) + + +def test_shuffle_iterator_rounds(): + """Test ShuffleIterator with different round counts.""" + num_items = 50 + seed = 42 + + # Test with minimum rounds (6) + shuffle_it1 = ShuffleIterator(num_items, seed, rounds=3) # Will be clamped to 6 + + d_output1 = cp.empty(num_items, dtype=np.int64) + cuda.compute.unary_transform(shuffle_it1, d_output1, lambda x: x, num_items) + + result1 = set(d_output1.get()) + assert result1 == set(range(num_items)) + + # Test with more rounds + shuffle_it2 = ShuffleIterator(num_items, seed, rounds=12) + + d_output2 = cp.empty(num_items, dtype=np.int64) + cuda.compute.unary_transform(shuffle_it2, d_output2, lambda x: x, num_items) + + result2 = set(d_output2.get()) + assert result2 == set(range(num_items)) + + +def test_shuffle_iterator_large(): + """Test ShuffleIterator with a larger dataset.""" + num_items = 10000 + seed = 12345 + + shuffle_it = ShuffleIterator(num_items, seed) + + # Just check sum to verify it's a valid permutation + h_init = np.array([0], dtype=np.int64) + d_output = cp.empty(1, dtype=np.int64) + + cuda.compute.reduce_into(shuffle_it, d_output, OpKind.PLUS, num_items, h_init) + + expected = sum(range(num_items)) + assert d_output.get()[0] == expected