From 082566c04217c819ee8faab38039af29f04df69e Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Thu, 15 Jun 2023 21:14:54 +0400 Subject: [PATCH] Split tuning --- cub/device/dispatch/dispatch_select_if.cuh | 227 +++++++++++++++++---- 1 file changed, 183 insertions(+), 44 deletions(-) diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index 6ad75d111..b9a077b4f 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -45,6 +45,8 @@ #include +#include "cub/agent/single_pass_scan_operators.cuh" +#include "cub/block/block_load.cuh" #include #include @@ -171,6 +173,7 @@ namespace select { template ::PRIMITIVE, std::size_t InputSize = sizeof(InputT)> @@ -188,20 +191,89 @@ struct sm90_tuning using delay_constructor = detail::fixed_delay_constructor_t<350, 450>; }; -// select +// select::if template -struct sm90_tuning +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 22; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<580>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 22; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<320, 605>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 17; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, false, false, false, sizeof(__int128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; +}; + +template <> +struct sm90_tuning<__uint128_t, false, false, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; +}; +#endif + +// select::flagged +template +struct sm90_tuning { static constexpr int threads = 448; static constexpr int items = 20; static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - using delay_constructor = detail::fixed_delay_constructor_t<692, 715>; + using delay_constructor = detail::no_delay_constructor_t<715>; }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 448; static constexpr int items = 20; @@ -212,7 +284,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 15; @@ -223,30 +295,30 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 11; static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - using delay_constructor = detail::fixed_delay_constructor_t<948, 1090>; + using delay_constructor = detail::fixed_delay_constructor_t<360, 1170>; }; #if CUB_IS_INT128_ENABLED template <> -struct sm90_tuning<__int128_t, false, false, sizeof(__int128_t)> +struct sm90_tuning<__int128_t, true, false, false, sizeof(__int128_t)> { static constexpr int threads = 512; - static constexpr int items = 4; + static constexpr int items = 3; static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; + using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; }; template <> -struct sm90_tuning<__uint128_t, false, false, sizeof(__uint128_t)> +struct sm90_tuning<__uint128_t, true, false, false, sizeof(__uint128_t)> { static constexpr int threads = 512; static constexpr int items = 3; @@ -257,9 +329,9 @@ struct sm90_tuning<__uint128_t, false, false, sizeof(__uint128_t)> }; #endif -// partition +// partition::if template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 20; @@ -270,41 +342,110 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { - static constexpr int threads = 128; - static constexpr int items = 18; + static constexpr int threads = 384; + static constexpr int items = 20; static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>; + using delay_constructor = detail::exponential_backon_jitter_window_constructor_t<1180, 715>; }; template -struct sm90_tuning +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<536, 1055>; +}; + +template +struct sm90_tuning { static constexpr int threads = 128; - static constexpr int items = 18; + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, false, true, false, sizeof(__int128_t)> +{ + static constexpr int threads = 192; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; +}; + +template <> +struct sm90_tuning<__uint128_t, false, true, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 192; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; +}; +#endif + +// partition::flagged +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<580, 850>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 512; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<388, 1055>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 20; static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>; + using delay_constructor = detail::fixed_delay_constructor_t<72, 1165>; }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 128; static constexpr int items = 9; static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - using delay_constructor = detail::fixed_delay_constructor_t<752, 1125>; + using delay_constructor = detail::fixed_delay_constructor_t<468, 1175>; }; #if CUB_IS_INT128_ENABLED template <> -struct sm90_tuning<__int128_t, true, false, sizeof(__int128_t)> +struct sm90_tuning<__int128_t, true, true, false, sizeof(__int128_t)> { static constexpr int threads = 160; static constexpr int items = 5; @@ -315,7 +456,7 @@ struct sm90_tuning<__int128_t, true, false, sizeof(__int128_t)> }; template <> -struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)> +struct sm90_tuning<__uint128_t, true, true, false, sizeof(__uint128_t)> { static constexpr int threads = 160; static constexpr int items = 5; @@ -328,7 +469,7 @@ struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)> } -template +template struct device_select_policy_hub { struct Policy350 : ChainedPolicy<350, Policy350, Policy350> @@ -349,7 +490,9 @@ struct device_select_policy_hub struct Policy900 : ChainedPolicy<900, Policy900, Policy350> { - using tuning = detail::select::sm90_tuning; + static constexpr bool flagged = std::is_same::value == false; + + using tuning = detail::select::sm90_tuning; using SelectIfPolicyT = AgentSelectIfPolicy, MayAlias, KEEP_REJECTS>> +template , + cub::detail::value_t, + MayAlias, + KEEP_REJECTS>> struct DispatchSelectIf : SelectedPolicy { /****************************************************************************** * Types and constants ******************************************************************************/ - // The input value type - using InputT = cub::detail::value_t; - - // The flag value type - using FlagT = cub::detail::value_t; - // Tile status descriptor interface type using ScanTileStateT = ScanTileState;