Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prefetching kernel as new fallback for cub::DeviceTransform #2396

Merged
merged 6 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion cub/benchmarks/bench/transform/babelstream.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ struct policy_hub_t
using algo_policy =
::cuda::std::_If<algorithm == cub::detail::transform::Algorithm::fallback_for,
cub::detail::transform::fallback_for_policy,
cub::detail::transform::async_copy_policy_t<TUNE_THREADS>>;
::cuda::std::_If<algorithm == cub::detail::transform::Algorithm::prefetch,
cub::detail::transform::prefetch_policy_t<TUNE_THREADS>,
cub::detail::transform::async_copy_policy_t<TUNE_THREADS>>>;
};
};
#endif
Expand Down
6 changes: 3 additions & 3 deletions cub/benchmarks/bench/transform/babelstream1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
// SPDX-License-Identifier: BSD-3-Clause

// %RANGE% TUNE_THREADS tpb 128:1024:128
// %RANGE% TUNE_ALGORITHM alg 0:1:1
// %RANGE% TUNE_ALGORITHM alg 0:2:1

// keep checks at the top so compilation of discarded variants fails really fast
#if !TUNE_BASE
# if TUNE_ALGORITHM == 1 && (__CUDA_ARCH_LIST__) < 900
# if TUNE_ALGORITHM == 2 && (__CUDA_ARCH_LIST__) < 900
# error "Cannot compile algorithm 4 (ublkcp) below sm90"
# endif

# if TUNE_ALGORITHM == 1 && !defined(_CUB_HAS_TRANSFORM_UBLKCP)
# if TUNE_ALGORITHM == 2 && !defined(_CUB_HAS_TRANSFORM_UBLKCP)
# error "Cannot tune for ublkcp algorithm, which is not provided by CUB (old CTK?)"
# endif
#endif
Expand Down
6 changes: 3 additions & 3 deletions cub/benchmarks/bench/transform/babelstream2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
// SPDX-License-Identifier: BSD-3-Clause

// %RANGE% TUNE_THREADS tpb 128:1024:128
// %RANGE% TUNE_ALGORITHM alg 0:1:1
// %RANGE% TUNE_ALGORITHM alg 0:2:1

// keep checks at the top so compilation of discarded variants fails really fast
#if !TUNE_BASE
# if TUNE_ALGORITHM == 1 && (__CUDA_ARCH_LIST__) < 900
# if TUNE_ALGORITHM == 2 && (__CUDA_ARCH_LIST__) < 900
# error "Cannot compile algorithm 4 (ublkcp) below sm90"
# endif

# if TUNE_ALGORITHM == 1 && !defined(_CUB_HAS_TRANSFORM_UBLKCP)
# if TUNE_ALGORITHM == 2 && !defined(_CUB_HAS_TRANSFORM_UBLKCP)
# error "Cannot tune for ublkcp algorithm, which is not provided by CUB (old CTK?)"
# endif
#endif
Expand Down
6 changes: 3 additions & 3 deletions cub/benchmarks/bench/transform/babelstream3.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
// SPDX-License-Identifier: BSD-3-Clause

// %RANGE% TUNE_THREADS tpb 128:1024:128
// %RANGE% TUNE_ALGORITHM alg 0:1:1
// %RANGE% TUNE_ALGORITHM alg 0:2:1

// keep checks at the top so compilation of discarded variants fails really fast
#if !TUNE_BASE
# if TUNE_ALGORITHM == 1 && (__CUDA_ARCH_LIST__) < 900
# if TUNE_ALGORITHM == 2 && (__CUDA_ARCH_LIST__) < 900
# error "Cannot compile algorithm 4 (ublkcp) below sm90"
# endif

# if TUNE_ALGORITHM == 1 && !defined(_CUB_HAS_TRANSFORM_UBLKCP)
# if TUNE_ALGORITHM == 2 && !defined(_CUB_HAS_TRANSFORM_UBLKCP)
# error "Cannot tune for ublkcp algorithm, which is not provided by CUB (old CTK?)"
# endif
#endif
Expand Down
230 changes: 207 additions & 23 deletions cub/cub/device/dispatch/dispatch_transform.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ _CCCL_HOST_DEVICE constexpr auto loaded_bytes_per_iteration() -> int
enum class Algorithm
{
fallback_for,
prefetch,
#ifdef _CUB_HAS_TRANSFORM_UBLKCP
ublkcp,
#endif // _CUB_HAS_TRANSFORM_UBLKCP
Expand Down Expand Up @@ -133,6 +134,116 @@ _CCCL_DEVICE void transform_kernel_impl(
}
}

template <typename T>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE const char* round_down_ptr(const T* ptr, unsigned alignment)
{
#if _CCCL_STD_VER > 2011
_CCCL_ASSERT(::cuda::std::has_single_bit(alignment), "");
#endif // _CCCL_STD_VER > 2011
return reinterpret_cast<const char*>(
reinterpret_cast<::cuda::std::uintptr_t>(ptr) & ~::cuda::std::uintptr_t{alignment - 1});
}

template <int BlockThreads>
struct prefetch_policy_t
{
static constexpr int block_threads = BlockThreads;
// items per tile are determined at runtime. these (inclusive) bounds allow overriding that value via a tuning policy
static constexpr int items_per_thread_no_input = 2; // when there are no input iterators, the kernel is just filling
static constexpr int min_items_per_thread = 1;
static constexpr int max_items_per_thread = 32;
};

// Prefetches (at least on Hopper) a 128 byte cache line. Prefetching out-of-bounds addresses has no side effects
// TODO(bgruber): there is also the cp.async.bulk.prefetch instruction available on Hopper. May improve perf a tiny bit
// as we need to create less instructions to prefetch the same amount of data.
template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void prefetch(const T* addr)
{
// TODO(bgruber): prefetch to L1 may be even better
asm volatile("prefetch.global.L2 [%0];" : : "l"(__cvta_generic_to_global(addr)) : "memory");
}

template <int BlockDim, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void prefetch_tile(const T* addr, int tile_size)
{
constexpr int prefetch_byte_stride = 128; // TODO(bgruber): should correspond to cache line size. Does this need to be
// architecture dependent?
const int tile_size_bytes = tile_size * sizeof(T);
// prefetch does not stall and unrolling just generates a lot of unnecessary computations and predicate handling
#pragma unroll 1
for (int offset = threadIdx.x * prefetch_byte_stride; offset < tile_size_bytes;
offset += BlockDim * prefetch_byte_stride)
{
prefetch(reinterpret_cast<const char*>(addr) + offset);
}
}

// TODO(miscco): we should probably constrain It to not be a contiguous iterator in C++17 (and change the overload
// above to accept any contiguous iterator)
// overload for any iterator that is not a pointer, do nothing
template <int, typename It, ::cuda::std::__enable_if_t<!::cuda::std::is_pointer<It>::value, int> = 0>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: We could use cuda::std::contiguous_iterator in C++17 onwards.

Definitely something for a followup

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a note.

_CCCL_DEVICE _CCCL_FORCEINLINE void prefetch_tile(It, int)
{}

// This kernel guarantees that objects passed as arguments to the user-provided transformation function f reside in
// global memory. No intermediate copies are taken. If the parameter type of f is a reference, taking the address of the
// parameter yields a global memory address.
template <typename PrefetchPolicy,
typename Offset,
typename F,
typename RandomAccessIteratorOut,
typename... RandomAccessIteratorIn>
_CCCL_DEVICE void transform_kernel_impl(
::cuda::std::integral_constant<Algorithm, Algorithm::prefetch>,
Offset num_items,
int num_elem_per_thread,
F f,
RandomAccessIteratorOut out,
RandomAccessIteratorIn... ins)
{
constexpr int block_dim = PrefetchPolicy::block_threads;
const int tile_stride = block_dim * num_elem_per_thread;
const Offset offset = static_cast<Offset>(blockIdx.x) * tile_stride;
const int tile_size = static_cast<int>(::cuda::std::min(num_items - offset, Offset{tile_stride}));

// move index and iterator domain to the block/thread index, to reduce arithmetic in the loops below
{
int dummy[] = {(ins += offset, 0)..., 0};
(void) &dummy;
out += offset;
}

{
// TODO(bgruber): replace by fold over comma in C++17
int dummy[] = {(prefetch_tile<block_dim>(ins, tile_size), 0)..., 0}; // extra zero to handle empty packs
(void) &dummy; // nvcc 11.1 needs extra strong unused warning suppression
}

#define PREFETCH_AGENT(full_tile) \
/* ahendriksen: various unrolling yields less <1% gains at much higher compile-time cost */ \
/* bgruber: but A6000 and H100 show small gains without pragma */ \
/*_Pragma("unroll 1")*/ for (int j = 0; j < num_elem_per_thread; ++j) \
{ \
const int idx = j * block_dim + threadIdx.x; \
if (full_tile || idx < tile_size) \
{ \
/* we have to unwrap Thrust's proxy references here for backward compatibility (try zip_iterator.cu test) */ \
out[idx] = f(THRUST_NS_QUALIFIER::raw_reference_cast(ins[idx])...); \
} \
}

if (tile_stride == tile_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: does the split between full_tile and partial tiles make a big difference in run time?

I expect that the difference shouldn't be too big, as the partial tile version only adds one comparison and predicates the rest of the computation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good question! I followed existing practice here, but I see your point. I will evaluate!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a benchmark on H200:

Babelstream on H200 two code paths (baseline) vs. single code path
# mul

## [0] NVIDIA H200

|  T{ct}  |  OffsetT{ct}  |  Elements{io}  |   Ref Time |   Ref Noise |   Cmp Time |   Cmp Noise |      Diff |   %Diff |  Status  |
|---------|---------------|----------------|------------|-------------|------------|-------------|-----------|---------|----------|
|   I8    |      I32      |      2^16      |   3.964 us |       3.53% |   3.948 us |       2.92% | -0.016 us |  -0.41% |   PASS   |
|   I8    |      I32      |      2^20      |  10.199 us |       1.88% |  10.266 us |       1.87% |  0.067 us |   0.65% |   PASS   |
|   I8    |      I32      |      2^24      |  22.588 us |       1.08% |  22.788 us |       1.11% |  0.200 us |   0.89% |   PASS   |
|   I8    |      I32      |      2^28      | 236.103 us |       0.15% | 238.034 us |       0.14% |  1.931 us |   0.82% |   FAIL   |
|   I8    |      I64      |      2^16      |   4.094 us |       3.64% |   4.114 us |       3.47% |  0.021 us |   0.51% |   PASS   |
|   I8    |      I64      |      2^20      |  10.245 us |       1.56% |  10.260 us |       1.47% |  0.015 us |   0.15% |   PASS   |
|   I8    |      I64      |      2^24      |  22.733 us |       1.09% |  22.792 us |       1.14% |  0.059 us |   0.26% |   PASS   |
|   I8    |      I64      |      2^28      | 237.687 us |       0.13% | 237.323 us |       0.14% | -0.363 us |  -0.15% |   FAIL   |
|   I16   |      I32      |      2^16      |   4.133 us |       2.97% |   4.168 us |       2.64% |  0.035 us |   0.84% |   PASS   |
|   I16   |      I32      |      2^20      |   7.748 us |       2.16% |   7.838 us |       2.33% |  0.089 us |   1.15% |   PASS   |
|   I16   |      I32      |      2^24      |  25.688 us |       1.48% |  25.959 us |       1.46% |  0.271 us |   1.05% |   PASS   |
|   I16   |      I32      |      2^28      | 299.739 us |       0.12% | 303.418 us |       0.10% |  3.679 us |   1.23% |   FAIL   |
|   I16   |      I64      |      2^16      |   4.191 us |       2.98% |   4.198 us |       3.03% |  0.007 us |   0.16% |   PASS   |
|   I16   |      I64      |      2^20      |   7.925 us |       2.11% |   7.921 us |       2.21% | -0.004 us |  -0.05% |   PASS   |
|   I16   |      I64      |      2^24      |  25.890 us |       1.39% |  25.933 us |       1.48% |  0.043 us |   0.17% |   PASS   |
|   I16   |      I64      |      2^28      | 300.617 us |       0.12% | 301.334 us |       0.10% |  0.717 us |   0.24% |   FAIL   |
|   F32   |      I32      |      2^16      |   4.108 us |       3.51% |   4.207 us |       3.29% |  0.100 us |   2.43% |   PASS   |
|   F32   |      I32      |      2^20      |   7.212 us |       2.63% |   7.398 us |       2.94% |  0.186 us |   2.58% |   PASS   |
|   F32   |      I32      |      2^24      |  38.193 us |       1.75% |  38.389 us |       1.59% |  0.196 us |   0.51% |   PASS   |
|   F32   |      I32      |      2^28      | 520.746 us |       0.15% | 522.979 us |       0.16% |  2.233 us |   0.43% |   FAIL   |
|   F32   |      I64      |      2^16      |   4.228 us |       4.57% |   4.325 us |       3.97% |  0.097 us |   2.29% |   PASS   |
|   F32   |      I64      |      2^20      |   7.264 us |       2.86% |   7.429 us |       2.98% |  0.165 us |   2.28% |   PASS   |
|   F32   |      I64      |      2^24      |  38.171 us |       1.62% |  38.409 us |       1.61% |  0.239 us |   0.63% |   PASS   |
|   F32   |      I64      |      2^28      | 521.283 us |       0.15% | 522.771 us |       0.16% |  1.488 us |   0.29% |   FAIL   |
|   F64   |      I32      |      2^16      |   4.387 us |       4.15% |   4.480 us |       3.78% |  0.093 us |   2.12% |   PASS   |
|   F64   |      I32      |      2^20      |   9.059 us |       2.89% |   9.198 us |       2.65% |  0.139 us |   1.54% |   PASS   |
|   F64   |      I32      |      2^24      |  67.603 us |       1.08% |  67.765 us |       1.06% |  0.161 us |   0.24% |   PASS   |
|   F64   |      I32      |      2^28      |   1.008 ms |       0.05% |   1.009 ms |       0.06% |  0.501 us |   0.05% |   PASS   |
|   F64   |      I64      |      2^16      |   4.516 us |       4.02% |   4.647 us |       5.02% |  0.130 us |   2.89% |   PASS   |
|   F64   |      I64      |      2^20      |   9.079 us |       2.79% |   9.217 us |       2.35% |  0.138 us |   1.52% |   PASS   |
|   F64   |      I64      |      2^24      |  67.634 us |       1.10% |  67.780 us |       1.07% |  0.147 us |   0.22% |   PASS   |
|   F64   |      I64      |      2^28      |   1.009 ms |       0.06% |   1.009 ms |       0.06% |  0.578 us |   0.06% |   FAIL   |
|  I128   |      I32      |      2^16      |   4.790 us |       4.22% |   4.934 us |       5.32% |  0.144 us |   3.00% |   PASS   |
|  I128   |      I32      |      2^20      |  12.903 us |       2.26% |  13.029 us |       2.36% |  0.126 us |   0.98% |   PASS   |
|  I128   |      I32      |      2^24      | 130.827 us |       0.64% | 130.909 us |       0.66% |  0.082 us |   0.06% |   PASS   |
|  I128   |      I32      |      2^28      |   2.030 ms |       0.14% |   2.031 ms |       0.14% |  0.725 us |   0.04% |   PASS   |
|  I128   |      I64      |      2^16      |   4.789 us |       3.89% |   4.788 us |       3.42% | -0.001 us |  -0.01% |   PASS   |
|  I128   |      I64      |      2^20      |  12.949 us |       2.35% |  13.150 us |       3.15% |  0.200 us |   1.55% |   PASS   |
|  I128   |      I64      |      2^24      | 131.666 us |       0.73% | 131.626 us |       0.74% | -0.040 us |  -0.03% |   PASS   |
|  I128   |      I64      |      2^28      |   2.032 ms |       0.41% |   2.033 ms |       0.41% |  0.144 us |   0.01% |   PASS   |

# add

## [0] NVIDIA H200

|  T{ct}  |  OffsetT{ct}  |  Elements{io}  |   Ref Time |   Ref Noise |   Cmp Time |   Cmp Noise |      Diff |   %Diff |  Status  |
|---------|---------------|----------------|------------|-------------|------------|-------------|-----------|---------|----------|
|   I8    |      I32      |      2^16      |   3.972 us |       2.69% |   4.055 us |       2.47% |  0.083 us |   2.09% |   PASS   |
|   I8    |      I32      |      2^20      |   7.763 us |       2.59% |   7.943 us |       2.18% |  0.179 us |   2.31% |   FAIL   |
|   I8    |      I32      |      2^24      |  26.937 us |       1.10% |  27.457 us |       1.05% |  0.521 us |   1.93% |   FAIL   |
|   I8    |      I32      |      2^28      | 300.779 us |       0.17% | 308.288 us |       0.16% |  7.509 us |   2.50% |   FAIL   |
|   I8    |      I64      |      2^16      |   4.136 us |       2.88% |   4.202 us |       2.87% |  0.066 us |   1.60% |   PASS   |
|   I8    |      I64      |      2^20      |   7.790 us |       2.26% |   7.934 us |       2.29% |  0.144 us |   1.84% |   PASS   |
|   I8    |      I64      |      2^24      |  27.023 us |       1.10% |  27.256 us |       1.10% |  0.234 us |   0.86% |   PASS   |
|   I8    |      I64      |      2^28      | 301.382 us |       0.11% | 308.000 us |       0.13% |  6.618 us |   2.20% |   FAIL   |
|   I16   |      I32      |      2^16      |   4.219 us |       2.79% |   4.305 us |       3.10% |  0.086 us |   2.03% |   PASS   |
|   I16   |      I32      |      2^20      |   7.193 us |       2.82% |   7.305 us |       2.80% |  0.112 us |   1.55% |   PASS   |
|   I16   |      I32      |      2^24      |  34.167 us |       1.71% |  34.630 us |       1.67% |  0.463 us |   1.36% |   PASS   |
|   I16   |      I32      |      2^28      | 444.715 us |       0.16% | 451.323 us |       0.15% |  6.607 us |   1.49% |   FAIL   |
|   I16   |      I64      |      2^16      |   4.242 us |       3.31% |   4.331 us |       3.11% |  0.090 us |   2.12% |   PASS   |
|   I16   |      I64      |      2^20      |   7.192 us |       2.76% |   7.344 us |       2.73% |  0.152 us |   2.12% |   PASS   |
|   I16   |      I64      |      2^24      |  34.275 us |       1.78% |  34.495 us |       1.73% |  0.220 us |   0.64% |   PASS   |
|   I16   |      I64      |      2^28      | 445.632 us |       0.15% | 449.233 us |       0.16% |  3.601 us |   0.81% |   FAIL   |
|   F32   |      I32      |      2^16      |   4.365 us |       2.98% |   4.601 us |       3.67% |  0.236 us |   5.40% |   FAIL   |
|   F32   |      I32      |      2^20      |   8.399 us |       2.74% |   8.559 us |       2.39% |  0.160 us |   1.90% |   PASS   |
|   F32   |      I32      |      2^24      |  54.969 us |       1.37% |  54.921 us |       1.37% | -0.048 us |  -0.09% |   PASS   |
|   F32   |      I32      |      2^28      | 785.761 us |       0.09% | 789.161 us |       0.09% |  3.400 us |   0.43% |   FAIL   |
|   F32   |      I64      |      2^16      |   4.429 us |       4.19% |   4.704 us |       4.52% |  0.275 us |   6.21% |   FAIL   |
|   F32   |      I64      |      2^20      |   8.373 us |       2.40% |   8.546 us |       2.48% |  0.173 us |   2.07% |   PASS   |
|   F32   |      I64      |      2^24      |  55.000 us |       1.37% |  55.306 us |       1.37% |  0.307 us |   0.56% |   PASS   |
|   F32   |      I64      |      2^28      | 786.775 us |       0.08% | 790.047 us |       0.08% |  3.272 us |   0.42% |   FAIL   |
|   F64   |      I32      |      2^16      |   4.706 us |       3.67% |   4.801 us |       4.52% |  0.095 us |   2.01% |   PASS   |
|   F64   |      I32      |      2^20      |  11.428 us |       3.41% |  11.495 us |       2.49% |  0.066 us |   0.58% |   PASS   |
|   F64   |      I32      |      2^24      | 100.121 us |       0.39% | 100.326 us |       0.39% |  0.205 us |   0.20% |   PASS   |
|   F64   |      I32      |      2^28      |   1.508 ms |       0.05% |   1.509 ms |       0.05% |  0.855 us |   0.06% |   FAIL   |
|   F64   |      I64      |      2^16      |   4.804 us |       4.27% |   4.931 us |       4.97% |  0.126 us |   2.63% |   PASS   |
|   F64   |      I64      |      2^20      |  11.316 us |       2.47% |  11.444 us |       2.55% |  0.128 us |   1.13% |   PASS   |
|   F64   |      I64      |      2^24      | 100.004 us |       0.41% | 100.244 us |       0.39% |  0.240 us |   0.24% |   PASS   |
|   F64   |      I64      |      2^28      |   1.509 ms |       0.05% |   1.509 ms |       0.06% |  0.658 us |   0.04% |   PASS   |
|  I128   |      I32      |      2^16      |   5.227 us |       4.38% |   5.329 us |       4.78% |  0.102 us |   1.95% |   PASS   |
|  I128   |      I32      |      2^20      |  17.633 us |       2.22% |  17.680 us |       2.28% |  0.047 us |   0.27% |   PASS   |
|  I128   |      I32      |      2^24      | 194.202 us |       0.29% | 194.335 us |       0.29% |  0.133 us |   0.07% |   PASS   |
|  I128   |      I32      |      2^28      |   3.008 ms |       0.09% |   3.009 ms |       0.09% |  1.169 us |   0.04% |   PASS   |
|  I128   |      I64      |      2^16      |   5.164 us |       3.50% |   5.179 us |       3.29% |  0.014 us |   0.28% |   PASS   |
|  I128   |      I64      |      2^20      |  17.804 us |       2.52% |  17.788 us |       2.36% | -0.016 us |  -0.09% |   PASS   |
|  I128   |      I64      |      2^24      | 195.061 us |       0.36% | 195.180 us |       0.38% |  0.119 us |   0.06% |   PASS   |
|  I128   |      I64      |      2^28      |   3.010 ms |       0.23% |   3.009 ms |       0.21% | -0.449 us |  -0.01% |   PASS   |

# triad

## [0] NVIDIA H200

|  T{ct}  |  OffsetT{ct}  |  Elements{io}  |   Ref Time |   Ref Noise |   Cmp Time |   Cmp Noise |      Diff |   %Diff |  Status  |
|---------|---------------|----------------|------------|-------------|------------|-------------|-----------|---------|----------|
|   I8    |      I32      |      2^16      |   4.246 us |       3.71% |   4.226 us |       3.63% | -0.019 us |  -0.45% |   PASS   |
|   I8    |      I32      |      2^20      |   7.839 us |       2.64% |   7.915 us |       2.36% |  0.075 us |   0.96% |   PASS   |
|   I8    |      I32      |      2^24      |  26.018 us |       1.33% |  26.417 us |       1.25% |  0.399 us |   1.53% |   FAIL   |
|   I8    |      I32      |      2^28      | 286.034 us |       0.15% | 294.496 us |       0.13% |  8.461 us |   2.96% |   FAIL   |
|   I8    |      I64      |      2^16      |   4.264 us |       3.84% |   4.301 us |       3.56% |  0.038 us |   0.88% |   PASS   |
|   I8    |      I64      |      2^20      |   7.958 us |       3.06% |   8.046 us |       3.02% |  0.088 us |   1.10% |   PASS   |
|   I8    |      I64      |      2^24      |  25.732 us |       1.25% |  25.935 us |       1.30% |  0.202 us |   0.79% |   PASS   |
|   I8    |      I64      |      2^28      | 314.272 us |       0.12% | 316.872 us |       0.21% |  2.600 us |   0.83% |   FAIL   |
|   I16   |      I32      |      2^16      |   4.327 us |       2.76% |   4.331 us |       3.10% |  0.004 us |   0.09% |   PASS   |
|   I16   |      I32      |      2^20      |   7.305 us |       2.79% |   7.410 us |       2.68% |  0.105 us |   1.44% |   PASS   |
|   I16   |      I32      |      2^24      |  34.765 us |       1.67% |  34.117 us |       1.75% | -0.647 us |  -1.86% |   FAIL   |
|   I16   |      I32      |      2^28      | 425.006 us |       0.17% | 430.051 us |       0.17% |  5.045 us |   1.19% |   FAIL   |
|   I16   |      I64      |      2^16      |   4.373 us |       3.45% |   4.388 us |       3.47% |  0.015 us |   0.34% |   PASS   |
|   I16   |      I64      |      2^20      |   7.401 us |       3.11% |   7.397 us |       2.74% | -0.003 us |  -0.04% |   PASS   |
|   I16   |      I64      |      2^24      |  33.826 us |       1.79% |  33.999 us |       1.75% |  0.174 us |   0.51% |   PASS   |
|   I16   |      I64      |      2^28      | 425.722 us |       0.17% | 428.301 us |       0.16% |  2.579 us |   0.61% |   FAIL   |
|   F32   |      I32      |      2^16      |   4.458 us |       3.71% |   4.407 us |       3.56% | -0.051 us |  -1.14% |   PASS   |
|   F32   |      I32      |      2^20      |   8.391 us |       2.70% |   8.450 us |       2.82% |  0.059 us |   0.71% |   PASS   |
|   F32   |      I32      |      2^24      |  55.046 us |       1.33% |  55.285 us |       1.34% |  0.239 us |   0.43% |   PASS   |
|   F32   |      I32      |      2^28      | 773.132 us |       0.12% | 775.959 us |       0.11% |  2.828 us |   0.37% |   FAIL   |
|   F32   |      I64      |      2^16      |   4.484 us |       4.79% |   4.493 us |       4.85% |  0.009 us |   0.20% |   PASS   |
|   F32   |      I64      |      2^20      |   8.401 us |       2.66% |   8.494 us |       2.66% |  0.093 us |   1.11% |   PASS   |
|   F32   |      I64      |      2^24      |  55.101 us |       1.32% |  55.349 us |       1.32% |  0.248 us |   0.45% |   PASS   |
|   F32   |      I64      |      2^28      | 774.147 us |       0.12% | 776.581 us |       0.11% |  2.434 us |   0.31% |   FAIL   |
|   F64   |      I32      |      2^16      |   4.827 us |       3.95% |   4.810 us |       3.70% | -0.017 us |  -0.35% |   PASS   |
|   F64   |      I32      |      2^20      |  11.340 us |       1.35% |  11.445 us |       2.54% |  0.105 us |   0.92% |   PASS   |
|   F64   |      I32      |      2^24      | 100.355 us |       0.43% | 100.262 us |       0.44% | -0.094 us |  -0.09% |   PASS   |
|   F64   |      I32      |      2^28      |   1.509 ms |       0.05% |   1.509 ms |       0.06% |  0.636 us |   0.04% |   PASS   |
|   F64   |      I64      |      2^16      |   4.903 us |       5.42% |   4.947 us |       5.01% |  0.044 us |   0.90% |   PASS   |
|   F64   |      I64      |      2^20      |  11.371 us |       1.66% |  11.553 us |       3.20% |  0.182 us |   1.60% |   PASS   |
|   F64   |      I64      |      2^24      | 100.220 us |       0.43% | 100.266 us |       0.43% |  0.046 us |   0.05% |   PASS   |
|   F64   |      I64      |      2^28      |   1.509 ms |       0.06% |   1.509 ms |       0.06% |  0.494 us |   0.03% |   PASS   |
|  I128   |      I32      |      2^16      |   5.422 us |       4.90% |   5.422 us |       5.07% |  0.000 us |   0.00% |   PASS   |
|  I128   |      I32      |      2^20      |  17.905 us |       2.25% |  17.918 us |       2.14% |  0.013 us |   0.07% |   PASS   |
|  I128   |      I32      |      2^24      | 194.422 us |       0.30% | 194.603 us |       0.27% |  0.181 us |   0.09% |   PASS   |
|  I128   |      I32      |      2^28      |   3.007 ms |       0.09% |   3.010 ms |       0.09% |  2.715 us |   0.09% |   FAIL   |
|  I128   |      I64      |      2^16      |   5.202 us |       3.37% |   5.235 us |       4.08% |  0.032 us |   0.62% |   PASS   |
|  I128   |      I64      |      2^20      |  17.851 us |       2.39% |  17.837 us |       2.32% | -0.014 us |  -0.08% |   PASS   |
|  I128   |      I64      |      2^24      | 194.996 us |       0.36% | 195.120 us |       0.34% |  0.124 us |   0.06% |   PASS   |
|  I128   |      I64      |      2^28      |   3.007 ms |       0.19% |   3.009 ms |       0.21% |  2.588 us |   0.09% |   PASS   |

# nstream

## [0] NVIDIA H200

|  T{ct}  |  OffsetT{ct}  |  Elements{io}  |  OverwriteInput  |   Ref Time |   Ref Noise |   Cmp Time |   Cmp Noise |       Diff |   %Diff |  Status  |
|---------|---------------|----------------|------------------|------------|-------------|------------|-------------|------------|---------|----------|
|   I8    |      I32      |      2^16      |        1         |   4.153 us |       3.08% |   4.072 us |       3.18% |  -0.081 us |  -1.95% |   PASS   |
|   I8    |      I32      |      2^20      |        1         |   7.199 us |       2.68% |   7.263 us |       2.30% |   0.065 us |   0.90% |   PASS   |
|   I8    |      I32      |      2^24      |        1         |  30.769 us |       1.27% |  31.748 us |       1.16% |   0.979 us |   3.18% |   FAIL   |
|   I8    |      I32      |      2^28      |        1         | 367.426 us |       0.15% | 387.490 us |       0.11% |  20.065 us |   5.46% |   FAIL   |
|   I8    |      I64      |      2^16      |        1         |   4.308 us |       2.83% |   4.194 us |       2.96% |  -0.114 us |  -2.64% |   PASS   |
|   I8    |      I64      |      2^20      |        1         |   7.247 us |       2.46% |   7.133 us |       2.63% |  -0.114 us |  -1.58% |   PASS   |
|   I8    |      I64      |      2^24      |        1         |  30.870 us |       1.22% |  30.982 us |       1.22% |   0.112 us |   0.36% |   PASS   |
|   I8    |      I64      |      2^28      |        1         | 369.782 us |       0.23% | 373.844 us |       0.13% |   4.062 us |   1.10% |   FAIL   |
|   I16   |      I32      |      2^16      |        1         |   4.440 us |       3.49% |   4.359 us |       2.92% |  -0.081 us |  -1.84% |   PASS   |
|   I16   |      I32      |      2^20      |        1         |   7.498 us |       2.97% |   7.485 us |       3.08% |  -0.013 us |  -0.18% |   PASS   |
|   I16   |      I32      |      2^24      |        1         |  42.648 us |       1.66% |  43.014 us |       1.60% |   0.367 us |   0.86% |   PASS   |
|   I16   |      I32      |      2^28      |        1         | 562.846 us |       0.13% | 572.306 us |       0.13% |   9.460 us |   1.68% |   FAIL   |
|   I16   |      I64      |      2^16      |        1         |   4.472 us |       3.73% |   4.376 us |       3.08% |  -0.096 us |  -2.14% |   PASS   |
|   I16   |      I64      |      2^20      |        1         |   7.528 us |       2.69% |   7.452 us |       3.06% |  -0.077 us |  -1.02% |   PASS   |
|   I16   |      I64      |      2^24      |        1         |  42.431 us |       1.67% |  42.550 us |       1.63% |   0.119 us |   0.28% |   PASS   |
|   I16   |      I64      |      2^28      |        1         | 560.706 us |       0.14% | 563.724 us |       0.09% |   3.018 us |   0.54% |   FAIL   |
|   F32   |      I32      |      2^16      |        1         |   4.452 us |       3.22% |   4.576 us |       3.21% |   0.124 us |   2.79% |   PASS   |
|   F32   |      I32      |      2^20      |        1         |   9.191 us |       2.85% |   9.327 us |       2.67% |   0.136 us |   1.48% |   PASS   |
|   F32   |      I32      |      2^24      |        1         |  70.198 us |       1.07% |  71.007 us |       1.02% |   0.809 us |   1.15% |   FAIL   |
|   F32   |      I32      |      2^28      |        1         |   1.021 ms |       0.21% |   1.036 ms |       0.19% |  15.479 us |   1.52% |   FAIL   |
|   F32   |      I64      |      2^16      |        1         |   4.519 us |       3.78% |   4.652 us |       4.58% |   0.133 us |   2.95% |   PASS   |
|   F32   |      I64      |      2^20      |        1         |   9.242 us |       2.80% |   9.355 us |       2.63% |   0.113 us |   1.22% |   PASS   |
|   F32   |      I64      |      2^24      |        1         |  70.674 us |       1.04% |  71.199 us |       1.02% |   0.524 us |   0.74% |   PASS   |
|   F32   |      I64      |      2^28      |        1         |   1.031 ms |       0.21% |   1.039 ms |       0.20% |   8.410 us |   0.82% |   FAIL   |
|   F64   |      I32      |      2^16      |        1         |   4.864 us |       4.07% |   5.022 us |       4.96% |   0.158 us |   3.25% |   PASS   |
|   F64   |      I32      |      2^20      |        1         |  13.639 us |       2.16% |  13.737 us |       1.61% |   0.098 us |   0.72% |   PASS   |
|   F64   |      I32      |      2^24      |        1         | 129.316 us |       0.65% | 133.768 us |       0.74% |   4.452 us |   3.44% |   FAIL   |
|   F64   |      I32      |      2^28      |        1         |   1.952 ms |       0.21% |   2.043 ms |       0.16% |  91.277 us |   4.68% |   FAIL   |
|   F64   |      I64      |      2^16      |        1         |   4.984 us |       4.54% |   5.129 us |       4.96% |   0.144 us |   2.90% |   PASS   |
|   F64   |      I64      |      2^20      |        1         |  13.639 us |       1.81% |  13.771 us |       1.87% |   0.132 us |   0.97% |   PASS   |
|   F64   |      I64      |      2^24      |        1         | 130.536 us |       0.64% | 133.817 us |       0.74% |   3.280 us |   2.51% |   FAIL   |
|   F64   |      I64      |      2^28      |        1         |   1.982 ms |       0.10% |   2.040 ms |       0.22% |  57.276 us |   2.89% |   FAIL   |
|  I128   |      I32      |      2^16      |        1         |   5.508 us |       4.58% |   5.697 us |       5.21% |   0.189 us |   3.43% |   PASS   |
|  I128   |      I32      |      2^20      |        1         |  21.720 us |       2.16% |  21.817 us |       2.20% |   0.097 us |   0.45% |   PASS   |
|  I128   |      I32      |      2^24      |        1         | 251.228 us |       0.40% | 258.446 us |       0.41% |   7.218 us |   2.87% |   FAIL   |
|  I128   |      I32      |      2^28      |        1         |   3.907 ms |       0.09% |   4.042 ms |       0.08% | 135.477 us |   3.47% |   FAIL   |
|  I128   |      I64      |      2^16      |        1         |   5.471 us |       3.86% |   5.409 us |       3.47% |  -0.063 us |  -1.15% |   PASS   |
|  I128   |      I64      |      2^20      |        1         |  21.859 us |       2.20% |  21.719 us |       2.19% |  -0.140 us |  -0.64% |   PASS   |
|  I128   |      I64      |      2^24      |        1         | 260.256 us |       0.38% | 258.323 us |       0.43% |  -1.933 us |  -0.74% |   FAIL   |
|  I128   |      I64      |      2^28      |        1         |   4.074 ms |       0.08% |   4.043 ms |       0.09% | -30.665 us |  -0.75% |   FAIL   |

# Summary

- Total Matches: 160
  - Pass    (diff <= min_noise): 117
  - Unknown (infinite noise):    0
  - Failure (diff > min_noise):  43

While the change seems to not matter in most cases, I see a few regressions, especially on the large problem sizes and with nstream.

{
PREFETCH_AGENT(true);
}
else
{
PREFETCH_AGENT(false);
}
#undef PREFETCH_AGENT
}

template <int BlockThreads>
struct async_copy_policy_t
{
Expand Down Expand Up @@ -173,16 +284,6 @@ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE constexpr auto round_up_to_po2_multiple(Inte
return (x + mult - 1) & ~(mult - 1);
}

template <typename T>
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE const char* round_down_ptr(const T* ptr, unsigned alignment)
{
#if _CCCL_STD_VER > 2011
_CCCL_ASSERT(::cuda::std::has_single_bit(alignment), "");
#endif // _CCCL_STD_VER > 2011
return reinterpret_cast<const char*>(
reinterpret_cast<::cuda::std::uintptr_t>(ptr) & ~::cuda::std::uintptr_t{alignment - 1});
}

// Implementation notes on memcpy_async and UBLKCP kernels regarding copy alignment and padding
//
// For performance considerations of memcpy_async:
Expand Down Expand Up @@ -543,8 +644,8 @@ struct policy_hub<RequiresStableAddress, ::cuda::std::tuple<RandomAccessIterator
{
static constexpr int min_bif = arch_to_min_bytes_in_flight(300);
// TODO(bgruber): we don't need algo, because we can just detect the type of algo_policy
static constexpr auto algorithm = Algorithm::fallback_for;
using algo_policy = fallback_for_policy;
static constexpr auto algorithm = Algorithm::prefetch;
using algo_policy = prefetch_policy_t<256>;
};

#ifdef _CUB_HAS_TRANSFORM_UBLKCP
Expand All @@ -566,8 +667,8 @@ struct policy_hub<RequiresStableAddress, ::cuda::std::tuple<RandomAccessIterator

static constexpr bool use_fallback =
RequiresStableAddress || !can_memcpy || no_input_streams || exhaust_smem || any_type_is_overalinged;
static constexpr auto algorithm = use_fallback ? Algorithm::fallback_for : Algorithm::ublkcp;
using algo_policy = ::cuda::std::_If<use_fallback, fallback_for_policy, async_policy>;
static constexpr auto algorithm = use_fallback ? Algorithm::prefetch : Algorithm::ublkcp;
using algo_policy = ::cuda::std::_If<use_fallback, prefetch_policy_t<256>, async_policy>;
};

using max_policy = policy900;
Expand Down Expand Up @@ -647,13 +748,38 @@ _CCCL_HOST_DEVICE inline PoorExpected<int> get_max_shared_memory()
return max_smem;
}

_CCCL_HOST_DEVICE inline PoorExpected<int> get_sm_count()
{
int device = 0;
auto error = CubDebug(cudaGetDevice(&device));
if (error != cudaSuccess)
{
return error;
}

int sm_count = 0;
error = CubDebug(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device));
if (error != cudaSuccess)
{
return error;
}

return sm_count;
}

struct elem_counts
{
int elem_per_thread;
int tile_size;
int smem_size;
};

struct prefetch_config
{
int max_occupancy;
int sm_count;
};

template <bool RequiresStableAddress,
typename Offset,
typename RandomAccessIteratorTupleIn,
Expand Down Expand Up @@ -758,15 +884,11 @@ struct dispatch_t<RequiresStableAddress,
return last_counts;
};
PoorExpected<elem_counts> config = [&]() {
NV_IF_TARGET(
NV_IS_HOST,
(
// this static variable exists for each template instantiation of the surrounding function and class, on which
// the chosen element count solely depends (assuming max SMEM is constant during a program execution)
static auto cached_config = determine_element_counts(); return cached_config;),
(
// we cannot cache the determined element count in device code
return determine_element_counts();));
NV_IF_TARGET(NV_IS_HOST,
(static auto cached_config = determine_element_counts(); return cached_config;),
(
// we cannot cache the determined element count in device code
return determine_element_counts();));
}();
if (!config)
{
Expand Down Expand Up @@ -828,6 +950,68 @@ struct dispatch_t<RequiresStableAddress,
make_iterator_kernel_arg(THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(::cuda::std::get<Is>(in)))...));
}

template <typename ActivePolicy, std::size_t... Is>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t
invoke_algorithm(cuda::std::index_sequence<Is...>, ::cuda::std::integral_constant<Algorithm, Algorithm::prefetch>)
{
using policy_t = typename ActivePolicy::algo_policy;
constexpr int block_dim = policy_t::block_threads;

auto determine_config = [&]() -> PoorExpected<prefetch_config> {
int max_occupancy = 0;
const auto error = CubDebug(MaxSmOccupancy(max_occupancy, CUB_DETAIL_TRANSFORM_KERNEL_PTR, block_dim, 0));
if (error != cudaSuccess)
{
return error;
}
const auto sm_count = get_sm_count();
if (!sm_count)
{
return sm_count.error;
}
return prefetch_config{max_occupancy, *sm_count};
};
Comment on lines +960 to +973
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this rather be a struct than a lambda, afaik the only thing that is needed is block_dim which is a static property

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discusses offline. We could move the lambda outside the surrounding function, but determined it's more a style question than correctness. We then dediced to leave it as is to stay consistent with the way the ublkcp kernel configuration is written.


PoorExpected<prefetch_config> config = [&]() {
NV_IF_TARGET(
NV_IS_HOST,
(
// this static variable exists for each template instantiation of the surrounding function and class, on which
// the chosen element count solely depends (assuming max SMEM is constant during a program execution)
static auto cached_config = determine_config(); return cached_config;),
(
// we cannot cache the determined element count in device code
return determine_config();));
}();
if (!config)
{
return config.error;
}

const int items_per_thread =
loaded_bytes_per_iter == 0
? +policy_t::items_per_thread_no_input
: ::cuda::ceil_div(ActivePolicy::min_bif, config->max_occupancy * block_dim * loaded_bytes_per_iter);

// Generate at least one block per SM. This improves tiny problem sizes (e.g. 2^16 elements).
const int items_per_thread_evenly_spread =
static_cast<int>(::cuda::std::min(Offset{items_per_thread}, num_items / (config->sm_count * block_dim)));

const int items_per_thread_clamped = ::cuda::std::clamp(
items_per_thread_evenly_spread, +policy_t::min_items_per_thread, +policy_t::max_items_per_thread);
const int tile_size = block_dim * items_per_thread_clamped;
const auto grid_dim = static_cast<unsigned int>(::cuda::ceil_div(num_items, Offset{tile_size}));
return CubDebug(
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(grid_dim, block_dim, 0, stream)
.doit(
CUB_DETAIL_TRANSFORM_KERNEL_PTR,
num_items,
items_per_thread_clamped,
op,
out,
make_iterator_kernel_arg(THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(::cuda::std::get<Is>(in)))...));
}

template <typename ActivePolicy>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke()
{
Expand Down
7 changes: 5 additions & 2 deletions cub/test/catch2_test_device_transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ struct policy_hub_for_alg
using algo_policy =
::cuda::std::_If<Alg == Algorithm::fallback_for,
cub::detail::transform::fallback_for_policy,
cub::detail::transform::async_copy_policy_t<256>>;
::cuda::std::_If<Alg == Algorithm::prefetch,
cub::detail::transform::prefetch_policy_t<256>,
cub::detail::transform::async_copy_policy_t<256>>>;
};
};

Expand Down Expand Up @@ -77,7 +79,8 @@ DECLARE_TMPL_LAUNCH_WRAPPER(transform_many_with_alg_entry_point,

using algorithms =
c2h::enum_type_list<Algorithm,
Algorithm::fallback_for
Algorithm::fallback_for,
Algorithm::prefetch
#ifdef _CUB_HAS_TRANSFORM_UBLKCP
,
Algorithm::ublkcp
Expand Down
Loading