@@ -746,13 +746,38 @@ _CCCL_HOST_DEVICE inline PoorExpected<int> get_max_shared_memory()
746
746
return max_smem;
747
747
}
748
748
749
+ _CCCL_HOST_DEVICE inline PoorExpected<int > get_sm_count ()
750
+ {
751
+ int device = 0 ;
752
+ auto error = CubDebug (cudaGetDevice (&device));
753
+ if (error != cudaSuccess)
754
+ {
755
+ return error;
756
+ }
757
+
758
+ int sm_count = 0 ;
759
+ error = CubDebug (cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device));
760
+ if (error != cudaSuccess)
761
+ {
762
+ return error;
763
+ }
764
+
765
+ return sm_count;
766
+ }
767
+
749
768
struct elem_counts
750
769
{
751
770
int elem_per_thread;
752
771
int tile_size;
753
772
int smem_size;
754
773
};
755
774
775
+ struct prefetch_config
776
+ {
777
+ int max_occupancy;
778
+ int sm_count;
779
+ };
780
+
756
781
template <bool RequiresStableAddress,
757
782
typename Offset,
758
783
typename RandomAccessIteratorTupleIn,
@@ -933,19 +958,49 @@ struct dispatch_t<RequiresStableAddress,
933
958
{
934
959
using policy_t = typename ActivePolicy::algo_policy;
935
960
constexpr int block_dim = policy_t ::block_threads;
936
- int max_occupancy = 0 ;
937
- const auto error = CubDebug (MaxSmOccupancy (max_occupancy, CUB_DETAIL_TRANSFORM_KERNEL_PTR, block_dim, 0 ));
938
- if (error != cudaSuccess)
961
+
962
+ auto determine_config = [&]() -> PoorExpected<prefetch_config> {
963
+ int max_occupancy = 0 ;
964
+ const auto error = CubDebug (MaxSmOccupancy (max_occupancy, CUB_DETAIL_TRANSFORM_KERNEL_PTR, block_dim, 0 ));
965
+ if (error != cudaSuccess)
966
+ {
967
+ return error;
968
+ }
969
+ const auto sm_count = get_sm_count ();
970
+ if (!sm_count)
971
+ {
972
+ return sm_count.error ;
973
+ }
974
+ return prefetch_config{max_occupancy, *sm_count};
975
+ };
976
+
977
+ PoorExpected<prefetch_config> config = [&]() {
978
+ NV_IF_TARGET (
979
+ NV_IS_HOST,
980
+ (
981
+ // this static variable exists for each template instantiation of the surrounding function and class, on which
982
+ // the chosen element count solely depends (assuming max SMEM is constant during a program execution)
983
+ static auto cached_config = determine_config (); return cached_config;),
984
+ (
985
+ // we cannot cache the determined element count in device code
986
+ return determine_config ();));
987
+ }();
988
+ if (!config)
939
989
{
940
- return error;
990
+ return config. error ;
941
991
}
942
992
943
993
const int items_per_thread =
944
994
loaded_bytes_per_iter == 0
945
995
? +policy_t ::items_per_thread_no_input
946
- : ::cuda::ceil_div (ActivePolicy::min_bif, max_occupancy * block_dim * loaded_bytes_per_iter);
947
- const int items_per_thread_clamped =
948
- ::cuda::std::clamp (items_per_thread, +policy_t ::min_items_per_thread, +policy_t ::max_items_per_thread);
996
+ : ::cuda::ceil_div (ActivePolicy::min_bif, config->max_occupancy * block_dim * loaded_bytes_per_iter);
997
+
998
+ // Generate at least one block per SM. This improves tiny problem sizes (e.g. 2^16 elements).
999
+ const int items_per_thread_evenly_spread =
1000
+ static_cast <int >(::cuda::std::min (Offset{items_per_thread}, num_items / (config->sm_count * block_dim)));
1001
+
1002
+ const int items_per_thread_clamped = ::cuda::std::clamp (
1003
+ items_per_thread_evenly_spread, +policy_t ::min_items_per_thread, +policy_t ::max_items_per_thread);
949
1004
const int tile_size = block_dim * items_per_thread_clamped;
950
1005
const auto grid_dim = static_cast <unsigned int >(::cuda::ceil_div (num_items, Offset{tile_size}));
951
1006
return CubDebug (
0 commit comments