-
Notifications
You must be signed in to change notification settings - Fork 163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prefetching kernel as new fallback for cub::DeviceTransform
#2396
Changes from all commits
65605a2
67080eb
4e4d57a
15d30e8
da08eec
0f5d523
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,7 @@ _CCCL_HOST_DEVICE constexpr auto loaded_bytes_per_iteration() -> int | |
enum class Algorithm | ||
{ | ||
fallback_for, | ||
prefetch, | ||
#ifdef _CUB_HAS_TRANSFORM_UBLKCP | ||
ublkcp, | ||
#endif // _CUB_HAS_TRANSFORM_UBLKCP | ||
|
@@ -133,6 +134,116 @@ _CCCL_DEVICE void transform_kernel_impl( | |
} | ||
} | ||
|
||
template <typename T> | ||
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE const char* round_down_ptr(const T* ptr, unsigned alignment) | ||
{ | ||
#if _CCCL_STD_VER > 2011 | ||
_CCCL_ASSERT(::cuda::std::has_single_bit(alignment), ""); | ||
#endif // _CCCL_STD_VER > 2011 | ||
return reinterpret_cast<const char*>( | ||
reinterpret_cast<::cuda::std::uintptr_t>(ptr) & ~::cuda::std::uintptr_t{alignment - 1}); | ||
} | ||
|
||
template <int BlockThreads> | ||
struct prefetch_policy_t | ||
{ | ||
static constexpr int block_threads = BlockThreads; | ||
// items per tile are determined at runtime. these (inclusive) bounds allow overriding that value via a tuning policy | ||
static constexpr int items_per_thread_no_input = 2; // when there are no input iterators, the kernel is just filling | ||
static constexpr int min_items_per_thread = 1; | ||
static constexpr int max_items_per_thread = 32; | ||
}; | ||
|
||
// Prefetches (at least on Hopper) a 128 byte cache line. Prefetching out-of-bounds addresses has no side effects | ||
// TODO(bgruber): there is also the cp.async.bulk.prefetch instruction available on Hopper. May improve perf a tiny bit | ||
// as we need to create less instructions to prefetch the same amount of data. | ||
template <typename T> | ||
_CCCL_DEVICE _CCCL_FORCEINLINE void prefetch(const T* addr) | ||
{ | ||
// TODO(bgruber): prefetch to L1 may be even better | ||
asm volatile("prefetch.global.L2 [%0];" : : "l"(__cvta_generic_to_global(addr)) : "memory"); | ||
} | ||
|
||
template <int BlockDim, typename T> | ||
_CCCL_DEVICE _CCCL_FORCEINLINE void prefetch_tile(const T* addr, int tile_size) | ||
{ | ||
constexpr int prefetch_byte_stride = 128; // TODO(bgruber): should correspond to cache line size. Does this need to be | ||
// architecture dependent? | ||
const int tile_size_bytes = tile_size * sizeof(T); | ||
// prefetch does not stall and unrolling just generates a lot of unnecessary computations and predicate handling | ||
#pragma unroll 1 | ||
for (int offset = threadIdx.x * prefetch_byte_stride; offset < tile_size_bytes; | ||
offset += BlockDim * prefetch_byte_stride) | ||
{ | ||
prefetch(reinterpret_cast<const char*>(addr) + offset); | ||
} | ||
} | ||
|
||
// TODO(miscco): we should probably constrain It to not be a contiguous iterator in C++17 (and change the overload | ||
// above to accept any contiguous iterator) | ||
// overload for any iterator that is not a pointer, do nothing | ||
template <int, typename It, ::cuda::std::__enable_if_t<!::cuda::std::is_pointer<It>::value, int> = 0> | ||
_CCCL_DEVICE _CCCL_FORCEINLINE void prefetch_tile(It, int) | ||
{} | ||
|
||
// This kernel guarantees that objects passed as arguments to the user-provided transformation function f reside in | ||
// global memory. No intermediate copies are taken. If the parameter type of f is a reference, taking the address of the | ||
// parameter yields a global memory address. | ||
template <typename PrefetchPolicy, | ||
typename Offset, | ||
typename F, | ||
typename RandomAccessIteratorOut, | ||
typename... RandomAccessIteratorIn> | ||
_CCCL_DEVICE void transform_kernel_impl( | ||
::cuda::std::integral_constant<Algorithm, Algorithm::prefetch>, | ||
Offset num_items, | ||
int num_elem_per_thread, | ||
F f, | ||
RandomAccessIteratorOut out, | ||
RandomAccessIteratorIn... ins) | ||
{ | ||
constexpr int block_dim = PrefetchPolicy::block_threads; | ||
const int tile_stride = block_dim * num_elem_per_thread; | ||
const Offset offset = static_cast<Offset>(blockIdx.x) * tile_stride; | ||
const int tile_size = static_cast<int>(::cuda::std::min(num_items - offset, Offset{tile_stride})); | ||
|
||
// move index and iterator domain to the block/thread index, to reduce arithmetic in the loops below | ||
{ | ||
int dummy[] = {(ins += offset, 0)..., 0}; | ||
(void) &dummy; | ||
out += offset; | ||
} | ||
|
||
{ | ||
// TODO(bgruber): replace by fold over comma in C++17 | ||
int dummy[] = {(prefetch_tile<block_dim>(ins, tile_size), 0)..., 0}; // extra zero to handle empty packs | ||
(void) &dummy; // nvcc 11.1 needs extra strong unused warning suppression | ||
} | ||
|
||
#define PREFETCH_AGENT(full_tile) \ | ||
/* ahendriksen: various unrolling yields less <1% gains at much higher compile-time cost */ \ | ||
/* bgruber: but A6000 and H100 show small gains without pragma */ \ | ||
/*_Pragma("unroll 1")*/ for (int j = 0; j < num_elem_per_thread; ++j) \ | ||
{ \ | ||
const int idx = j * block_dim + threadIdx.x; \ | ||
if (full_tile || idx < tile_size) \ | ||
{ \ | ||
/* we have to unwrap Thrust's proxy references here for backward compatibility (try zip_iterator.cu test) */ \ | ||
out[idx] = f(THRUST_NS_QUALIFIER::raw_reference_cast(ins[idx])...); \ | ||
} \ | ||
} | ||
|
||
if (tile_stride == tile_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: does the split between full_tile and partial tiles make a big difference in run time? I expect that the difference shouldn't be too big, as the partial tile version only adds one comparison and predicates the rest of the computation. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is a good question! I followed existing practice here, but I see your point. I will evaluate! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is a benchmark on H200: Babelstream on H200 two code paths (baseline) vs. single code path
While the change seems to not matter in most cases, I see a few regressions, especially on the large problem sizes and with nstream. |
||
{ | ||
PREFETCH_AGENT(true); | ||
} | ||
else | ||
{ | ||
PREFETCH_AGENT(false); | ||
} | ||
#undef PREFETCH_AGENT | ||
} | ||
|
||
template <int BlockThreads> | ||
struct async_copy_policy_t | ||
{ | ||
|
@@ -173,16 +284,6 @@ _CCCL_HOST_DEVICE _CCCL_FORCEINLINE constexpr auto round_up_to_po2_multiple(Inte | |
return (x + mult - 1) & ~(mult - 1); | ||
} | ||
|
||
template <typename T> | ||
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE const char* round_down_ptr(const T* ptr, unsigned alignment) | ||
{ | ||
#if _CCCL_STD_VER > 2011 | ||
_CCCL_ASSERT(::cuda::std::has_single_bit(alignment), ""); | ||
#endif // _CCCL_STD_VER > 2011 | ||
return reinterpret_cast<const char*>( | ||
reinterpret_cast<::cuda::std::uintptr_t>(ptr) & ~::cuda::std::uintptr_t{alignment - 1}); | ||
} | ||
|
||
// Implementation notes on memcpy_async and UBLKCP kernels regarding copy alignment and padding | ||
// | ||
// For performance considerations of memcpy_async: | ||
|
@@ -543,8 +644,8 @@ struct policy_hub<RequiresStableAddress, ::cuda::std::tuple<RandomAccessIterator | |
{ | ||
static constexpr int min_bif = arch_to_min_bytes_in_flight(300); | ||
// TODO(bgruber): we don't need algo, because we can just detect the type of algo_policy | ||
static constexpr auto algorithm = Algorithm::fallback_for; | ||
using algo_policy = fallback_for_policy; | ||
static constexpr auto algorithm = Algorithm::prefetch; | ||
using algo_policy = prefetch_policy_t<256>; | ||
}; | ||
|
||
#ifdef _CUB_HAS_TRANSFORM_UBLKCP | ||
|
@@ -566,8 +667,8 @@ struct policy_hub<RequiresStableAddress, ::cuda::std::tuple<RandomAccessIterator | |
|
||
static constexpr bool use_fallback = | ||
RequiresStableAddress || !can_memcpy || no_input_streams || exhaust_smem || any_type_is_overalinged; | ||
static constexpr auto algorithm = use_fallback ? Algorithm::fallback_for : Algorithm::ublkcp; | ||
using algo_policy = ::cuda::std::_If<use_fallback, fallback_for_policy, async_policy>; | ||
static constexpr auto algorithm = use_fallback ? Algorithm::prefetch : Algorithm::ublkcp; | ||
using algo_policy = ::cuda::std::_If<use_fallback, prefetch_policy_t<256>, async_policy>; | ||
}; | ||
|
||
using max_policy = policy900; | ||
|
@@ -647,13 +748,38 @@ _CCCL_HOST_DEVICE inline PoorExpected<int> get_max_shared_memory() | |
return max_smem; | ||
} | ||
|
||
_CCCL_HOST_DEVICE inline PoorExpected<int> get_sm_count() | ||
{ | ||
int device = 0; | ||
auto error = CubDebug(cudaGetDevice(&device)); | ||
if (error != cudaSuccess) | ||
{ | ||
return error; | ||
} | ||
|
||
int sm_count = 0; | ||
error = CubDebug(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device)); | ||
if (error != cudaSuccess) | ||
{ | ||
return error; | ||
} | ||
|
||
return sm_count; | ||
} | ||
|
||
struct elem_counts | ||
{ | ||
int elem_per_thread; | ||
int tile_size; | ||
int smem_size; | ||
}; | ||
|
||
struct prefetch_config | ||
{ | ||
int max_occupancy; | ||
int sm_count; | ||
}; | ||
|
||
template <bool RequiresStableAddress, | ||
typename Offset, | ||
typename RandomAccessIteratorTupleIn, | ||
|
@@ -758,15 +884,11 @@ struct dispatch_t<RequiresStableAddress, | |
return last_counts; | ||
}; | ||
PoorExpected<elem_counts> config = [&]() { | ||
NV_IF_TARGET( | ||
NV_IS_HOST, | ||
( | ||
// this static variable exists for each template instantiation of the surrounding function and class, on which | ||
// the chosen element count solely depends (assuming max SMEM is constant during a program execution) | ||
static auto cached_config = determine_element_counts(); return cached_config;), | ||
( | ||
// we cannot cache the determined element count in device code | ||
return determine_element_counts();)); | ||
NV_IF_TARGET(NV_IS_HOST, | ||
(static auto cached_config = determine_element_counts(); return cached_config;), | ||
( | ||
// we cannot cache the determined element count in device code | ||
return determine_element_counts();)); | ||
}(); | ||
if (!config) | ||
{ | ||
|
@@ -828,6 +950,68 @@ struct dispatch_t<RequiresStableAddress, | |
make_iterator_kernel_arg(THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(::cuda::std::get<Is>(in)))...)); | ||
} | ||
|
||
template <typename ActivePolicy, std::size_t... Is> | ||
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t | ||
invoke_algorithm(cuda::std::index_sequence<Is...>, ::cuda::std::integral_constant<Algorithm, Algorithm::prefetch>) | ||
{ | ||
using policy_t = typename ActivePolicy::algo_policy; | ||
constexpr int block_dim = policy_t::block_threads; | ||
|
||
auto determine_config = [&]() -> PoorExpected<prefetch_config> { | ||
int max_occupancy = 0; | ||
const auto error = CubDebug(MaxSmOccupancy(max_occupancy, CUB_DETAIL_TRANSFORM_KERNEL_PTR, block_dim, 0)); | ||
if (error != cudaSuccess) | ||
{ | ||
return error; | ||
} | ||
const auto sm_count = get_sm_count(); | ||
if (!sm_count) | ||
{ | ||
return sm_count.error; | ||
} | ||
return prefetch_config{max_occupancy, *sm_count}; | ||
}; | ||
Comment on lines
+960
to
+973
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this rather be a struct than a lambda, afaik the only thing that is needed is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Discusses offline. We could move the lambda outside the surrounding function, but determined it's more a style question than correctness. We then dediced to leave it as is to stay consistent with the way the ublkcp kernel configuration is written. |
||
|
||
PoorExpected<prefetch_config> config = [&]() { | ||
NV_IF_TARGET( | ||
NV_IS_HOST, | ||
( | ||
// this static variable exists for each template instantiation of the surrounding function and class, on which | ||
// the chosen element count solely depends (assuming max SMEM is constant during a program execution) | ||
static auto cached_config = determine_config(); return cached_config;), | ||
( | ||
// we cannot cache the determined element count in device code | ||
return determine_config();)); | ||
}(); | ||
if (!config) | ||
{ | ||
return config.error; | ||
} | ||
|
||
const int items_per_thread = | ||
loaded_bytes_per_iter == 0 | ||
? +policy_t::items_per_thread_no_input | ||
: ::cuda::ceil_div(ActivePolicy::min_bif, config->max_occupancy * block_dim * loaded_bytes_per_iter); | ||
|
||
// Generate at least one block per SM. This improves tiny problem sizes (e.g. 2^16 elements). | ||
const int items_per_thread_evenly_spread = | ||
static_cast<int>(::cuda::std::min(Offset{items_per_thread}, num_items / (config->sm_count * block_dim))); | ||
|
||
const int items_per_thread_clamped = ::cuda::std::clamp( | ||
items_per_thread_evenly_spread, +policy_t::min_items_per_thread, +policy_t::max_items_per_thread); | ||
const int tile_size = block_dim * items_per_thread_clamped; | ||
const auto grid_dim = static_cast<unsigned int>(::cuda::ceil_div(num_items, Offset{tile_size})); | ||
return CubDebug( | ||
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(grid_dim, block_dim, 0, stream) | ||
.doit( | ||
CUB_DETAIL_TRANSFORM_KERNEL_PTR, | ||
num_items, | ||
items_per_thread_clamped, | ||
op, | ||
out, | ||
make_iterator_kernel_arg(THRUST_NS_QUALIFIER::try_unwrap_contiguous_iterator(::cuda::std::get<Is>(in)))...)); | ||
} | ||
|
||
template <typename ActivePolicy> | ||
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke() | ||
{ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: We could use
cuda::std::contiguous_iterator
in C++17 onwards.Definitely something for a followup
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a note.