From 66943b4e9f198ddf93dc405d76a3c4d553a18438 Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Thu, 5 Sep 2024 10:38:24 +0200 Subject: [PATCH] Add memcpy_async transform kernel for A100 Fixes: #2361 --- .../bench/transform/babelstream1.cu | 6 +- .../bench/transform/babelstream2.cu | 6 +- .../bench/transform/babelstream3.cu | 6 +- .../device/dispatch/dispatch_transform.cuh | 288 +++++++++++++++++- cub/test/catch2_test_device_transform.cu | 7 +- 5 files changed, 300 insertions(+), 13 deletions(-) diff --git a/cub/benchmarks/bench/transform/babelstream1.cu b/cub/benchmarks/bench/transform/babelstream1.cu index 87abdfef6ff..3b62e0cf708 100644 --- a/cub/benchmarks/bench/transform/babelstream1.cu +++ b/cub/benchmarks/bench/transform/babelstream1.cu @@ -2,15 +2,15 @@ // SPDX-License-Identifier: BSD-3-Clause // %RANGE% TUNE_THREADS tpb 128:1024:128 -// %RANGE% TUNE_ALGORITHM alg 0:1:1 +// %RANGE% TUNE_ALGORITHM alg 0:2:1 // keep checks at the top so compilation of discarded variants fails really fast #if !TUNE_BASE -# if TUNE_ALGORITHM == 1 && (__CUDA_ARCH_LIST__) < 900 +# if TUNE_ALGORITHM == 2 && (__CUDA_ARCH_LIST__) < 900 # error "Cannot compile algorithm 4 (ublkcp) below sm90" # endif -# if TUNE_ALGORITHM == 1 && !defined(_CUB_HAS_TRANSFORM_UBLKCP) +# if TUNE_ALGORITHM == 2 && !defined(_CUB_HAS_TRANSFORM_UBLKCP) # error "Cannot tune for ublkcp algorithm, which is not provided by CUB (old CTK?)" # endif #endif diff --git a/cub/benchmarks/bench/transform/babelstream2.cu b/cub/benchmarks/bench/transform/babelstream2.cu index c8fa017b788..fafbf49b491 100644 --- a/cub/benchmarks/bench/transform/babelstream2.cu +++ b/cub/benchmarks/bench/transform/babelstream2.cu @@ -2,15 +2,15 @@ // SPDX-License-Identifier: BSD-3-Clause // %RANGE% TUNE_THREADS tpb 128:1024:128 -// %RANGE% TUNE_ALGORITHM alg 0:1:1 +// %RANGE% TUNE_ALGORITHM alg 0:2:1 // keep checks at the top so compilation of discarded variants fails really fast #if !TUNE_BASE -# if TUNE_ALGORITHM == 1 && (__CUDA_ARCH_LIST__) < 900 +# if TUNE_ALGORITHM == 2 && (__CUDA_ARCH_LIST__) < 900 # error "Cannot compile algorithm 4 (ublkcp) below sm90" # endif -# if TUNE_ALGORITHM == 1 && !defined(_CUB_HAS_TRANSFORM_UBLKCP) +# if TUNE_ALGORITHM == 2 && !defined(_CUB_HAS_TRANSFORM_UBLKCP) # error "Cannot tune for ublkcp algorithm, which is not provided by CUB (old CTK?)" # endif #endif diff --git a/cub/benchmarks/bench/transform/babelstream3.cu b/cub/benchmarks/bench/transform/babelstream3.cu index db541554210..8bd48b47570 100644 --- a/cub/benchmarks/bench/transform/babelstream3.cu +++ b/cub/benchmarks/bench/transform/babelstream3.cu @@ -2,15 +2,15 @@ // SPDX-License-Identifier: BSD-3-Clause // %RANGE% TUNE_THREADS tpb 128:1024:128 -// %RANGE% TUNE_ALGORITHM alg 0:1:1 +// %RANGE% TUNE_ALGORITHM alg 0:2:1 // keep checks at the top so compilation of discarded variants fails really fast #if !TUNE_BASE -# if TUNE_ALGORITHM == 1 && (__CUDA_ARCH_LIST__) < 900 +# if TUNE_ALGORITHM == 2 && (__CUDA_ARCH_LIST__) < 900 # error "Cannot compile algorithm 4 (ublkcp) below sm90" # endif -# if TUNE_ALGORITHM == 1 && !defined(_CUB_HAS_TRANSFORM_UBLKCP) +# if TUNE_ALGORITHM == 2 && !defined(_CUB_HAS_TRANSFORM_UBLKCP) # error "Cannot tune for ublkcp algorithm, which is not provided by CUB (old CTK?)" # endif #endif diff --git a/cub/cub/device/dispatch/dispatch_transform.cuh b/cub/cub/device/dispatch/dispatch_transform.cuh index 1b25d003a39..7b8cc497879 100644 --- a/cub/cub/device/dispatch/dispatch_transform.cuh +++ b/cub/cub/device/dispatch/dispatch_transform.cuh @@ -94,6 +94,7 @@ enum class Algorithm // We previously had a fallback algorithm that would use cub::DeviceFor. Benchmarks showed that the prefetch algorithm // is always superior to that fallback, so it was removed. prefetch, + memcpy_async, #ifdef _CUB_HAS_TRANSFORM_UBLKCP ublkcp, #endif // _CUB_HAS_TRANSFORM_UBLKCP @@ -310,6 +311,136 @@ _CCCL_HOST_DEVICE auto make_aligned_base_ptr(const T* ptr, int alignment) -> ali return aligned_base_ptr{base_ptr, static_cast(reinterpret_cast(ptr) - base_ptr)}; } +constexpr int memcpy_async_alignment = 16; +constexpr int memcpy_async_size_multiple = 16; + +// Our own version of ::cuda::aligned_size_t, since we cannot include on CUDA_ARCH < 700 +template <_CUDA_VSTD::size_t _Alignment> +struct aligned_size_t +{ + _CUDA_VSTD::size_t value; // TODO(bgruber): can this be an int? + + _CCCL_HOST_DEVICE constexpr operator size_t() const + { + return value; + } +}; + +// TODO(bgruber): inline this as lambda in C++14 +template +_CCCL_DEVICE const T* copy_and_return_smem_dst( + cooperative_groups::thread_block& group, + int tile_size, + char* smem, + int& smem_offset, + Offset global_offset, + aligned_base_ptr aligned_ptr) +{ + // because SMEM base pointer and bytes_to_copy are always multiples of 16-byte, we only need to align the SMEM start + // for types with larger alignment + _CCCL_IF_CONSTEXPR (alignof(T) > memcpy_async_alignment) + { + smem_offset = round_up_to_po2_multiple(smem_offset, static_cast(alignof(T))); + } + const char* const src = aligned_ptr.ptr + global_offset * sizeof(T); + char* const dst = smem + smem_offset; + _CCCL_ASSERT(reinterpret_cast(src) % memcpy_async_alignment == 0, ""); + _CCCL_ASSERT(reinterpret_cast(dst) % memcpy_async_alignment == 0, ""); + const int bytes_to_copy = round_up_to_po2_multiple( + aligned_ptr.head_padding + static_cast(sizeof(T)) * tile_size, memcpy_async_size_multiple); + smem_offset += bytes_to_copy; // leave aligned address for follow-up copy + cooperative_groups::memcpy_async( + group, dst, src, aligned_size_t{static_cast<::cuda::std::size_t>(bytes_to_copy)}); + + const char* const dst_start_of_data = dst + aligned_ptr.head_padding; + _CCCL_ASSERT(reinterpret_cast(dst_start_of_data) % alignof(T) == 0, ""); + return reinterpret_cast(dst_start_of_data); +} + +// TODO(ahendriksen): the codegen for memcpy_async for char and short is really verbose (300 instructions). we may +// rather want to just do an unrolled loop here. +template +_CCCL_DEVICE const T* copy_and_return_smem_dst_fallback( + cooperative_groups::thread_block& group, + int tile_size, + char* smem, + int& smem_offset, + Offset global_offset, + aligned_base_ptr aligned_ptr) +{ + smem_offset = round_up_to_po2_multiple(smem_offset, static_cast(alignof(T))); + const T* src = aligned_ptr.ptr_to_elements() + global_offset; + T* dst = reinterpret_cast(smem + smem_offset); + _CCCL_ASSERT(reinterpret_cast(src) % alignof(T) == 0, ""); + _CCCL_ASSERT(reinterpret_cast(dst) % alignof(T) == 0, ""); + const int bytes_to_copy = static_cast(sizeof(T)) * tile_size; + smem_offset += bytes_to_copy; + cooperative_groups::memcpy_async(group, dst, src, bytes_to_copy); + + return dst; +} + +template +_CCCL_DEVICE void transform_kernel_impl( + ::cuda::std::integral_constant, + Offset num_items, + int num_elem_per_thread, + F f, + RandomAccessIteratorOut out, + aligned_base_ptr... aligned_ptrs) +{ + extern __shared__ char smem[]; // this should be __attribute((aligned(memcpy_async_alignment))), but then it clashes + // with the ublkcp kernel, which sets a higher alignment, since they are both called + // from the same kernel entry point (albeit one is always discarded). However, SMEM is + // 16-byte aligned by default. + + constexpr int block_dim = MemcpyAsyncPolicy::block_threads; + const int tile_stride = block_dim * num_elem_per_thread; + const Offset offset = static_cast(blockIdx.x) * tile_stride; + const int tile_size = static_cast(::cuda::std::min(num_items - offset, Offset{tile_stride})); + + auto group = cooperative_groups::this_thread_block(); + int smem_offset = 0; + // TODO(bgruber): if we used SMEM offsets instead of pointers, we only need half the registers + const bool inner_blocks = 0 < blockIdx.x && blockIdx.x + 2 < gridDim.x; + const auto smem_ptrs = ::cuda::std::tuple{ + (inner_blocks ? copy_and_return_smem_dst(group, tile_size, smem, smem_offset, offset, aligned_ptrs) + : copy_and_return_smem_dst_fallback(group, tile_size, smem, smem_offset, offset, aligned_ptrs))...}; + cooperative_groups::wait(group); + (void) smem_ptrs; // suppress unused warning for MSVC + (void) &smem_offset; // MSVC needs extra strong unused warning suppression + + // move the whole index and iterator to the block/thread index, to reduce arithmetic in the loops below + out += offset; + + // TODO(bgruber): use a polymorphic lambda in C++14 +#define MEMCPY_ASYNC_AGENT(full_tile) \ + _Pragma("unroll 1") /* Unroll 1 tends to improve performance, especially for smaller data types (confirmed by \ + benchmark) */ \ + for (int j = 0; j < num_elem_per_thread; ++j) \ + { \ + const int idx = j * block_dim + threadIdx.x; \ + if (full_tile || idx < tile_size) \ + { \ + out[idx] = poor_apply( \ + [&](const InTs* __restrict__... smem_base_ptrs) { \ + return f(smem_base_ptrs[idx]...); \ + }, \ + smem_ptrs); \ + } \ + } + + if (tile_stride == tile_size) + { + MEMCPY_ASYNC_AGENT(true); + } + else + { + MEMCPY_ASYNC_AGENT(false); + } +#undef MEMCPY_ASYNC_AGENT +} + constexpr int bulk_copy_alignment = 128; constexpr int bulk_copy_size_multiple = 16; @@ -519,7 +650,7 @@ _CCCL_HOST_DEVICE auto make_aligned_base_ptr_kernel_arg(It ptr, int alignment) - // TODO(bgruber): make a variable template in C++14 template using needs_aligned_ptr_t = - ::cuda::std::bool_constant +_CCCL_HOST_DEVICE constexpr auto memcpy_async_smem_for_tile_size(int tile_size) -> int +{ +#if _CCCL_STD_VER > 2011 + int smem_size = 0; + auto count_smem = [&](int size, int alignment) { + smem_size = round_up_to_po2_multiple(smem_size, alignment); + // max aligned_base_ptr head_padding + max padding after == 16 + smem_size += size * tile_size + memcpy_async_alignment; + }; + // TODO(bgruber): replace by fold over comma in C++17 (left to right evaluation!) + int dummy[] = { + (count_smem(sizeof(value_t), alignof(value_t)), 0)..., 0}; + (void) &dummy; // need to take the address to suppress unused warnings more strongly for nvcc 11.1 + (void) &count_smem; + return smem_size; +#else // _CCCL_STD_VER > 2011 + // we use a simplified calculation with a bit of excess in C++11 since constexpr functions are more limited + return sum((sizeof(value_t) * tile_size + memcpy_async_alignment + + alignof(value_t))...); +#endif // _CCCL_STD_VER > 2011 +} + template _CCCL_HOST_DEVICE constexpr auto bulk_copy_smem_for_tile_size(int tile_size) -> int { @@ -611,9 +766,25 @@ struct policy_hub; }; + // TODO(bgruber): should we add a tuning for 750? They should have items_per_thread_from_occupancy(256, 4, ...) + + // A100 + struct policy800 : ChainedPolicy<800, policy800, policy300> + { + static constexpr int min_bif = arch_to_min_bytes_in_flight(800); + using async_policy = async_copy_policy_t<256>; + static constexpr bool exhaust_smem = + memcpy_async_smem_for_tile_size( + async_policy::block_threads * async_policy::min_items_per_thread) + > 48 * 1024; + static constexpr bool use_fallback = RequiresStableAddress || !can_memcpy || no_input_streams || exhaust_smem; + static constexpr auto algorithm = use_fallback ? Algorithm::prefetch : Algorithm::memcpy_async; + using algo_policy = ::cuda::std::_If, async_policy>; + }; + #ifdef _CUB_HAS_TRANSFORM_UBLKCP // H100 and H200 - struct policy900 : ChainedPolicy<900, policy900, policy300> + struct policy900 : ChainedPolicy<900, policy900, policy800> { static constexpr int min_bif = arch_to_min_bytes_in_flight(900); using async_policy = async_copy_policy_t<256>; @@ -636,7 +807,7 @@ struct policy_hub(); + // TODO(bgruber): I want to write tests for this but those are highly depending on the architecture we are running on? + template + CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE auto configure_memcpy_async_kernel() + -> PoorExpected< + ::cuda::std:: + tuple> + { + // Benchmarking shows that even for a few iteration, this loop takes around 4-7 us, so should not be a concern. + using policy_t = typename ActivePolicy::algo_policy; + constexpr int block_dim = policy_t::block_threads; + static_assert(block_dim % memcpy_async_alignment == 0, + "block_threads needs to be a multiple of memcpy_async_alignment (16)"); // then tile_size is a + // multiple of 16-byte + auto determine_element_counts = [&]() -> PoorExpected { + const auto max_smem = get_max_shared_memory(); + if (!max_smem) + { + return max_smem.error; + } + + elem_counts last_counts{}; + // Increase the number of output elements per thread until we reach the required bytes in flight. + static_assert(policy_t::min_items_per_thread <= policy_t::max_items_per_thread, ""); // ensures the loop below + // runs at least once + for (int elem_per_thread = +policy_t::min_items_per_thread; elem_per_thread <= +policy_t::max_items_per_thread; + ++elem_per_thread) + { + const auto tile_size = block_dim * elem_per_thread; + const int smem_size = memcpy_async_smem_for_tile_size(tile_size); + if (smem_size > *max_smem) + { +#ifdef CUB_DETAIL_DEBUG_ENABLE_HOST_ASSERTIONS + // assert should be prevented by smem check in policy + assert(last_counts.elem_per_thread > 0 && "min_items_per_thread exceeds available shared memory"); +#endif // CUB_DETAIL_DEBUG_ENABLE_HOST_ASSERTIONS + return last_counts; + } + + if (tile_size >= num_items) + { + return elem_counts{elem_per_thread, tile_size, smem_size}; + } + + int max_occupancy = 0; + const auto error = + CubDebug(MaxSmOccupancy(max_occupancy, CUB_DETAIL_TRANSFORM_KERNEL_PTR, block_dim, smem_size)); + if (error != cudaSuccess) + { + return error; + } + + const int bytes_in_flight_SM = max_occupancy * tile_size * loaded_bytes_per_iter; + if (bytes_in_flight_SM >= ActivePolicy::min_bif) + { + return elem_counts{elem_per_thread, tile_size, smem_size}; + } + + last_counts = elem_counts{elem_per_thread, tile_size, smem_size}; + } + return last_counts; + }; + PoorExpected config = [&]() { + NV_IF_TARGET( + NV_IS_HOST, + ( + // this static variable exists for each template instantiation of the surrounding function and class, on which + // the chosen element count solely depends (assuming max SMEM is constant during a program execution) + static auto cached_config = determine_element_counts(); return cached_config;), + ( + // we cannot cache the determined element count in device code + return determine_element_counts();)); + }(); + if (!config) + { + return config.error; + } +#ifdef CUB_DETAIL_DEBUG_ENABLE_HOST_ASSERTIONS + assert(config->elem_per_thread > 0); + assert(config->tile_size > 0); + assert(config->tile_size % memcpy_async_alignment == 0); + assert((sizeof...(RandomAccessIteratorsIn) == 0) != (config->smem_size != 0)); // logical xor +#endif // CUB_DETAIL_DEBUG_ENABLE_HOST_ASSERTIONS + + const auto grid_dim = static_cast(::cuda::ceil_div(num_items, Offset{config->tile_size})); + return ::cuda::std::make_tuple( + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(grid_dim, block_dim, config->smem_size, stream), + CUB_DETAIL_TRANSFORM_KERNEL_PTR, + config->elem_per_thread); + } + #ifdef _CUB_HAS_TRANSFORM_UBLKCP // TODO(bgruber): I want to write tests for this but those are highly depending on the architecture we are running // on? @@ -871,6 +1132,27 @@ struct dispatch_telem_per_thread); } + template + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t + invoke_algorithm(cuda::std::index_sequence, ::cuda::std::integral_constant) + { + auto ret = configure_memcpy_async_kernel(); + if (!ret) + { + return ret.error; + } + // TODO(bgruber): use a structured binding in C++17 + // auto [launcher, kernel, elem_per_thread] = *ret; + return ::cuda::std::get<0>(*ret).doit( + ::cuda::std::get<1>(*ret), + num_items, + ::cuda::std::get<2>(*ret), + op, + out, + make_aligned_base_ptr_kernel_arg( + THRUST_NS_QUALIFIER::unwrap_contiguous_iterator(::cuda::std::get(in)), memcpy_async_alignment)...); + } + template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t invoke_algorithm(cuda::std::index_sequence, ::cuda::std::integral_constant) diff --git a/cub/test/catch2_test_device_transform.cu b/cub/test/catch2_test_device_transform.cu index 949fa6569a3..0b6ab1c9ccb 100644 --- a/cub/test/catch2_test_device_transform.cu +++ b/cub/test/catch2_test_device_transform.cu @@ -77,7 +77,8 @@ DECLARE_TMPL_LAUNCH_WRAPPER(transform_many_with_alg_entry_point, using algorithms = c2h::enum_type_list; REQUIRE(cub::PtxVersion(ptx_version) == cudaSuccess); \ _CCCL_DIAG_PUSH \ _CCCL_DIAG_SUPPRESS_MSVC(4127) /* conditional expression is constant */ \ + if (alg == Algorithm::memcpy_async && ptx_version < 800) \ + { \ + return; \ + } \ FILTER_UBLKCP \ _CCCL_DIAG_POP