Skip to content

Commit

Permalink
Improve warp_scan api test
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Jun 10, 2024
1 parent 9582835 commit a2c3772
Showing 1 changed file with 69 additions and 59 deletions.
128 changes: 69 additions & 59 deletions cub/test/catch2_test_warp_scan_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,32 @@

#include <cuda/std/numeric>

#include <numeric>

#include "catch2_test_helper.h"
#include "cuda/std/__algorithm/max.h"
#include "cuda/std/__numeric/inclusive_scan.h"

constexpr int num_warps = 4;

template <typename T>
struct max_op
{
__host__ __device__ T const& operator()(T const& i, T const& j)
{
return cuda::std::max(i, j);
}
};

template <typename T>
struct sum_op
{
__host__ __device__ T operator()(T const& i, T const& j)
{
return i + j;
}
};

// example-begin inclusive-warp-scan-init-value
__global__ void InclusiveScanKernel(int* output)
{
Expand All @@ -45,26 +67,26 @@ __global__ void InclusiveScanKernel(int* output)
// Allocate WarpScan shared memory for 4 warps
__shared__ typename warp_scan_t::TempStorage temp_storage[num_warps];

int initial_value = 1;
int thread_data = threadIdx.x;
int warp_id = threadIdx.x / 32;
int initial_value = 3;
int thread_data = threadIdx.x % 32 + warp_id;

// warp #0 input: { 0, 1, 2, 3, 4, ..., 31}
// warp #1 input: {32, 33, 34, 35, 36, ..., 63}
// warp #2 input: {64, 65, 66, 67, 68, ..., 95}
// warp #4 input: {96, 97, 98, 99, 100, ..., 127}
// warp #0 input: {0, 1, 2, 3, ..., 31}
// warp #1 input: {1, 2, 3, 4, ..., 32}
// warp #2 input: {2, 3, 4, 5, ..., 33}
// warp #4 input: {3, 4, 5, 6, ..., 34}

// Collectively compute the block-wide inclusive prefix max scan
int warp_id = threadIdx.x / 32;
warp_scan_t(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, initial_value, cub::Sum());
warp_scan_t(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, initial_value, cub::Max());

// initial value = 1 (for each warp)
// warp #0 output: { 1, 2, 4, ..., 497}
// warp #1 output: {33, 66, 100, ..., 1521}
// warp #2 output: {65, 130, 196, ..., 2545}
// warp #3 output: {97, 194, 292, ..., 3569}
// initial value = 3 (for each warp)
// warp #0 output: {3, 3, 3, 3, ..., 31}
// warp #1 output: {3, 3, 3, 4, ..., 32}
// warp #2 output: {3, 3, 4, 5, ..., 33}
// warp #3 output: {3, 4, 5, 6, ..., 34}
output[threadIdx.x] = thread_data;

// example-end inclusive-warp-scan-init-value
output[threadIdx.x] = thread_data;
}

CUB_TEST("Block array-based inclusive scan works with initial value", "[scan][block]")
Expand All @@ -76,53 +98,49 @@ CUB_TEST("Block array-based inclusive scan works with initial value", "[scan][bl
REQUIRE(cudaSuccess == cudaDeviceSynchronize());

c2h::host_vector<int> expected(d_out.size());
expected[0] = 1; // Initial value

// Calculate the prefix sum with an additional +1 every 32 elements
for (int i = 1; i < num_warps * 32; ++i)
for (int i = 0; i < num_warps; ++i)
{
if (i % 32 == 0)
{
expected[i] = i + 1; // Reset at the start of each warp
}
else
{
expected[i] = expected[i - 1] + i;
}
auto start = expected.begin() + i * 32;
auto end = start + 32;

std::iota(start, end, i); // initialize host input for every warp

cuda::std::inclusive_scan(start, end, start, max_op<int>{}, 3);
}

REQUIRE(expected == d_out);
}

// example-begin inclusive-warp-scan-init-value-aggregate
__global__ void InclusiveScanKernelAggr(int* output, int* d_warp_aggregate)
{
// Specialize WarpScan for type int
typedef cub::WarpScan<int> warp_scan_t;
// Allocate WarpScan shared memory for 4 warps
__shared__ typename warp_scan_t::TempStorage temp_storage[num_warps];

int initial_value = 1;
int thread_data = threadIdx.x;
int warp_id = threadIdx.x / 32;
int initial_value = 3; // for each warp
int thread_data = 1;
int warp_aggregate;

// warp #0 input: { 0, 1, 2, 3, 4, ..., 31}
// warp #1 input: {32, 33, 34, 35, 36, ..., 63}
// warp #2 input: {64, 65, 66, 67, 68, ..., 95}
// warp #4 input: {96, 97, 98, 99, 100, ..., 127}
// warp #0 input: {1, 1, 1, 1, ..., 1}
// warp #1 input: {1, 1, 1, 1, ..., 1}
// warp #2 input: {1, 1, 1, 1, ..., 1}
// warp #4 input: {1, 1, 1, 1, ..., 1}

// Collectively compute the block-wide inclusive prefix max scan
int warp_aggregate;
int warp_id = threadIdx.x / 32;
warp_scan_t(temp_storage[warp_id]).InclusiveScan(thread_data, thread_data, initial_value, cub::Sum(), warp_aggregate);

// initial value = 1 (for each warp)
// warp #0 output: { 1, 2, 4, ..., 497} - aggregate: 496
// warp #1 output: {33, 66, 100, ..., 1521} - aggregate: 1520
// warp #2 output: {65, 130, 196, ..., 2545} - aggregate: 2544
// warp #3 output: {97, 194, 292, ..., 3569} - aggregate: 3568
// warp #1 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32
// warp #2 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32
// warp #0 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32
// warp #3 output: {4, 5, 6, 7, ..., 35} - warp aggregate: 32

// example-end inclusive-warp-scan-init-value
d_warp_aggregate[warp_id] = warp_aggregate;
// example-end inclusive-warp-scan-init-value-aggregate
output[threadIdx.x] = thread_data;
d_warp_aggregate[warp_id] = warp_aggregate;
}

CUB_TEST("Block array-based inclusive scan aggregate works with initial value", "[scan][block]")
Expand All @@ -137,28 +155,20 @@ CUB_TEST("Block array-based inclusive scan aggregate works with initial value",

c2h::host_vector<int> expected(d_out.size());
c2h::host_vector<int> expected_aggr{};
expected[0] = 1; // Initial value

// Calculate the prefix sum with an additional +1 every 32 elements
for (int i = 1; i < num_warps * 32; ++i)
for (int i = 0; i < num_warps; ++i)
{
if (i % 32 == 0)
{
expected[i] = i + 1; // Reset at the start of each warp
}
else
{
expected[i] = expected[i - 1] + i;
}

// fetch the aggregate at the end of each warp
if (i % 32 == 0)
{
expected_aggr.push_back(expected[i - 1] - 1); // warp aggregate doed not take
// initial value into account
}
auto start = expected.begin() + i * 32;
auto end = start + 32;
int init_val = 3;

std::fill(start, end, 1); // initialize host input for every warp

cuda::std::inclusive_scan(start, end, start, sum_op<int>{}, init_val);

expected_aggr.push_back(expected[i * 32 + 31] - init_val); // warp aggregate doed not take
// initial value into account
}
expected_aggr.push_back(expected.back() - 1);

REQUIRE(expected == d_out);
REQUIRE(expected_aggr == d_warp_aggregate);
Expand Down

0 comments on commit a2c3772

Please sign in to comment.