Skip to content
Merged
163 changes: 163 additions & 0 deletions cub/cub/detail/arch_dispatch.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/__device/arch_id.h>
#include <cuda/std/__type_traits/is_empty.h>
#include <cuda/std/__utility/forward.h>
#include <cuda/std/__utility/integer_sequence.h>
#include <cuda/std/array>

CUB_NAMESPACE_BEGIN

namespace detail
{
#if !defined(CUB_DEFINE_RUNTIME_POLICIES) && !_CCCL_COMPILER(NVRTC)

# if _CCCL_STD_VER < 2020
template <typename PolicySelector, ::cuda::arch_id LowestArchId>
struct policy_getter_17
{
PolicySelector policy_selector;

_CCCL_API _CCCL_FORCEINLINE constexpr auto operator()() const
{
return policy_selector(LowestArchId);
}
};

template <typename PolicySelector, size_t N>
_CCCL_API constexpr auto find_lowest_arch_with_same_policy(
PolicySelector policy_selector, size_t i, const ::cuda::std::array<::cuda::arch_id, N>& all_arches) -> ::cuda::arch_id
{
const auto policy = policy_selector(all_arches[i]);
while (i > 0 && policy_selector(all_arches[i - 1]) == policy)
{
--i;
}
return all_arches[i];
}

template <int ArchMult, typename CudaArchSeq, typename PolicySelector, size_t... Is>
struct lowest_arch_resolver;

// we keep the compile-time build up of the mapping table outside a template parameterized by a user-provided callable
template <int ArchMult, int... CudaArches, typename PolicySelector, size_t... Is>
struct lowest_arch_resolver<ArchMult, ::cuda::std::integer_sequence<int, CudaArches...>, PolicySelector, Is...>
{
static_assert(::cuda::std::is_empty_v<PolicySelector>);
static_assert(sizeof...(CudaArches) == sizeof...(Is));

using policy_t = decltype(PolicySelector{}(::cuda::arch_id{}));

static constexpr ::cuda::arch_id all_arches[sizeof...(Is)] = {::cuda::arch_id{(CudaArches * ArchMult) / 10}...};
static constexpr policy_t all_policies[sizeof...(Is)] = {PolicySelector{}(all_arches[Is])...};

_CCCL_API static constexpr auto find_lowest(size_t i) -> ::cuda::arch_id
{
const auto& policy = all_policies[i];
while (i > 0 && policy == all_policies[i - 1])
{
--i;
}
return all_arches[i];
}

static constexpr ::cuda::arch_id lowest_arch_with_same_policy[sizeof...(Is)] = {find_lowest(Is)...};
};
# endif // if _CCCL_STD_VER < 2020

template <int ArchMult, int... CudaArches, typename PolicySelector, typename FunctorT, size_t... Is>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch_to_arch_list(
PolicySelector policy_selector, ::cuda::arch_id device_arch, FunctorT&& f, ::cuda::std::index_sequence<Is...>)
{
_CCCL_ASSERT(((device_arch == ::cuda::arch_id{(CudaArches * ArchMult) / 10}) || ...),
"device_arch must appear in the list of architectures compiled for");

using policy_t = decltype(policy_selector(::cuda::arch_id{}));

cudaError_t e = cudaErrorInvalidDeviceFunction;
# if _CCCL_STD_VER >= 2020
// In C++20, we just create an integral_constant holding the policy, because policies are structural types in C++20.
// This causes f to be only instantiated for each distinct policy, since the same policy for different arches results
// in the same integral_constant type passed to f
(...,
(device_arch == ::cuda::arch_id{(CudaArches * ArchMult) / 10}
? (e = f(
::cuda::std::integral_constant<policy_t, policy_selector(::cuda::arch_id{(CudaArches * ArchMult) / 10})>{}))
: cudaSuccess));
# else // if _CCCL_STD_VER >= 2020
// In C++17, we have to collapse architectures with the same policies ourselves, so we instantiate call_for_arch once
// per policy on the lowest ArchId which produces the same policy
using resolver_t =
lowest_arch_resolver<ArchMult, ::cuda::std::integer_sequence<int, CudaArches...>, PolicySelector, Is...>;
(...,
(device_arch == ::cuda::arch_id{(CudaArches * ArchMult) / 10}
? (e = f(policy_getter_17<PolicySelector, resolver_t::lowest_arch_with_same_policy[Is]>{policy_selector}))
: cudaSuccess));

# endif // if _CCCL_STD_VER >= 2020
return e;
}

template <typename PolicySelector, typename FunctorT, size_t... Is>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch_all_arches_helper(
PolicySelector policy_selector, ::cuda::arch_id device_arch, FunctorT&& f, ::cuda::std::index_sequence<Is...> seq)
{
static constexpr auto all_arches = ::cuda::__all_arch_ids();
return dispatch_to_arch_list<10, static_cast<int>(all_arches[Is])...>(policy_selector, device_arch, f, seq);
}

//! Takes a policy hub and instantiates f with the minimum possible number of nullary functor types that return a policy
//! at compile-time (if possible), and then calls the appropriate instantiation based on a runtime GPU architecture.
//! Depending on the used compiler, C++ standard, and available macros, a different number of instantiations may be
//! produced.
template <typename PolicySelector, typename F>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t
dispatch_arch(PolicySelector policy_selector, ::cuda::arch_id device_arch, F&& f)
{
// if we have __CUDA_ARCH_LIST__ or NV_TARGET_SM_INTEGER_LIST, we only poll the policy hub for those arches.
# ifdef __CUDA_ARCH_LIST__
[[maybe_unused]] static constexpr auto arch_seq = ::cuda::std::integer_sequence<int, __CUDA_ARCH_LIST__>{};
return dispatch_to_arch_list<1, __CUDA_ARCH_LIST__>(
policy_selector, device_arch, ::cuda::std::forward<F>(f), ::cuda::std::make_index_sequence<arch_seq.size()>{});
# elif defined(NV_TARGET_SM_INTEGER_LIST)
[[maybe_unused]] static constexpr auto arch_seq = ::cuda::std::integer_sequence<int, NV_TARGET_SM_INTEGER_LIST>{};
return dispatch_to_arch_list<10, NV_TARGET_SM_INTEGER_LIST>(
policy_selector, device_arch, ::cuda::std::forward<F>(f), ::cuda::std::make_index_sequence<arch_seq.size()>{});
# else
// some compilers don't tell us what arches we are compiling for, so we test all of them
return dispatch_all_arches_helper(
policy_selector,
device_arch,
::cuda::std::forward<F>(f),
::cuda::std::make_index_sequence<::cuda::__all_arch_ids().size()>{});
# endif
}

#else // !defined(CUB_DEFINE_RUNTIME_POLICIES) && !_CCCL_COMPILER(NVRTC)

// if we are compiling CCCL.C with runtime policies, we cannot query the policy hub at compile time
_CCCL_EXEC_CHECK_DISABLE
template <typename PolicySelector, typename F>
_CCCL_API _CCCL_FORCEINLINE cudaError_t dispatch_arch(PolicySelector policy_selector, ::cuda::arch_id device_arch, F&& f)
{
return f([&] {
return policy_selector(device_arch);
});
}
#endif // !defined(CUB_DEFINE_RUNTIME_POLICIES) && !_CCCL_COMPILER(NVRTC)
} // namespace detail

CUB_NAMESPACE_END
146 changes: 146 additions & 0 deletions cub/test/catch2_test_arch_dispatch.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cub/detail/arch_dispatch.cuh>

#include <cuda/std/__algorithm/find_if.h>

#include <c2h/catch2_test_helper.h>

#ifdef __CUDA_ARCH_LIST__
# define CUDA_SM_LIST __CUDA_ARCH_LIST__
# define CUDA_SM_LIST_SCALE 1
#elif defined(NV_TARGET_SM_INTEGER_LIST)
# define CUDA_SM_LIST NV_TARGET_SM_INTEGER_LIST
# define CUDA_SM_LIST_SCALE 10
#endif

using cuda::arch_id;

struct a_policy
{
arch_id value;

_CCCL_API constexpr bool operator==(const a_policy& other) const noexcept
{
return value == other.value;
}

_CCCL_API constexpr bool operator!=(const a_policy& other) const noexcept
{
return value != other.value;
}
};

struct policy_selector_all
{
_CCCL_API constexpr auto operator()(arch_id id) const -> a_policy
{
return a_policy{id};
}
};

#ifdef CUDA_SM_LIST
template <arch_id SelectedPolicyArch, int... ArchList>
void check_arch_is_in_list()
{
static_assert(((SelectedPolicyArch == arch_id{ArchList * CUDA_SM_LIST_SCALE / 10}) || ...));
}
#endif // CUDA_SM_LIST

struct closure_all
{
arch_id id;

template <typename PolicyGetter>
CUB_RUNTIME_FUNCTION auto operator()(PolicyGetter policy_getter) const -> cudaError_t
{
#ifdef CUDA_SM_LIST
check_arch_is_in_list<PolicyGetter{}().value, CUDA_SM_LIST>();
#endif // CUDA_SM_LIST
constexpr a_policy active_policy = policy_getter();
// since an individual policy is generated per architecture, we can do an exact comparison here
REQUIRE(active_policy.value == id);
return cudaSuccess;
}
};

C2H_TEST("dispatch_arch prunes based on __CUDA_ARCH_LIST__/NV_TARGET_SM_INTEGER_LIST", "[util][dispatch]")
{
#ifdef CUDA_SM_LIST
for (const int sm_val : {CUDA_SM_LIST})
{
const auto id = arch_id{sm_val * CUDA_SM_LIST_SCALE / 10};
#else
for (const arch_id id : cuda::__all_arch_ids())
{
#endif
CHECK(cub::detail::dispatch_arch(policy_selector_all{}, id, closure_all{id}) == cudaSuccess);
}
}

template <int NumPolicies>
struct check_policy_closure
{
arch_id id;
cuda::std::array<arch_id, NumPolicies> policy_ids;

template <typename PolicyGetter>
CUB_RUNTIME_FUNCTION cudaError_t operator()(PolicyGetter policy_getter) const
{
constexpr a_policy active_policy = policy_getter();
CAPTURE(id, policy_ids);
const auto policy_arch = *cuda::std::find_if(policy_ids.rbegin(), policy_ids.rend(), [&](arch_id policy_ver) {
return policy_ver <= id;
});
REQUIRE(active_policy.value == policy_arch);
return cudaSuccess;
}
};

// distinct policies for 60+, 80+ and 100+
struct policy_selector_some
{
_CCCL_API constexpr auto operator()(arch_id id) const -> a_policy
{
if (id >= arch_id::sm_100)
{
return a_policy{arch_id::sm_100};
}
if (id >= arch_id::sm_80)
{
return a_policy{arch_id::sm_80};
}
// default is policy 60
return a_policy{arch_id::sm_60};
}
};

// only a single policy
struct policy_selector_minimal
{
_CCCL_API constexpr auto operator()(arch_id) const -> a_policy
{
// default is policy 60
return a_policy{arch_id::sm_60};
}
};

C2H_TEST("dispatch_arch invokes correct policy", "[util][dispatch]")
{
#ifdef CUDA_SM_LIST
for (const int sm_val : {CUDA_SM_LIST})
{
const auto id = arch_id{sm_val * CUDA_SM_LIST_SCALE / 10};
#else
for (const arch_id id : cuda::__all_arch_ids())
{
#endif
const auto closure_some =
check_policy_closure<3>{id, cuda::std::array<arch_id, 3>{arch_id::sm_60, arch_id::sm_80, arch_id::sm_100}};
CHECK(cub::detail::dispatch_arch(policy_selector_some{}, id, closure_some) == cudaSuccess);

const auto closure_minimal = check_policy_closure<1>{id, cuda::std::array<arch_id, 1>{arch_id::sm_60}};
CHECK(cub::detail::dispatch_arch(policy_selector_minimal{}, id, closure_minimal) == cudaSuccess);
}
}