diff --git a/cub/test/catch2_test_warp_scan_api.cu b/cub/test/catch2_test_warp_scan_api.cu index acca315305f..45444c44e14 100644 --- a/cub/test/catch2_test_warp_scan_api.cu +++ b/cub/test/catch2_test_warp_scan_api.cu @@ -33,10 +33,32 @@ #include +#include + #include "catch2_test_helper.h" +#include "cuda/std/__algorithm/max.h" +#include "cuda/std/__numeric/inclusive_scan.h" constexpr int num_warps = 4; +template +struct max_op +{ + __host__ __device__ T const& operator()(T const& i, T const& j) + { + return cuda::std::max(i, j); + } +}; + +template +struct sum_op +{ + __host__ __device__ T operator()(T const& i, T const& j) + { + return i + j; + } +}; + // example-begin inclusive-warp-scan-init-value __global__ void InclusiveScanKernel(int* output) { @@ -45,26 +67,26 @@ __global__ void InclusiveScanKernel(int* output) // Allocate WarpScan shared memory for 4 warps __shared__ typename warp_scan_t::TempStorage temp_storage[num_warps]; - int initial_value = 1; - int thread_data = threadIdx.x; + int warp_id = threadIdx.x / 32; + int initial_value = 3; + int thread_data = threadIdx.x % 32 + warp_id; - // warp #0 input: { 0, 1, 2, 3, 4, ..., 31} - // warp #1 input: {32, 33, 34, 35, 36, ..., 63} - // warp #2 input: {64, 65, 66, 67, 68, ..., 95} - // warp #4 input: {96, 97, 98, 99, 100, ..., 127} + // warp #0 input: {0, 1, 2, 3, ..., 31} + // warp #1 input: {1, 2, 3, 4, ..., 32} + // warp #2 input: {2, 3, 4, 5, ..., 33} + // warp #4 input: {3, 4, 5, 6, ..., 34} // Collectively compute the block-wide inclusive prefix max scan - int warp_id = threadIdx.x / 32; - warp_scan_t(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, initial_value, cub::Sum()); + warp_scan_t(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, initial_value, cub::Max()); - // initial value = 1 (for each warp) - // warp #0 output: { 1, 2, 4, ..., 497} - // warp #1 output: {33, 66, 100, ..., 1521} - // warp #2 output: {65, 130, 196, ..., 2545} - // warp #3 output: {97, 194, 292, ..., 3569} + // initial value = 3 (for each warp) + // warp #0 output: {3, 3, 3, 3, ..., 31} + // warp #1 output: {3, 3, 3, 4, ..., 32} + // warp #2 output: {3, 3, 4, 5, ..., 33} + // warp #3 output: {3, 4, 5, 6, ..., 34} + output[threadIdx.x] = thread_data; // example-end inclusive-warp-scan-init-value - output[threadIdx.x] = thread_data; } CUB_TEST("Block array-based inclusive scan works with initial value", "[scan][block]") @@ -76,24 +98,21 @@ CUB_TEST("Block array-based inclusive scan works with initial value", "[scan][bl REQUIRE(cudaSuccess == cudaDeviceSynchronize()); c2h::host_vector expected(d_out.size()); - expected[0] = 1; // Initial value - // Calculate the prefix sum with an additional +1 every 32 elements - for (int i = 1; i < num_warps * 32; ++i) + for (int i = 0; i < num_warps; ++i) { - if (i % 32 == 0) - { - expected[i] = i + 1; // Reset at the start of each warp - } - else - { - expected[i] = expected[i - 1] + i; - } + auto start = expected.begin() + i * 32; + auto end = start + 32; + + std::iota(start, end, i); // initialize host input for every warp + + cuda::std::inclusive_scan(start, end, start, max_op{}, 3); } REQUIRE(expected == d_out); } +// example-begin inclusive-warp-scan-init-value-aggregate __global__ void InclusiveScanKernelAggr(int* output, int* d_warp_aggregate) { // Specialize WarpScan for type int @@ -101,28 +120,27 @@ __global__ void InclusiveScanKernelAggr(int* output, int* d_warp_aggregate) // Allocate WarpScan shared memory for 4 warps __shared__ typename warp_scan_t::TempStorage temp_storage[num_warps]; - int initial_value = 1; - int thread_data = threadIdx.x; + int warp_id = threadIdx.x / 32; + int initial_value = 3; // for each warp + int thread_data = 1; + int warp_aggregate; - // warp #0 input: { 0, 1, 2, 3, 4, ..., 31} - // warp #1 input: {32, 33, 34, 35, 36, ..., 63} - // warp #2 input: {64, 65, 66, 67, 68, ..., 95} - // warp #4 input: {96, 97, 98, 99, 100, ..., 127} + // warp #0 input: {1, 1, 1, 1, ..., 1} + // warp #1 input: {1, 1, 1, 1, ..., 1} + // warp #2 input: {1, 1, 1, 1, ..., 1} + // warp #4 input: {1, 1, 1, 1, ..., 1} // Collectively compute the block-wide inclusive prefix max scan - int warp_aggregate; - int warp_id = threadIdx.x / 32; warp_scan_t(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, initial_value, cub::Sum(), warp_aggregate); - // initial value = 1 (for each warp) - // warp #0 output: { 1, 2, 4, ..., 497} - aggregate: 496 - // warp #1 output: {33, 66, 100, ..., 1521} - aggregate: 1520 - // warp #2 output: {65, 130, 196, ..., 2545} - aggregate: 2544 - // warp #3 output: {97, 194, 292, ..., 3569} - aggregate: 3568 + // warp #1 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32 + // warp #2 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32 + // warp #0 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32 + // warp #3 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32 - // example-end inclusive-warp-scan-init-value - d_warp_aggregate[warp_id] = warp_aggregate; + // example-end inclusive-warp-scan-init-value-aggregate output[threadIdx.x] = thread_data; + d_warp_aggregate[warp_id] = warp_aggregate; } CUB_TEST("Block array-based inclusive scan aggregate works with initial value", "[scan][block]") @@ -137,28 +155,20 @@ CUB_TEST("Block array-based inclusive scan aggregate works with initial value", c2h::host_vector expected(d_out.size()); c2h::host_vector expected_aggr{}; - expected[0] = 1; // Initial value - // Calculate the prefix sum with an additional +1 every 32 elements - for (int i = 1; i < num_warps * 32; ++i) + for (int i = 0; i < num_warps; ++i) { - if (i % 32 == 0) - { - expected[i] = i + 1; // Reset at the start of each warp - } - else - { - expected[i] = expected[i - 1] + i; - } - - // fetch the aggregate at the end of each warp - if (i % 32 == 0) - { - expected_aggr.push_back(expected[i - 1] - 1); // warp aggregate doed not take - // initial value into account - } + auto start = expected.begin() + i * 32; + auto end = start + 32; + int init_val = 3; + + std::fill(start, end, 1); // initialize host input for every warp + + cuda::std::inclusive_scan(start, end, start, sum_op{}, init_val); + + expected_aggr.push_back(expected[i * 32 + 31] - init_val); // warp aggregate doed not take + // initial value into account } - expected_aggr.push_back(expected.back() - 1); REQUIRE(expected == d_out); REQUIRE(expected_aggr == d_warp_aggregate);