Skip to content

Commit

Permalink
Tune select and partition for SM90
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 15, 2023
1 parent f76fbda commit b72e594
Showing 1 changed file with 187 additions and 13 deletions.
200 changes: 187 additions & 13 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,168 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREA
namespace detail
{

template <class InputT, bool MayAlias>
namespace select
{

template <class InputT,
bool KeepRejects,
bool PrimitiveInput = Traits<InputT>::PRIMITIVE,
std::size_t InputSize = sizeof(InputT)>
struct sm90_tuning
{
static constexpr int threads = 128;

static constexpr int nominal_4b_items_per_thread = 10;

static constexpr int items = CUB_MIN(nominal_4b_items_per_thread,
CUB_MAX(1, (nominal_4b_items_per_thread * 4 / InputSize)));

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
};

// select
template <class Input>
struct sm90_tuning<Input, false, true, 1>
{
static constexpr int threads = 448;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<692, 715>;
};

template <class Input>
struct sm90_tuning<Input, false, true, 2>
{
static constexpr int threads = 448;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<504, 765>;
};

template <class Input>
struct sm90_tuning<Input, false, true, 4>
{
static constexpr int threads = 384;
static constexpr int items = 15;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<415, 1125>;
};

template <class Input>
struct sm90_tuning<Input, false, true, 8>
{
static constexpr int threads = 384;
static constexpr int items = 11;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<948, 1090>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, false, sizeof(__int128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 4;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>;
};

template <>
struct sm90_tuning<__uint128_t, false, false, sizeof(__uint128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 3;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>;
};
#endif

// partition
template <class Input>
struct sm90_tuning<Input, true, true, 1>
{
static constexpr int threads = 384;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<908, 995>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 2>
{
static constexpr int threads = 128;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 4>
{
static constexpr int threads = 128;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 8>
{
static constexpr int threads = 128;
static constexpr int items = 9;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<752, 1125>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, true, false, sizeof(__int128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>;
};

template <>
struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>;
};
#endif

}

template <class InputT, bool MayAlias, bool KeepRejects>
struct device_select_policy_hub
{
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
Expand All @@ -186,7 +347,19 @@ struct device_select_policy_hub
detail::fixed_delay_constructor_t<350, 450>>;
};

using MaxPolicy = Policy350;
struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
{
using tuning = detail::select::sm90_tuning<InputT, KeepRejects>;

using SelectIfPolicyT = AgentSelectIfPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

using MaxPolicy = Policy900;
};

} // detail
Expand Down Expand Up @@ -226,17 +399,18 @@ struct device_select_policy_hub
* @tparam KEEP_REJECTS
* Whether or not we push rejected items to the back of the output
*/
template <typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
bool KEEP_REJECTS,
bool MayAlias = false,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>, MayAlias>>
template <
typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
bool KEEP_REJECTS,
bool MayAlias = false,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>, MayAlias, KEEP_REJECTS>>
struct DispatchSelectIf : SelectedPolicy
{
/******************************************************************************
Expand Down

0 comments on commit b72e594

Please sign in to comment.