diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index c64885dc5..6ad75d111 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -167,7 +167,168 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREA namespace detail { -template +namespace select +{ + +template ::PRIMITIVE, + std::size_t InputSize = sizeof(InputT)> +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int nominal_4b_items_per_thread = 10; + + static constexpr int items = CUB_MIN(nominal_4b_items_per_thread, + CUB_MAX(1, (nominal_4b_items_per_thread * 4 / InputSize))); + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<350, 450>; +}; + +// select +template +struct sm90_tuning +{ + static constexpr int threads = 448; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<692, 715>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 448; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<504, 765>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 15; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<415, 1125>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<948, 1090>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, false, false, sizeof(__int128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 4; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; +}; + +template <> +struct sm90_tuning<__uint128_t, false, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 512; + static constexpr int items = 3; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; +}; +#endif + +// partition +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<908, 995>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 18; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 18; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 9; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<752, 1125>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, true, false, sizeof(__int128_t)> +{ + static constexpr int threads = 160; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; +}; + +template <> +struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)> +{ + static constexpr int threads = 160; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; +}; +#endif + +} + +template struct device_select_policy_hub { struct Policy350 : ChainedPolicy<350, Policy350, Policy350> @@ -186,7 +347,19 @@ struct device_select_policy_hub detail::fixed_delay_constructor_t<350, 450>>; }; - using MaxPolicy = Policy350; + struct Policy900 : ChainedPolicy<900, Policy900, Policy350> + { + using tuning = detail::select::sm90_tuning; + + using SelectIfPolicyT = AgentSelectIfPolicy; + }; + + using MaxPolicy = Policy900; }; } // detail @@ -226,17 +399,18 @@ struct device_select_policy_hub * @tparam KEEP_REJECTS * Whether or not we push rejected items to the back of the output */ -template , MayAlias>> +template < + typename InputIteratorT, + typename FlagsInputIteratorT, + typename SelectedOutputIteratorT, + typename NumSelectedIteratorT, + typename SelectOpT, + typename EqualityOpT, + typename OffsetT, + bool KEEP_REJECTS, + bool MayAlias = false, + typename SelectedPolicy = + detail::device_select_policy_hub, MayAlias, KEEP_REJECTS>> struct DispatchSelectIf : SelectedPolicy { /******************************************************************************