Skip to content

Commit 074ad49

Browse files
Evenly spread elements (small problem sizes)
1 parent 9f0d63b commit 074ad49

File tree

1 file changed

+62
-7
lines changed

1 file changed

+62
-7
lines changed

Diff for: cub/cub/device/dispatch/dispatch_transform.cuh

+62-7
Original file line numberDiff line numberDiff line change
@@ -746,13 +746,38 @@ _CCCL_HOST_DEVICE inline PoorExpected<int> get_max_shared_memory()
746746
return max_smem;
747747
}
748748

749+
_CCCL_HOST_DEVICE inline PoorExpected<int> get_sm_count()
750+
{
751+
int device = 0;
752+
auto error = CubDebug(cudaGetDevice(&device));
753+
if (error != cudaSuccess)
754+
{
755+
return error;
756+
}
757+
758+
int sm_count = 0;
759+
error = CubDebug(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device));
760+
if (error != cudaSuccess)
761+
{
762+
return error;
763+
}
764+
765+
return sm_count;
766+
}
767+
749768
struct elem_counts
750769
{
751770
int elem_per_thread;
752771
int tile_size;
753772
int smem_size;
754773
};
755774

775+
struct prefetch_config
776+
{
777+
int max_occupancy;
778+
int sm_count;
779+
};
780+
756781
template <bool RequiresStableAddress,
757782
typename Offset,
758783
typename RandomAccessIteratorTupleIn,
@@ -933,19 +958,49 @@ struct dispatch_t<RequiresStableAddress,
933958
{
934959
using policy_t = typename ActivePolicy::algo_policy;
935960
constexpr int block_dim = policy_t::block_threads;
936-
int max_occupancy = 0;
937-
const auto error = CubDebug(MaxSmOccupancy(max_occupancy, CUB_DETAIL_TRANSFORM_KERNEL_PTR, block_dim, 0));
938-
if (error != cudaSuccess)
961+
962+
auto determine_config = [&]() -> PoorExpected<prefetch_config> {
963+
int max_occupancy = 0;
964+
const auto error = CubDebug(MaxSmOccupancy(max_occupancy, CUB_DETAIL_TRANSFORM_KERNEL_PTR, block_dim, 0));
965+
if (error != cudaSuccess)
966+
{
967+
return error;
968+
}
969+
const auto sm_count = get_sm_count();
970+
if (!sm_count)
971+
{
972+
return sm_count.error;
973+
}
974+
return prefetch_config{max_occupancy, *sm_count};
975+
};
976+
977+
PoorExpected<prefetch_config> config = [&]() {
978+
NV_IF_TARGET(
979+
NV_IS_HOST,
980+
(
981+
// this static variable exists for each template instantiation of the surrounding function and class, on which
982+
// the chosen element count solely depends (assuming max SMEM is constant during a program execution)
983+
static auto cached_config = determine_config(); return cached_config;),
984+
(
985+
// we cannot cache the determined element count in device code
986+
return determine_config();));
987+
}();
988+
if (!config)
939989
{
940-
return error;
990+
return config.error;
941991
}
942992

943993
const int items_per_thread =
944994
loaded_bytes_per_iter == 0
945995
? +policy_t::items_per_thread_no_input
946-
: ::cuda::ceil_div(ActivePolicy::min_bif, max_occupancy * block_dim * loaded_bytes_per_iter);
947-
const int items_per_thread_clamped =
948-
::cuda::std::clamp(items_per_thread, +policy_t::min_items_per_thread, +policy_t::max_items_per_thread);
996+
: ::cuda::ceil_div(ActivePolicy::min_bif, config->max_occupancy * block_dim * loaded_bytes_per_iter);
997+
998+
// Generate at least one block per SM. This improves tiny problem sizes (e.g. 2^16 elements).
999+
const int items_per_thread_evenly_spread =
1000+
static_cast<int>(::cuda::std::min(Offset{items_per_thread}, num_items / (config->sm_count * block_dim)));
1001+
1002+
const int items_per_thread_clamped = ::cuda::std::clamp(
1003+
items_per_thread_evenly_spread, +policy_t::min_items_per_thread, +policy_t::max_items_per_thread);
9491004
const int tile_size = block_dim * items_per_thread_clamped;
9501005
const auto grid_dim = static_cast<unsigned int>(::cuda::ceil_div(num_items, Offset{tile_size}));
9511006
return CubDebug(

0 commit comments

Comments
 (0)