From c84d218624e8ee70e7dcb73075a1685eeb763581 Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Mon, 19 Jun 2023 20:14:55 +0400 Subject: [PATCH] Refactor SM90 tuning --- cub/device/dispatch/dispatch_scan.cuh | 143 +----- cub/device/dispatch/dispatch_select_if.cuh | 348 +------------- cub/device/dispatch/tuning/tuning_scan.cuh | 205 +++++++++ .../dispatch/tuning/tuning_select_if.cuh | 426 ++++++++++++++++++ 4 files changed, 637 insertions(+), 485 deletions(-) create mode 100644 cub/device/dispatch/tuning/tuning_scan.cuh create mode 100644 cub/device/dispatch/tuning/tuning_select_if.cuh diff --git a/cub/device/dispatch/dispatch_scan.cuh b/cub/device/dispatch/dispatch_scan.cuh index 1529e21bb..a4bf2eff5 100644 --- a/cub/device/dispatch/dispatch_scan.cuh +++ b/cub/device/dispatch/dispatch_scan.cuh @@ -35,10 +35,9 @@ #pragma once -#include - #include #include +#include #include #include #include @@ -48,6 +47,8 @@ #include +#include + CUB_NAMESPACE_BEGIN /****************************************************************************** @@ -195,144 +196,6 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS)) .ConsumeRange(num_items, tile_state, start_tile); } -/****************************************************************************** - * Policy - ******************************************************************************/ - -namespace detail -{ -namespace scan -{ - -template -struct tuning -{ - static constexpr int threads = Threads; - static constexpr int items = Items; - - using delay_constructor = detail::fixed_delay_constructor_t; -}; - -template ::PRIMITIVE, - std::size_t AccumSize = sizeof(AccumT)> -struct sm90_tuning -{ - static constexpr int threads = 128; - static constexpr int items = 15; - - using delay_constructor = detail::default_delay_constructor_t; -}; - -// clang-format off -template struct sm90_tuning : tuning<192, 22, 168, 1140> {}; -template struct sm90_tuning : tuning<512, 12, 376, 1125> {}; -template struct sm90_tuning : tuning<128, 24, 648, 1245> {}; -template struct sm90_tuning : tuning<224, 24, 632, 1290> {}; - -template <> struct sm90_tuning : tuning<128, 24, 688, 1140> {}; -template <> struct sm90_tuning : tuning<224, 24, 576, 1215> {}; - -#if CUB_IS_INT128_ENABLED -template <> struct sm90_tuning< __int128_t, true, false, sizeof(__int128_t)> : tuning<576, 21, 860, 630> {}; -template <> struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)> : tuning<576, 21, 860, 630> {}; -#endif -// clang-format on - -} // namespace scan -} // namespace detail - -template -struct DeviceScanPolicy -{ - // For large values, use timesliced loads/stores to fit shared memory. - static constexpr bool LargeValues = sizeof(AccumT) > 128; - static constexpr BlockLoadAlgorithm ScanTransposedLoad = - LargeValues ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED - : BLOCK_LOAD_WARP_TRANSPOSE; - static constexpr BlockStoreAlgorithm ScanTransposedStore = - LargeValues ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED - : BLOCK_STORE_WARP_TRANSPOSE; - - template - using policy_t = - AgentScanPolicy, - DelayConstructorT>; - - /// SM350 - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> - { - // GTX Titan: 29.5B items/s (232.4 GB/s) @ 48M 32-bit T - using ScanPolicyT = policy_t<128, - 12, ///< Threads per block, items per thread - AccumT, - BLOCK_LOAD_DIRECT, - LOAD_CA, - BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, - BLOCK_SCAN_RAKING, - detail::default_delay_constructor_t>; - }; - - /// SM520 - struct Policy520 : ChainedPolicy<520, Policy520, Policy350> - { - // Titan X: 32.47B items/s @ 48M 32-bit T - using ScanPolicyT = policy_t<128, - 12, ///< Threads per block, items per thread - AccumT, - BLOCK_LOAD_DIRECT, - LOAD_CA, - ScanTransposedStore, - BLOCK_SCAN_WARP_SCANS, - detail::default_delay_constructor_t>; - }; - - /// SM600 - struct Policy600 : ChainedPolicy<600, Policy600, Policy520> - { - using ScanPolicyT = policy_t<128, - 15, ///< Threads per block, items per thread - AccumT, - ScanTransposedLoad, - LOAD_DEFAULT, - ScanTransposedStore, - BLOCK_SCAN_WARP_SCANS, - detail::default_delay_constructor_t>; - }; - - /// SM900 - struct Policy900 : ChainedPolicy<900, Policy900, Policy600> - { - using tuning = detail::scan::sm90_tuning::value>; - - using ScanPolicyT = policy_t; - }; - - using MaxPolicy = Policy900; -}; - /****************************************************************************** * Dispatch ******************************************************************************/ diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index 7755a3343..a2f3536ed 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -45,11 +46,11 @@ #include -#include - #include #include +#include + CUB_NAMESPACE_BEGIN /****************************************************************************** @@ -164,349 +165,6 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREA } -namespace detail -{ - -namespace select -{ - -template ::PRIMITIVE, - std::size_t InputSize = sizeof(InputT)> -struct sm90_tuning -{ - static constexpr int threads = 128; - - static constexpr int nominal_4b_items_per_thread = 10; - - static constexpr int items = CUB_MIN(nominal_4b_items_per_thread, - CUB_MAX(1, (nominal_4b_items_per_thread * 4 / InputSize))); - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<350, 450>; -}; - -// select::if -template -struct sm90_tuning -{ - static constexpr int threads = 256; - static constexpr int items = 22; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::no_delay_constructor_t<580>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 256; - static constexpr int items = 22; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; - - using delay_constructor = detail::fixed_delay_constructor_t<320, 605>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 384; - static constexpr int items = 17; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; - - using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 384; - static constexpr int items = 11; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; - - using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>; -}; - -#if CUB_IS_INT128_ENABLED -template <> -struct sm90_tuning<__int128_t, false, false, 4, false, sizeof(__int128_t)> -{ - static constexpr int threads = 512; - static constexpr int items = 5; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; -}; - -template <> -struct sm90_tuning<__uint128_t, false, false, 4, false, sizeof(__uint128_t)> -{ - static constexpr int threads = 512; - static constexpr int items = 5; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; -}; -#endif - -// select::flagged -template -struct sm90_tuning -{ - static constexpr int threads = 448; - static constexpr int items = 20; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::no_delay_constructor_t<715>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 448; - static constexpr int items = 20; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<504, 765>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 384; - static constexpr int items = 15; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<415, 1125>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 384; - static constexpr int items = 11; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<360, 1170>; -}; - -#if CUB_IS_INT128_ENABLED -template <> -struct sm90_tuning<__int128_t, true, false, 4, false, sizeof(__int128_t)> -{ - static constexpr int threads = 512; - static constexpr int items = 3; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; -}; - -template <> -struct sm90_tuning<__uint128_t, true, false, 4, false, sizeof(__uint128_t)> -{ - static constexpr int threads = 512; - static constexpr int items = 3; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; -}; -#endif - -// partition::if -template -struct sm90_tuning -{ - static constexpr int threads = 384; - static constexpr int items = 20; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<908, 995>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 320; - static constexpr int items = 14; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<500, 560>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 256; - static constexpr int items = 14; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<536, 1055>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 128; - static constexpr int items = 12; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; - - using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>; -}; - -#if CUB_IS_INT128_ENABLED -template <> -struct sm90_tuning<__int128_t, false, true, 4, false, sizeof(__int128_t)> -{ - static constexpr int threads = 192; - static constexpr int items = 5; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; - - using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; -}; - -template <> -struct sm90_tuning<__uint128_t, false, true, 4, false, sizeof(__uint128_t)> -{ - static constexpr int threads = 192; - static constexpr int items = 5; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; - - using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; -}; -#endif - -// partition::flagged -template -struct sm90_tuning -{ - static constexpr int threads = 256; - static constexpr int items = 20; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<580, 850>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 512; - static constexpr int items = 20; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<388, 1055>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 256; - static constexpr int items = 20; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<72, 1165>; -}; - -template -struct sm90_tuning -{ - static constexpr int threads = 224; - static constexpr int items = 6; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<532, 1180>; -}; - -#if CUB_IS_INT128_ENABLED -template <> -struct sm90_tuning<__int128_t, true, true, 4, false, sizeof(__int128_t)> -{ - static constexpr int threads = 160; - static constexpr int items = 5; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; -}; - -template <> -struct sm90_tuning<__uint128_t, true, true, 4, false, sizeof(__uint128_t)> -{ - static constexpr int threads = 160; - static constexpr int items = 5; - - static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; - - using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; -}; -#endif - -} - -template -struct device_select_policy_hub -{ - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> - { - static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 10; - - static constexpr int ITEMS_PER_THREAD = - CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, - CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(InputT)))); - - using SelectIfPolicyT = AgentSelectIfPolicy<128, - ITEMS_PER_THREAD, - BLOCK_LOAD_DIRECT, - MayAlias ? LOAD_CA : LOAD_LDG, - BLOCK_SCAN_WARP_SCANS, - detail::fixed_delay_constructor_t<350, 450>>; - }; - - struct Policy900 : ChainedPolicy<900, Policy900, Policy350> - { - static constexpr bool flagged = std::is_same::value == false; - - using tuning = detail::select::sm90_tuning; - - using SelectIfPolicyT = AgentSelectIfPolicy; - }; - - using MaxPolicy = Policy900; -}; - -} // detail - - /****************************************************************************** * Dispatch ******************************************************************************/ diff --git a/cub/device/dispatch/tuning/tuning_scan.cuh b/cub/device/dispatch/tuning/tuning_scan.cuh new file mode 100644 index 000000000..f40357158 --- /dev/null +++ b/cub/device/dispatch/tuning/tuning_scan.cuh @@ -0,0 +1,205 @@ +/****************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +CUB_NAMESPACE_BEGIN + +namespace detail +{ +namespace scan +{ + +enum class keep_rejects { no, yes }; +enum class primitive_accum { no, yes }; +enum class primitive_op { no, yes }; +enum class offset_size { _4, _8, unknown }; +enum class accum_size { _1, _2, _4, _8, _16, unknown }; + +template +constexpr primitive_accum is_primitive_accum() +{ + return Traits::PRIMITIVE ? primitive_accum::yes : primitive_accum::no; +} + +template +constexpr primitive_op is_primitive_op() +{ + return basic_binary_op_t::value ? primitive_op::yes : primitive_op::no; +} + +template +constexpr accum_size classify_accum_size() +{ + return sizeof(AccumT) == 1 ? accum_size::_1 + : sizeof(AccumT) == 2 ? accum_size::_2 + : sizeof(AccumT) == 4 ? accum_size::_4 + : sizeof(AccumT) == 8 ? accum_size::_8 + : sizeof(AccumT) == 16 ? accum_size::_16 + : accum_size::unknown; +} + +template +struct tuning +{ + static constexpr int threads = Threads; + static constexpr int items = Items; + + using delay_constructor = detail::fixed_delay_constructor_t; +}; + +template (), + accum_size AccumSize = classify_accum_size()> +struct sm90_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 15; + + using delay_constructor = detail::default_delay_constructor_t; +}; + +// clang-format off +template struct sm90_tuning : tuning<192, 22, 168, 1140> {}; +template struct sm90_tuning : tuning<512, 12, 376, 1125> {}; +template struct sm90_tuning : tuning<128, 24, 648, 1245> {}; +template struct sm90_tuning : tuning<224, 24, 632, 1290> {}; + +template <> struct sm90_tuning : tuning<128, 24, 688, 1140> {}; +template <> struct sm90_tuning : tuning<224, 24, 576, 1215> {}; + +#if CUB_IS_INT128_ENABLED +template <> struct sm90_tuning< __int128_t, primitive_op::yes, primitive_accum::no, accum_size::_16> : tuning<576, 21, 860, 630> {}; +template <> struct sm90_tuning<__uint128_t, primitive_op::yes, primitive_accum::no, accum_size::_16> : tuning<576, 21, 860, 630> {}; +#endif +// clang-format on + +} // namespace scan +} // namespace detail + + +template +struct DeviceScanPolicy +{ + // For large values, use timesliced loads/stores to fit shared memory. + static constexpr bool LargeValues = sizeof(AccumT) > 128; + static constexpr BlockLoadAlgorithm ScanTransposedLoad = + LargeValues ? BLOCK_LOAD_WARP_TRANSPOSE_TIMESLICED + : BLOCK_LOAD_WARP_TRANSPOSE; + static constexpr BlockStoreAlgorithm ScanTransposedStore = + LargeValues ? BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED + : BLOCK_STORE_WARP_TRANSPOSE; + + template + using policy_t = + AgentScanPolicy, + DelayConstructorT>; + + /// SM350 + struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + { + // GTX Titan: 29.5B items/s (232.4 GB/s) @ 48M 32-bit T + using ScanPolicyT = policy_t<128, + 12, ///< Threads per block, items per thread + AccumT, + BLOCK_LOAD_DIRECT, + LOAD_CA, + BLOCK_STORE_WARP_TRANSPOSE_TIMESLICED, + BLOCK_SCAN_RAKING, + detail::default_delay_constructor_t>; + }; + + /// SM520 + struct Policy520 : ChainedPolicy<520, Policy520, Policy350> + { + // Titan X: 32.47B items/s @ 48M 32-bit T + using ScanPolicyT = policy_t<128, + 12, ///< Threads per block, items per thread + AccumT, + BLOCK_LOAD_DIRECT, + LOAD_CA, + ScanTransposedStore, + BLOCK_SCAN_WARP_SCANS, + detail::default_delay_constructor_t>; + }; + + /// SM600 + struct Policy600 : ChainedPolicy<600, Policy600, Policy520> + { + using ScanPolicyT = policy_t<128, + 15, ///< Threads per block, items per thread + AccumT, + ScanTransposedLoad, + LOAD_DEFAULT, + ScanTransposedStore, + BLOCK_SCAN_WARP_SCANS, + detail::default_delay_constructor_t>; + }; + + /// SM900 + struct Policy900 : ChainedPolicy<900, Policy900, Policy600> + { + using tuning = detail::scan::sm90_tuning()>; + + using ScanPolicyT = policy_t; + }; + + using MaxPolicy = Policy900; +}; + + +CUB_NAMESPACE_END diff --git a/cub/device/dispatch/tuning/tuning_select_if.cuh b/cub/device/dispatch/tuning/tuning_select_if.cuh new file mode 100644 index 000000000..171d2d598 --- /dev/null +++ b/cub/device/dispatch/tuning/tuning_select_if.cuh @@ -0,0 +1,426 @@ +/****************************************************************************** + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +CUB_NAMESPACE_BEGIN + +namespace detail +{ + +namespace select +{ + +enum class flagged { no, yes }; +enum class keep_rejects { no, yes }; +enum class primitive { no, yes }; +enum class offset_size { _4, _8, unknown }; +enum class input_size { _1, _2, _4, _8, _16, unknown }; + +template +constexpr primitive is_primitive() +{ + return Traits::PRIMITIVE ? primitive::yes : primitive::no; +} + +template +constexpr flagged is_flagged() +{ + return std::is_same::value ? flagged::no : flagged::yes; +} + +template +constexpr keep_rejects are_rejects_kept() +{ + return KeepRejects ? keep_rejects::yes : keep_rejects::no; +} + +template +constexpr input_size classify_input_size() +{ + return sizeof(InputT) == 1 ? input_size::_1 + : sizeof(InputT) == 2 ? input_size::_2 + : sizeof(InputT) == 4 ? input_size::_4 + : sizeof(InputT) == 8 ? input_size::_8 + : sizeof(InputT) == 16 ? input_size::_16 + : input_size::unknown; +} + +template +constexpr offset_size classify_offset_size() +{ + return sizeof(OffsetT) == 4 ? offset_size::_4 + : sizeof(OffsetT) == 8 ? offset_size::_8 + : offset_size::unknown; +} + +template (), + input_size InputSize = classify_input_size()> +struct sm90_tuning +{ + static constexpr int threads = 128; + + static constexpr int nominal_4b_items_per_thread = 10; + + static constexpr int items = + CUB_MIN(nominal_4b_items_per_thread, + CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(InputT)))); + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<350, 450>; +}; + +// select::if +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 22; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<580>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 22; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<320, 605>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 17; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 512; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; +}; + +template <> +struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 512; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>; +}; +#endif + +// select::flagged +template +struct sm90_tuning +{ + static constexpr int threads = 448; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::no_delay_constructor_t<715>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 448; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<504, 765>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 15; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<415, 1125>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 11; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<360, 1170>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 512; + static constexpr int items = 3; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; +}; + +template <> +struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 512; + static constexpr int items = 3; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>; +}; +#endif + +// partition::if +template +struct sm90_tuning +{ + static constexpr int threads = 384; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<908, 995>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 320; + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<500, 560>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 14; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<536, 1055>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 128; + static constexpr int items = 12; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 192; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; +}; + +template <> +struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 192; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE; + + using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>; +}; +#endif + +// partition::flagged +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<580, 850>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 512; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<388, 1055>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 256; + static constexpr int items = 20; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<72, 1165>; +}; + +template +struct sm90_tuning +{ + static constexpr int threads = 224; + static constexpr int items = 6; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<532, 1180>; +}; + +#if CUB_IS_INT128_ENABLED +template <> +struct sm90_tuning<__int128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 160; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; +}; + +template <> +struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16> +{ + static constexpr int threads = 160; + static constexpr int items = 5; + + static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT; + + using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>; +}; +#endif + +} // namespace select + +template +struct device_select_policy_hub +{ + struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + { + static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 10; + + static constexpr int ITEMS_PER_THREAD = + CUB_MIN(NOMINAL_4B_ITEMS_PER_THREAD, + CUB_MAX(1, (NOMINAL_4B_ITEMS_PER_THREAD * 4 / sizeof(InputT)))); + + using SelectIfPolicyT = AgentSelectIfPolicy<128, + ITEMS_PER_THREAD, + BLOCK_LOAD_DIRECT, + MayAlias ? LOAD_CA : LOAD_LDG, + BLOCK_SCAN_WARP_SCANS, + detail::fixed_delay_constructor_t<350, 450>>; + }; + + struct Policy900 : ChainedPolicy<900, Policy900, Policy350> + { + using tuning = detail::select::sm90_tuning(), + select::are_rejects_kept(), + select::classify_offset_size()>; + + using SelectIfPolicyT = AgentSelectIfPolicy; + }; + + using MaxPolicy = Policy900; +}; + +} // detail + +CUB_NAMESPACE_END