From b4b28c664a82791064d24af049d5286d61cba6bc Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Mon, 12 Jun 2023 10:56:26 +0400 Subject: [PATCH] Tune select and partition for SM90 --- cub/device/dispatch/dispatch_select_if.cuh | 333 ++++++++++++++++++++- 1 file changed, 324 insertions(+), 9 deletions(-) diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index c64885dc5..4c2256c72 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -45,6 +45,8 @@ #include +#include "cub/agent/single_pass_scan_operators.cuh" +#include "cub/block/block_load.cuh" #include #include @@ -167,7 +169,308 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREA namespace detail { -template +namespace select +{ + +template ::PRIMITIVE, + std::size_t InputSize = sizeof(InputT)> +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int nominal_4b_items_per_thread = 10; + + static constexpr int items = CUB_MIN(nominal_4b_items_per_thread, + CUB_MAX(1, (nominal_4b_items_per_thread * 4 / InputSize))); + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<350, 450>; +}; + +// select::if +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 22; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<580>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 22; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<320, 605>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 17; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, false, false, 4, false, sizeof(__int128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; +}; + +template <> +struct sm90_tuning<__uint128_t, false, false, 4, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; +}; +#endif + +// select::flagged +template +struct sm90_tuning +{ + static constexpr int threads = 448; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<715>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 448; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<504, 765>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 15; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<415, 1125>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<360, 1170>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, true, false, 4, false, sizeof(__int128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 3; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; +}; + +template <> +struct sm90_tuning<__uint128_t, true, false, 4, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 3; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; +}; +#endif + +// partition::if +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<908, 995>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::exponential_backon_jitter_window_constructor_t<1180, 715>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<536, 1055>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, false, true, 4, false, sizeof(__int128_t)> +{ + static constexpr int threads = 192; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; +}; + +template <> +struct sm90_tuning<__uint128_t, false, true, 4, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 192; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; +}; +#endif + +// partition::flagged +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<580, 850>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 512; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<388, 1055>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<72, 1165>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 9; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<468, 1175>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, true, true, 4, false, sizeof(__int128_t)> +{ + static constexpr int threads = 160; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; +}; + +template <> +struct sm90_tuning<__uint128_t, true, true, 4, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 160; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; +}; +#endif + +} + +template struct device_select_policy_hub { struct Policy350 : ChainedPolicy<350, Policy350, Policy350> @@ -186,7 +489,21 @@ struct device_select_policy_hub detail::fixed_delay_constructor_t<350, 450>>; }; - using MaxPolicy = Policy350; + struct Policy900 : ChainedPolicy<900, Policy900, Policy350> + { + static constexpr bool flagged = std::is_same::value == false; + + using tuning = detail::select::sm90_tuning; + + using SelectIfPolicyT = AgentSelectIfPolicy; + }; + + using MaxPolicy = Policy900; }; } // detail @@ -236,19 +553,17 @@ template , MayAlias>> + detail::device_select_policy_hub, + cub::detail::value_t, + OffsetT, + MayAlias, + KEEP_REJECTS>> struct DispatchSelectIf : SelectedPolicy { /****************************************************************************** * Types and constants ******************************************************************************/ - // The input value type - using InputT = cub::detail::value_t; - - // The flag value type - using FlagT = cub::detail::value_t; - // Tile status descriptor interface type using ScanTileStateT = ScanTileState;