Skip to content

Commit

Permalink
Split tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 15, 2023
1 parent b72e594 commit 082566c
Showing 1 changed file with 183 additions and 44 deletions.
227 changes: 183 additions & 44 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

#include "cub/agent/single_pass_scan_operators.cuh"
#include "cub/block/block_load.cuh"
#include <nv/target>

#include <cstdio>
Expand Down Expand Up @@ -171,6 +173,7 @@ namespace select
{

template <class InputT,
bool Flagged,
bool KeepRejects,
bool PrimitiveInput = Traits<InputT>::PRIMITIVE,
std::size_t InputSize = sizeof(InputT)>
Expand All @@ -188,20 +191,89 @@ struct sm90_tuning
using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
};

// select
// select::if
template <class Input>
struct sm90_tuning<Input, false, true, 1>
struct sm90_tuning<Input, false, false, true, 1>
{
static constexpr int threads = 256;
static constexpr int items = 22;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<580>;
};

template <class Input>
struct sm90_tuning<Input, false, false, true, 2>
{
static constexpr int threads = 256;
static constexpr int items = 22;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<320, 605>;
};

template <class Input>
struct sm90_tuning<Input, false, false, true, 4>
{
static constexpr int threads = 384;
static constexpr int items = 17;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>;
};

template <class Input>
struct sm90_tuning<Input, false, false, true, 8>
{
static constexpr int threads = 384;
static constexpr int items = 11;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, false, false, sizeof(__int128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>;
};

template <>
struct sm90_tuning<__uint128_t, false, false, false, sizeof(__uint128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>;
};
#endif

// select::flagged
template <class Input>
struct sm90_tuning<Input, true, false, true, 1>
{
static constexpr int threads = 448;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<692, 715>;
using delay_constructor = detail::no_delay_constructor_t<715>;
};

template <class Input>
struct sm90_tuning<Input, false, true, 2>
struct sm90_tuning<Input, true, false, true, 2>
{
static constexpr int threads = 448;
static constexpr int items = 20;
Expand All @@ -212,7 +284,7 @@ struct sm90_tuning<Input, false, true, 2>
};

template <class Input>
struct sm90_tuning<Input, false, true, 4>
struct sm90_tuning<Input, true, false, true, 4>
{
static constexpr int threads = 384;
static constexpr int items = 15;
Expand All @@ -223,30 +295,30 @@ struct sm90_tuning<Input, false, true, 4>
};

template <class Input>
struct sm90_tuning<Input, false, true, 8>
struct sm90_tuning<Input, true, false, true, 8>
{
static constexpr int threads = 384;
static constexpr int items = 11;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<948, 1090>;
using delay_constructor = detail::fixed_delay_constructor_t<360, 1170>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, false, sizeof(__int128_t)>
struct sm90_tuning<__int128_t, true, false, false, sizeof(__int128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 4;
static constexpr int items = 3;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<460, 1145>;
using delay_constructor = detail::fixed_delay_constructor_t<284, 1130>;
};

template <>
struct sm90_tuning<__uint128_t, false, false, sizeof(__uint128_t)>
struct sm90_tuning<__uint128_t, true, false, false, sizeof(__uint128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 3;
Expand All @@ -257,9 +329,9 @@ struct sm90_tuning<__uint128_t, false, false, sizeof(__uint128_t)>
};
#endif

// partition
// partition::if
template <class Input>
struct sm90_tuning<Input, true, true, 1>
struct sm90_tuning<Input, false, true, true, 1>
{
static constexpr int threads = 384;
static constexpr int items = 20;
Expand All @@ -270,41 +342,110 @@ struct sm90_tuning<Input, true, true, 1>
};

template <class Input>
struct sm90_tuning<Input, true, true, 2>
struct sm90_tuning<Input, false, true, true, 2>
{
static constexpr int threads = 128;
static constexpr int items = 18;
static constexpr int threads = 384;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>;
using delay_constructor = detail::exponential_backon_jitter_window_constructor_t<1180, 715>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 4>
struct sm90_tuning<Input, false, true, true, 4>
{
static constexpr int threads = 256;
static constexpr int items = 14;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<536, 1055>;
};

template <class Input>
struct sm90_tuning<Input, false, true, true, 8>
{
static constexpr int threads = 128;
static constexpr int items = 18;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, true, false, sizeof(__int128_t)>
{
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};

template <>
struct sm90_tuning<__uint128_t, false, true, false, sizeof(__uint128_t)>
{
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
#endif

// partition::flagged
template <class Input>
struct sm90_tuning<Input, true, true, true, 1>
{
static constexpr int threads = 256;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<580, 850>;
};

template <class Input>
struct sm90_tuning<Input, true, true, true, 2>
{
static constexpr int threads = 512;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<388, 1055>;
};

template <class Input>
struct sm90_tuning<Input, true, true, true, 4>
{
static constexpr int threads = 256;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<988, 1060>;
using delay_constructor = detail::fixed_delay_constructor_t<72, 1165>;
};

template <class Input>
struct sm90_tuning<Input, true, true, 8>
struct sm90_tuning<Input, true, true, true, 8>
{
static constexpr int threads = 128;
static constexpr int items = 9;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<752, 1125>;
using delay_constructor = detail::fixed_delay_constructor_t<468, 1175>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, true, false, sizeof(__int128_t)>
struct sm90_tuning<__int128_t, true, true, false, sizeof(__int128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;
Expand All @@ -315,7 +456,7 @@ struct sm90_tuning<__int128_t, true, false, sizeof(__int128_t)>
};

template <>
struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)>
struct sm90_tuning<__uint128_t, true, true, false, sizeof(__uint128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;
Expand All @@ -328,7 +469,7 @@ struct sm90_tuning<__uint128_t, true, false, sizeof(__uint128_t)>

}

template <class InputT, bool MayAlias, bool KeepRejects>
template <class InputT, class FlagT, bool MayAlias, bool KeepRejects>
struct device_select_policy_hub
{
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
Expand All @@ -349,7 +490,9 @@ struct device_select_policy_hub

struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
{
using tuning = detail::select::sm90_tuning<InputT, KeepRejects>;
static constexpr bool flagged = std::is_same<FlagT, NullType>::value == false;

using tuning = detail::select::sm90_tuning<InputT, flagged, KeepRejects>;

using SelectIfPolicyT = AgentSelectIfPolicy<tuning::threads,
tuning::items,
Expand Down Expand Up @@ -399,30 +542,26 @@ struct device_select_policy_hub
* @tparam KEEP_REJECTS
* Whether or not we push rejected items to the back of the output
*/
template <
typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
bool KEEP_REJECTS,
bool MayAlias = false,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>, MayAlias, KEEP_REJECTS>>
template <typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
bool KEEP_REJECTS,
bool MayAlias = false,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>,
cub::detail::value_t<FlagsInputIteratorT>,
MayAlias,
KEEP_REJECTS>>
struct DispatchSelectIf : SelectedPolicy
{
/******************************************************************************
* Types and constants
******************************************************************************/

// The input value type
using InputT = cub::detail::value_t<InputIteratorT>;

// The flag value type
using FlagT = cub::detail::value_t<FlagsInputIteratorT>;

// Tile status descriptor interface type
using ScanTileStateT = ScanTileState<OffsetT>;

Expand Down

0 comments on commit 082566c

Please sign in to comment.