diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index b9a077b4f..4c2256c72 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -175,6 +175,7 @@ namespace select template ::PRIMITIVE, std::size_t InputSize = sizeof(InputT)> struct sm90_tuning @@ -193,7 +194,7 @@ struct sm90_tuning // select::if template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 256; static constexpr int items = 22; @@ -204,7 +205,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 256; static constexpr int items = 22; @@ -215,7 +216,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 17; @@ -226,7 +227,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 11; @@ -238,7 +239,7 @@ struct sm90_tuning #if CUB_IS_INT128_ENABLED template <> -struct sm90_tuning<__int128_t, false, false, false, sizeof(__int128_t)> +struct sm90_tuning<__int128_t, false, false, 4, false, sizeof(__int128_t)> { static constexpr int threads = 512; static constexpr int items = 5; @@ -249,7 +250,7 @@ struct sm90_tuning<__int128_t, false, false, false, sizeof(__int128_t)> }; template <> -struct sm90_tuning<__uint128_t, false, false, false, sizeof(__uint128_t)> +struct sm90_tuning<__uint128_t, false, false, 4, false, sizeof(__uint128_t)> { static constexpr int threads = 512; static constexpr int items = 5; @@ -262,7 +263,7 @@ struct sm90_tuning<__uint128_t, false, false, false, sizeof(__uint128_t)> // select::flagged template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 448; static constexpr int items = 20; @@ -273,7 +274,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 448; static constexpr int items = 20; @@ -284,7 +285,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 15; @@ -295,7 +296,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 11; @@ -307,7 +308,7 @@ struct sm90_tuning #if CUB_IS_INT128_ENABLED template <> -struct sm90_tuning<__int128_t, true, false, false, sizeof(__int128_t)> +struct sm90_tuning<__int128_t, true, false, 4, false, sizeof(__int128_t)> { static constexpr int threads = 512; static constexpr int items = 3; @@ -318,7 +319,7 @@ struct sm90_tuning<__int128_t, true, false, false, sizeof(__int128_t)> }; template <> -struct sm90_tuning<__uint128_t, true, false, false, sizeof(__uint128_t)> +struct sm90_tuning<__uint128_t, true, false, 4, false, sizeof(__uint128_t)> { static constexpr int threads = 512; static constexpr int items = 3; @@ -331,7 +332,7 @@ struct sm90_tuning<__uint128_t, true, false, false, sizeof(__uint128_t)> // partition::if template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 20; @@ -342,7 +343,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 384; static constexpr int items = 20; @@ -353,7 +354,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 256; static constexpr int items = 14; @@ -364,7 +365,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 128; static constexpr int items = 12; @@ -376,7 +377,7 @@ struct sm90_tuning #if CUB_IS_INT128_ENABLED template <> -struct sm90_tuning<__int128_t, false, true, false, sizeof(__int128_t)> +struct sm90_tuning<__int128_t, false, true, 4, false, sizeof(__int128_t)> { static constexpr int threads = 192; static constexpr int items = 5; @@ -387,7 +388,7 @@ struct sm90_tuning<__int128_t, false, true, false, sizeof(__int128_t)> }; template <> -struct sm90_tuning<__uint128_t, false, true, false, sizeof(__uint128_t)> +struct sm90_tuning<__uint128_t, false, true, 4, false, sizeof(__uint128_t)> { static constexpr int threads = 192; static constexpr int items = 5; @@ -400,7 +401,7 @@ struct sm90_tuning<__uint128_t, false, true, false, sizeof(__uint128_t)> // partition::flagged template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 256; static constexpr int items = 20; @@ -411,7 +412,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 512; static constexpr int items = 20; @@ -422,7 +423,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 256; static constexpr int items = 20; @@ -433,7 +434,7 @@ struct sm90_tuning }; template -struct sm90_tuning +struct sm90_tuning { static constexpr int threads = 128; static constexpr int items = 9; @@ -445,7 +446,7 @@ struct sm90_tuning #if CUB_IS_INT128_ENABLED template <> -struct sm90_tuning<__int128_t, true, true, false, sizeof(__int128_t)> +struct sm90_tuning<__int128_t, true, true, 4, false, sizeof(__int128_t)> { static constexpr int threads = 160; static constexpr int items = 5; @@ -456,7 +457,7 @@ struct sm90_tuning<__int128_t, true, true, false, sizeof(__int128_t)> }; template <> -struct sm90_tuning<__uint128_t, true, true, false, sizeof(__uint128_t)> +struct sm90_tuning<__uint128_t, true, true, 4, false, sizeof(__uint128_t)> { static constexpr int threads = 160; static constexpr int items = 5; @@ -469,7 +470,7 @@ struct sm90_tuning<__uint128_t, true, true, false, sizeof(__uint128_t)> } -template +template struct device_select_policy_hub { struct Policy350 : ChainedPolicy<350, Policy350, Policy350> @@ -492,7 +493,7 @@ struct device_select_policy_hub { static constexpr bool flagged = std::is_same::value == false; - using tuning = detail::select::sm90_tuning; + using tuning = detail::select::sm90_tuning; using SelectIfPolicyT = AgentSelectIfPolicy, cub::detail::value_t, + OffsetT, MayAlias, KEEP_REJECTS>> struct DispatchSelectIf : SelectedPolicy