diff --git a/cub/cub/detail/arch_dispatch.cuh b/cub/cub/detail/arch_dispatch.cuh new file mode 100644 index 00000000000..012ab791019 --- /dev/null +++ b/cub/cub/detail/arch_dispatch.cuh @@ -0,0 +1,163 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#pragma once + +#include + +#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC) +# pragma GCC system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG) +# pragma clang system_header +#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC) +# pragma system_header +#endif // no system header + +#include +#include +#include +#include +#include + +CUB_NAMESPACE_BEGIN + +namespace detail +{ +#if !defined(CUB_DEFINE_RUNTIME_POLICIES) && !_CCCL_COMPILER(NVRTC) + +# if _CCCL_STD_VER < 2020 +template +struct policy_getter_17 +{ + PolicySelector policy_selector; + + _CCCL_API _CCCL_FORCEINLINE constexpr auto operator()() const + { + return policy_selector(LowestArchId); + } +}; + +template +_CCCL_API constexpr auto find_lowest_arch_with_same_policy( + PolicySelector policy_selector, size_t i, const ::cuda::std::array<::cuda::arch_id, N>& all_arches) -> ::cuda::arch_id +{ + const auto policy = policy_selector(all_arches[i]); + while (i > 0 && policy_selector(all_arches[i - 1]) == policy) + { + --i; + } + return all_arches[i]; +} + +template +struct lowest_arch_resolver; + +// we keep the compile-time build up of the mapping table outside a template parameterized by a user-provided callable +template +struct lowest_arch_resolver, PolicySelector, Is...> +{ + static_assert(::cuda::std::is_empty_v); + static_assert(sizeof...(CudaArches) == sizeof...(Is)); + + using policy_t = decltype(PolicySelector{}(::cuda::arch_id{})); + + static constexpr ::cuda::arch_id all_arches[sizeof...(Is)] = {::cuda::arch_id{(CudaArches * ArchMult) / 10}...}; + static constexpr policy_t all_policies[sizeof...(Is)] = {PolicySelector{}(all_arches[Is])...}; + + _CCCL_API static constexpr auto find_lowest(size_t i) -> ::cuda::arch_id + { + const auto& policy = all_policies[i]; + while (i > 0 && policy == all_policies[i - 1]) + { + --i; + } + return all_arches[i]; + } + + static constexpr ::cuda::arch_id lowest_arch_with_same_policy[sizeof...(Is)] = {find_lowest(Is)...}; +}; +# endif // if _CCCL_STD_VER < 2020 + +template +CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch_to_arch_list( + PolicySelector policy_selector, ::cuda::arch_id device_arch, FunctorT&& f, ::cuda::std::index_sequence) +{ + _CCCL_ASSERT(((device_arch == ::cuda::arch_id{(CudaArches * ArchMult) / 10}) || ...), + "device_arch must appear in the list of architectures compiled for"); + + using policy_t = decltype(policy_selector(::cuda::arch_id{})); + + cudaError_t e = cudaErrorInvalidDeviceFunction; +# if _CCCL_STD_VER >= 2020 + // In C++20, we just create an integral_constant holding the policy, because policies are structural types in C++20. + // This causes f to be only instantiated for each distinct policy, since the same policy for different arches results + // in the same integral_constant type passed to f + (..., + (device_arch == ::cuda::arch_id{(CudaArches * ArchMult) / 10} + ? (e = f( + ::cuda::std::integral_constant{})) + : cudaSuccess)); +# else // if _CCCL_STD_VER >= 2020 + // In C++17, we have to collapse architectures with the same policies ourselves, so we instantiate call_for_arch once + // per policy on the lowest ArchId which produces the same policy + using resolver_t = + lowest_arch_resolver, PolicySelector, Is...>; + (..., + (device_arch == ::cuda::arch_id{(CudaArches * ArchMult) / 10} + ? (e = f(policy_getter_17{policy_selector})) + : cudaSuccess)); + +# endif // if _CCCL_STD_VER >= 2020 + return e; +} + +template +CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch_all_arches_helper( + PolicySelector policy_selector, ::cuda::arch_id device_arch, FunctorT&& f, ::cuda::std::index_sequence seq) +{ + static constexpr auto all_arches = ::cuda::__all_arch_ids(); + return dispatch_to_arch_list<10, static_cast(all_arches[Is])...>(policy_selector, device_arch, f, seq); +} + +//! Takes a policy hub and instantiates f with the minimum possible number of nullary functor types that return a policy +//! at compile-time (if possible), and then calls the appropriate instantiation based on a runtime GPU architecture. +//! Depending on the used compiler, C++ standard, and available macros, a different number of instantiations may be +//! produced. +template +CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t +dispatch_arch(PolicySelector policy_selector, ::cuda::arch_id device_arch, F&& f) +{ + // if we have __CUDA_ARCH_LIST__ or NV_TARGET_SM_INTEGER_LIST, we only poll the policy hub for those arches. +# ifdef __CUDA_ARCH_LIST__ + [[maybe_unused]] static constexpr auto arch_seq = ::cuda::std::integer_sequence{}; + return dispatch_to_arch_list<1, __CUDA_ARCH_LIST__>( + policy_selector, device_arch, ::cuda::std::forward(f), ::cuda::std::make_index_sequence{}); +# elif defined(NV_TARGET_SM_INTEGER_LIST) + [[maybe_unused]] static constexpr auto arch_seq = ::cuda::std::integer_sequence{}; + return dispatch_to_arch_list<10, NV_TARGET_SM_INTEGER_LIST>( + policy_selector, device_arch, ::cuda::std::forward(f), ::cuda::std::make_index_sequence{}); +# else + // some compilers don't tell us what arches we are compiling for, so we test all of them + return dispatch_all_arches_helper( + policy_selector, + device_arch, + ::cuda::std::forward(f), + ::cuda::std::make_index_sequence<::cuda::__all_arch_ids().size()>{}); +# endif +} + +#else // !defined(CUB_DEFINE_RUNTIME_POLICIES) && !_CCCL_COMPILER(NVRTC) + +// if we are compiling CCCL.C with runtime policies, we cannot query the policy hub at compile time +_CCCL_EXEC_CHECK_DISABLE +template +_CCCL_API _CCCL_FORCEINLINE cudaError_t dispatch_arch(PolicySelector policy_selector, ::cuda::arch_id device_arch, F&& f) +{ + return f([&] { + return policy_selector(device_arch); + }); +} +#endif // !defined(CUB_DEFINE_RUNTIME_POLICIES) && !_CCCL_COMPILER(NVRTC) +} // namespace detail + +CUB_NAMESPACE_END diff --git a/cub/test/catch2_test_arch_dispatch.cu b/cub/test/catch2_test_arch_dispatch.cu new file mode 100644 index 00000000000..b82a8e44813 --- /dev/null +++ b/cub/test/catch2_test_arch_dispatch.cu @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include + +#include + +#ifdef __CUDA_ARCH_LIST__ +# define CUDA_SM_LIST __CUDA_ARCH_LIST__ +# define CUDA_SM_LIST_SCALE 1 +#elif defined(NV_TARGET_SM_INTEGER_LIST) +# define CUDA_SM_LIST NV_TARGET_SM_INTEGER_LIST +# define CUDA_SM_LIST_SCALE 10 +#endif + +using cuda::arch_id; + +struct a_policy +{ + arch_id value; + + _CCCL_API constexpr bool operator==(const a_policy& other) const noexcept + { + return value == other.value; + } + + _CCCL_API constexpr bool operator!=(const a_policy& other) const noexcept + { + return value != other.value; + } +}; + +struct policy_selector_all +{ + _CCCL_API constexpr auto operator()(arch_id id) const -> a_policy + { + return a_policy{id}; + } +}; + +#ifdef CUDA_SM_LIST +template +void check_arch_is_in_list() +{ + static_assert(((SelectedPolicyArch == arch_id{ArchList * CUDA_SM_LIST_SCALE / 10}) || ...)); +} +#endif // CUDA_SM_LIST + +struct closure_all +{ + arch_id id; + + template + CUB_RUNTIME_FUNCTION auto operator()(PolicyGetter policy_getter) const -> cudaError_t + { +#ifdef CUDA_SM_LIST + check_arch_is_in_list(); +#endif // CUDA_SM_LIST + constexpr a_policy active_policy = policy_getter(); + // since an individual policy is generated per architecture, we can do an exact comparison here + REQUIRE(active_policy.value == id); + return cudaSuccess; + } +}; + +C2H_TEST("dispatch_arch prunes based on __CUDA_ARCH_LIST__/NV_TARGET_SM_INTEGER_LIST", "[util][dispatch]") +{ +#ifdef CUDA_SM_LIST + for (const int sm_val : {CUDA_SM_LIST}) + { + const auto id = arch_id{sm_val * CUDA_SM_LIST_SCALE / 10}; +#else + for (const arch_id id : cuda::__all_arch_ids()) + { +#endif + CHECK(cub::detail::dispatch_arch(policy_selector_all{}, id, closure_all{id}) == cudaSuccess); + } +} + +template +struct check_policy_closure +{ + arch_id id; + cuda::std::array policy_ids; + + template + CUB_RUNTIME_FUNCTION cudaError_t operator()(PolicyGetter policy_getter) const + { + constexpr a_policy active_policy = policy_getter(); + CAPTURE(id, policy_ids); + const auto policy_arch = *cuda::std::find_if(policy_ids.rbegin(), policy_ids.rend(), [&](arch_id policy_ver) { + return policy_ver <= id; + }); + REQUIRE(active_policy.value == policy_arch); + return cudaSuccess; + } +}; + +// distinct policies for 60+, 80+ and 100+ +struct policy_selector_some +{ + _CCCL_API constexpr auto operator()(arch_id id) const -> a_policy + { + if (id >= arch_id::sm_100) + { + return a_policy{arch_id::sm_100}; + } + if (id >= arch_id::sm_80) + { + return a_policy{arch_id::sm_80}; + } + // default is policy 60 + return a_policy{arch_id::sm_60}; + } +}; + +// only a single policy +struct policy_selector_minimal +{ + _CCCL_API constexpr auto operator()(arch_id) const -> a_policy + { + // default is policy 60 + return a_policy{arch_id::sm_60}; + } +}; + +C2H_TEST("dispatch_arch invokes correct policy", "[util][dispatch]") +{ +#ifdef CUDA_SM_LIST + for (const int sm_val : {CUDA_SM_LIST}) + { + const auto id = arch_id{sm_val * CUDA_SM_LIST_SCALE / 10}; +#else + for (const arch_id id : cuda::__all_arch_ids()) + { +#endif + const auto closure_some = + check_policy_closure<3>{id, cuda::std::array{arch_id::sm_60, arch_id::sm_80, arch_id::sm_100}}; + CHECK(cub::detail::dispatch_arch(policy_selector_some{}, id, closure_some) == cudaSuccess); + + const auto closure_minimal = check_policy_closure<1>{id, cuda::std::array{arch_id::sm_60}}; + CHECK(cub::detail::dispatch_arch(policy_selector_minimal{}, id, closure_minimal) == cudaSuccess); + } +}