Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune Select and Partition on A100 #289

Merged
merged 6 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 331 additions & 8 deletions cub/cub/device/dispatch/tuning/tuning_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primit
static constexpr int threads = 256;
static constexpr int items = 22;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<320, 605>;
};
Expand All @@ -135,7 +135,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primit
static constexpr int threads = 384;
static constexpr int items = 17;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<76, 1150>;
};
Expand All @@ -146,7 +146,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primit
static constexpr int threads = 384;
static constexpr int items = 11;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<380, 1140>;
};
Expand Down Expand Up @@ -284,7 +284,7 @@ struct sm90_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primi
static constexpr int threads = 128;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<512, 1075>;
};
Expand All @@ -296,7 +296,7 @@ struct sm90_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4,
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
Expand All @@ -307,7 +307,7 @@ struct sm90_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4,
static constexpr int threads = 192;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_TRANSPOSE;
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<1616, 1115>;
};
Expand Down Expand Up @@ -382,12 +382,310 @@ struct sm90_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
};
#endif


template <class InputT,
flagged,
keep_rejects,
offset_size OffsetSize,
primitive = is_primitive<InputT>(),
input_size InputSize = classify_input_size<InputT>()>
struct sm80_tuning
{
static constexpr int threads = 128;

static constexpr int nominal_4b_items_per_thread = 10;

static constexpr int items =
CUB_MIN(nominal_4b_items_per_thread,
CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(InputT))));

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
};

// select::if
template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 992;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<395>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 576;
static constexpr int items = 14;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<870>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 256;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1130>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 192;
static constexpr int items = 10;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<832, 1165>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 384;
static constexpr int items = 4;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1140>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 384;
static constexpr int items = 4;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1140>;
};
#endif

// select::flagged
template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 224;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<735>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 256;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1155>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 320;
static constexpr int items = 10;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<124, 1115>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::no, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 384;
static constexpr int items = 6;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1130>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::no, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<464, 1025>;
};
#endif

// partition::if
template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 512;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<510>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 224;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1045>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 192;
static constexpr int items = 15;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1040>;
};

template <class Input>
struct sm80_tuning<Input, flagged::no, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 192;
static constexpr int items = 10;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<68, 1160>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::no, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};
#endif

// partition::flagged
template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_1>
{
static constexpr int threads = 512;
static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<595>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_2>
{
static constexpr int threads = 224;
static constexpr int items = 18;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1105>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_4>
{
static constexpr int threads = 192;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::fixed_delay_constructor_t<912, 1025>;
};

template <class Input>
struct sm80_tuning<Input, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::yes, input_size::_8>
{
static constexpr int threads = 192;
static constexpr int items = 12;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<884, 1130>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm80_tuning<__int128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};

template <>
struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4, primitive::no, input_size::_16>
{
static constexpr int threads = 256;
static constexpr int items = 5;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::fixed_delay_constructor_t<400, 1090>;
};
#endif

} // namespace select

template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects>
struct device_select_policy_hub
{
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
struct DefaultTuning
{
static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 10;

Expand All @@ -403,7 +701,32 @@ struct device_select_policy_hub
detail::fixed_delay_constructor_t<350, 450>>;
};

struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
struct Policy350
: DefaultTuning
, ChainedPolicy<350, Policy350, Policy350>
{};

struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
{
using tuning = detail::select::sm80_tuning<InputT,
select::is_flagged<FlagT>(),
select::are_rejects_kept<KeepRejects>(),
select::classify_offset_size<OffsetT>()>;

using SelectIfPolicyT = AgentSelectIfPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

struct Policy860
: DefaultTuning
, ChainedPolicy<860, Policy860, Policy800>
{};

struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::select::sm90_tuning<InputT,
select::is_flagged<FlagT>(),
Expand Down
Loading
Loading