Skip to content

Commit

Permalink
Tune select and partition for SM90
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 15, 2023
1 parent f76fbda commit b4b28c6
Showing 1 changed file with 324 additions and 9 deletions.
333 changes: 324 additions & 9 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

#include "cub/agent/single_pass_scan_operators.cuh"
#include "cub/block/block_load.cuh"
#include <nv/target>

#include <cstdio>
Expand Down Expand Up @@ -167,7 +169,308 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREA
namespace detail
{

template <class InputT, bool MayAlias>
namespace select
{

template <class InputT,
bool Flagged,
bool KeepRejects,
std::size_t OffsetSize,
bool PrimitiveInput = Traits<InputT>::PRIMITIVE,
std::size_t InputSize = sizeof(InputT)>
struct sm90_tuning
{
static constexpr int threads = 128;

static constexpr int nominal_4b_items_per_thread = 10;

static constexpr int items = CUB_MIN(nominal_4b_items_per_thread,
CUB_MAX(1, (nominal_4b_items_per_thread * 4 / InputSize)));

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
};

// select::if
template <class Input>
struct sm90_tuning<Input, false, false, 4, true, 1>
{
static constexpr int threads = 256;
static constexpr int items = 22;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<580>;
};

template <class Input>
struct sm90_tuning<Input, false, false, 4, true, 2>
{
static constexpr int threads = 256;
static constexpr int items = 22;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<320, 605>;
};

template <class Input>
struct sm90_tuning<Input, false, false, 4, true, 4>
{
static constexpr int threads = 384;
static constexpr int items = 17;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>;
};

template <class Input>
struct sm90_tuning<Input, false, false, 4, true, 8>
{
static constexpr int threads = 384;
static constexpr int items = 11;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, false, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>;
};

template <>
struct sm90_tuning<__uint128_t, false, false, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>;
};
#endif

// select::flagged
template <class Input>
struct sm90_tuning<Input, true, false, 4, true, 1>
{
static constexpr int threads = 448;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<715>;
};

template <class Input>
struct sm90_tuning<Input, true, false, 4, true, 2>
{
static constexpr int threads = 448;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<504, 765>;
};

template <class Input>
struct sm90_tuning<Input, true, false, 4, true, 4>
{
static constexpr int threads = 384;
static constexpr int items = 15;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<415, 1125>;
};

template <class Input>
struct sm90_tuning<Input, true, false, 4, true, 8>
{
static constexpr int threads = 384;
static constexpr int items = 11;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<360, 1170>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, true, false, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 3;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>;
};

template <>
struct sm90_tuning<__uint128_t, true, false, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 3;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>;
};
#endif

// partition::if
template <class Input>
struct sm90_tuning<Input, false, true, 4, true, 1>
{
static constexpr int threads = 384;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<908, 995>;
};

template <class Input>
struct sm90_tuning<Input, false, true, 4, true, 2>
{
static constexpr int threads = 384;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::exponential_backon_jitter_window_constructor_t<1180, 715>;
};

template <class Input>
struct sm90_tuning<Input, false, true, 4, true, 4>
{
static constexpr int threads = 256;
static constexpr int items = 14;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<536, 1055>;
};

template <class Input>
struct sm90_tuning<Input, false, true, 4, true, 8>
{
static constexpr int threads = 128;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, true, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};

template <>
struct sm90_tuning<__uint128_t, false, true, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
#endif

// partition::flagged
template <class Input>
struct sm90_tuning<Input, true, true, 4, true, 1>
{
static constexpr int threads = 256;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<580, 850>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 4, true, 2>
{
static constexpr int threads = 512;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<388, 1055>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 4, true, 4>
{
static constexpr int threads = 256;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<72, 1165>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 4, true, 8>
{
static constexpr int threads = 128;
static constexpr int items = 9;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<468, 1175>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, true, true, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>;
};

template <>
struct sm90_tuning<__uint128_t, true, true, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<720, 1105>;
};
#endif

}

template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects>
struct device_select_policy_hub
{
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
Expand All @@ -186,7 +489,21 @@ struct device_select_policy_hub
detail::fixed_delay_constructor_t<350, 450>>;
};

using MaxPolicy = Policy350;
struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
{
static constexpr bool flagged = std::is_same<FlagT, NullType>::value == false;

using tuning = detail::select::sm90_tuning<InputT, flagged, KeepRejects, sizeof(OffsetT)>;

using SelectIfPolicyT = AgentSelectIfPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

using MaxPolicy = Policy900;
};

} // detail
Expand Down Expand Up @@ -236,19 +553,17 @@ template <typename InputIteratorT,
bool KEEP_REJECTS,
bool MayAlias = false,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>, MayAlias>>
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>,
cub::detail::value_t<FlagsInputIteratorT>,
OffsetT,
MayAlias,
KEEP_REJECTS>>
struct DispatchSelectIf : SelectedPolicy
{
/******************************************************************************
* Types and constants
******************************************************************************/

// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;

// Tile status descriptor interface type
using ScanTileStateT = ScanTileState<OffsetT>;

Expand Down

0 comments on commit b4b28c6

Please sign in to comment.