Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune RLE on A100 #295

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
274 changes: 267 additions & 7 deletions cub/cub/device/dispatch/tuning/tuning_run_length_encode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,108 @@ struct sm90_tuning<LengthT, __uint128_t, primitive_length::yes, primitive_key::n
};
#endif

template <class LengthT,
class KeyT,
primitive_length PrimitiveLength = is_primitive_length<LengthT>(),
primitive_key PrimitiveKey = is_primitive_key<KeyT>(),
length_size LengthSize = classify_length_size<LengthT>(),
key_size KeySize = classify_key_size<KeyT>()>
struct sm80_tuning
{
static constexpr int max_input_bytes = CUB_MAX(sizeof(KeyT), sizeof(LengthT));
static constexpr int combined_input_bytes = sizeof(KeyT) + sizeof(LengthT);

static constexpr int threads = 128;

static constexpr int nominal_4b_items_per_thread = 6;

static constexpr int items =
(max_input_bytes <= 8)
? 6
: CUB_MIN(nominal_4b_items_per_thread,
CUB_MAX(1,
((nominal_4b_items_per_thread * 8) + combined_input_bytes - 1) /
combined_input_bytes));

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::default_reduce_by_key_delay_constructor_t<LengthT, int>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_1>
{
static constexpr int threads = 256;

static constexpr int items = 14;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<640>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_2>
{
static constexpr int threads = 256;

static constexpr int items = 13;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<900>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_4>
{
static constexpr int threads = 256;

static constexpr int items = 13;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

using delay_constructor = detail::no_delay_constructor_t<1080>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_8>
{
static constexpr int threads = 224;

static constexpr int items = 9;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<1075>;
};

#if CUB_IS_INT128_ENABLED
template <class LengthT>
struct sm80_tuning<LengthT, __int128_t, primitive_length::yes, primitive_key::no, length_size::_4, key_size::_16>
{
static constexpr int threads = 128;

static constexpr int items = 7;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<630>;
};

template <class LengthT>
struct sm80_tuning<LengthT, __uint128_t, primitive_length::yes, primitive_key::no, length_size::_4, key_size::_16>
{
static constexpr int threads = 128;

static constexpr int items = 7;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

using delay_constructor = detail::no_delay_constructor_t<630>;
};
#endif

} // namespace encode

namespace non_trivial_runs
Expand Down Expand Up @@ -295,6 +397,114 @@ struct sm90_tuning<LengthT, __uint128_t, primitive_length::yes, primitive_key::n
};
#endif

template <class LengthT,
class KeyT,
primitive_length PrimitiveLength = is_primitive_length<LengthT>(),
primitive_key PrimitiveKey = is_primitive_key<KeyT>(),
length_size LengthSize = classify_length_size<LengthT>(),
key_size KeySize = classify_key_size<KeyT>()>
struct sm80_tuning
{
static constexpr int threads = 96;

static constexpr int nominal_4b_items_per_thread = 15;

static constexpr int items = CUB_MIN(nominal_4b_items_per_thread,
CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(KeyT))));

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

static constexpr bool store_with_time_slicing = true;

using delay_constructor = detail::default_reduce_by_key_delay_constructor_t<LengthT, int>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_1>
{
static constexpr int threads = 192;

static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;

static constexpr bool store_with_time_slicing = false;

using delay_constructor = detail::no_delay_constructor_t<630>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_2>
{
static constexpr int threads = 192;

static constexpr int items = 20;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

static constexpr bool store_with_time_slicing = false;

using delay_constructor = detail::no_delay_constructor_t<1015>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_4>
{
static constexpr int threads = 224;

static constexpr int items = 15;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

static constexpr bool store_with_time_slicing = false;

using delay_constructor = detail::no_delay_constructor_t<915>;
};

template <class LengthT, class KeyT>
struct sm80_tuning<LengthT, KeyT, primitive_length::yes, primitive_key::yes, length_size::_4, key_size::_8>
{
static constexpr int threads = 256;

static constexpr int items = 13;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

static constexpr bool store_with_time_slicing = false;

using delay_constructor = detail::no_delay_constructor_t<1065>;
};

#if CUB_IS_INT128_ENABLED
template <class LengthT>
struct sm80_tuning<LengthT, __int128_t, primitive_length::yes, primitive_key::no, length_size::_4, key_size::_16>
{
static constexpr int threads = 192;

static constexpr int items = 13;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

static constexpr bool store_with_time_slicing = false;

using delay_constructor = detail::no_delay_constructor_t<1050>;
};

template <class LengthT>
struct sm80_tuning<LengthT, __uint128_t, primitive_length::yes, primitive_key::no, length_size::_4, key_size::_16>
{
static constexpr int threads = 192;

static constexpr int items = 13;

static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;

static constexpr bool store_with_time_slicing = false;

using delay_constructor = detail::no_delay_constructor_t<1050>;
};
#endif

} // namespace non_trivial_runs


Expand All @@ -306,8 +516,7 @@ struct device_run_length_encode_policy_hub
static constexpr int MAX_INPUT_BYTES = CUB_MAX(sizeof(KeyT), sizeof(LengthT));
static constexpr int COMBINED_INPUT_BYTES = sizeof(KeyT) + sizeof(LengthT);

/// SM35
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
struct DefaultTuning
{
static constexpr int NOMINAL_4B_ITEMS_PER_THREAD = 6;
static constexpr int ITEMS_PER_THREAD =
Expand All @@ -327,8 +536,34 @@ struct device_run_length_encode_policy_hub
detail::default_reduce_by_key_delay_constructor_t<LengthT, int>>;
};

/// SM35
struct Policy350
: DefaultTuning
, ChainedPolicy<350, Policy350, Policy350>
{};

/// SM80
struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
{
using tuning = detail::rle::encode::sm80_tuning<LengthT, KeyT>;

using ReduceByKeyPolicyT =
AgentReduceByKeyPolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
LOAD_DEFAULT,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

// SM86
struct Policy860
: DefaultTuning
, ChainedPolicy<860, Policy860, Policy800>
{};

/// SM90
struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::rle::encode::sm90_tuning<LengthT, KeyT>;

Expand All @@ -347,8 +582,7 @@ struct device_run_length_encode_policy_hub
template <class LengthT, class KeyT>
struct device_non_trivial_runs_policy_hub
{
/// SM35
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
struct DefaultTuning
{
enum
{
Expand All @@ -367,9 +601,35 @@ struct device_non_trivial_runs_policy_hub
BLOCK_SCAN_WARP_SCANS,
detail::default_reduce_by_key_delay_constructor_t<int, int>>;
};


/// SM35
struct Policy350
: DefaultTuning
, ChainedPolicy<350, Policy350, Policy350>
{};

// SM80
struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
{
using tuning = detail::rle::non_trivial_runs::sm80_tuning<LengthT, KeyT>;

using RleSweepPolicyT = AgentRlePolicy<tuning::threads,
tuning::items,
tuning::load_algorithm,
LOAD_DEFAULT,
tuning::store_with_time_slicing,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

// SM86
struct Policy860
: DefaultTuning
, ChainedPolicy<860, Policy860, Policy800>
{};

// SM90
struct Policy900 : ChainedPolicy<900, Policy900, Policy350>
struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
{
using tuning = detail::rle::non_trivial_runs::sm90_tuning<LengthT, KeyT>;

Expand Down
Loading