Skip to content

Commit

Permalink
Offset
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 15, 2023
1 parent 082566c commit 553b40d
Showing 1 changed file with 28 additions and 26 deletions.
54 changes: 28 additions & 26 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ namespace select
template <class InputT,
bool Flagged,
bool KeepRejects,
std::size_t OffsetSize,
bool PrimitiveInput = Traits<InputT>::PRIMITIVE,
std::size_t InputSize = sizeof(InputT)>
struct sm90_tuning
Expand All @@ -193,7 +194,7 @@ struct sm90_tuning

// select::if
template <class Input>
struct sm90_tuning<Input, false, false, true, 1>
struct sm90_tuning<Input, false, false, 4, true, 1>
{
static constexpr int threads = 256;
static constexpr int items = 22;
Expand All @@ -204,7 +205,7 @@ struct sm90_tuning<Input, false, false, true, 1>
};

template <class Input>
struct sm90_tuning<Input, false, false, true, 2>
struct sm90_tuning<Input, false, false, 4, true, 2>
{
static constexpr int threads = 256;
static constexpr int items = 22;
Expand All @@ -215,7 +216,7 @@ struct sm90_tuning<Input, false, false, true, 2>
};

template <class Input>
struct sm90_tuning<Input, false, false, true, 4>
struct sm90_tuning<Input, false, false, 4, true, 4>
{
static constexpr int threads = 384;
static constexpr int items = 17;
Expand All @@ -226,7 +227,7 @@ struct sm90_tuning<Input, false, false, true, 4>
};

template <class Input>
struct sm90_tuning<Input, false, false, true, 8>
struct sm90_tuning<Input, false, false, 4, true, 8>
{
static constexpr int threads = 384;
static constexpr int items = 11;
Expand All @@ -238,7 +239,7 @@ struct sm90_tuning<Input, false, false, true, 8>

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, false, false, sizeof(__int128_t)>
struct sm90_tuning<__int128_t, false, false, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 5;
Expand All @@ -249,7 +250,7 @@ struct sm90_tuning<__int128_t, false, false, false, sizeof(__int128_t)>
};

template <>
struct sm90_tuning<__uint128_t, false, false, false, sizeof(__uint128_t)>
struct sm90_tuning<__uint128_t, false, false, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 5;
Expand All @@ -262,7 +263,7 @@ struct sm90_tuning<__uint128_t, false, false, false, sizeof(__uint128_t)>

// select::flagged
template <class Input>
struct sm90_tuning<Input, true, false, true, 1>
struct sm90_tuning<Input, true, false, 4, true, 1>
{
static constexpr int threads = 448;
static constexpr int items = 20;
Expand All @@ -273,7 +274,7 @@ struct sm90_tuning<Input, true, false, true, 1>
};

template <class Input>
struct sm90_tuning<Input, true, false, true, 2>
struct sm90_tuning<Input, true, false, 4, true, 2>
{
static constexpr int threads = 448;
static constexpr int items = 20;
Expand All @@ -284,7 +285,7 @@ struct sm90_tuning<Input, true, false, true, 2>
};

template <class Input>
struct sm90_tuning<Input, true, false, true, 4>
struct sm90_tuning<Input, true, false, 4, true, 4>
{
static constexpr int threads = 384;
static constexpr int items = 15;
Expand All @@ -295,7 +296,7 @@ struct sm90_tuning<Input, true, false, true, 4>
};

template <class Input>
struct sm90_tuning<Input, true, false, true, 8>
struct sm90_tuning<Input, true, false, 4, true, 8>
{
static constexpr int threads = 384;
static constexpr int items = 11;
Expand All @@ -307,7 +308,7 @@ struct sm90_tuning<Input, true, false, true, 8>

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, true, false, false, sizeof(__int128_t)>
struct sm90_tuning<__int128_t, true, false, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 3;
Expand All @@ -318,7 +319,7 @@ struct sm90_tuning<__int128_t, true, false, false, sizeof(__int128_t)>
};

template <>
struct sm90_tuning<__uint128_t, true, false, false, sizeof(__uint128_t)>
struct sm90_tuning<__uint128_t, true, false, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 512;
static constexpr int items = 3;
Expand All @@ -331,7 +332,7 @@ struct sm90_tuning<__uint128_t, true, false, false, sizeof(__uint128_t)>

// partition::if
template <class Input>
struct sm90_tuning<Input, false, true, true, 1>
struct sm90_tuning<Input, false, true, 4, true, 1>
{
static constexpr int threads = 384;
static constexpr int items = 20;
Expand All @@ -342,7 +343,7 @@ struct sm90_tuning<Input, false, true, true, 1>
};

template <class Input>
struct sm90_tuning<Input, false, true, true, 2>
struct sm90_tuning<Input, false, true, 4, true, 2>
{
static constexpr int threads = 384;
static constexpr int items = 20;
Expand All @@ -353,7 +354,7 @@ struct sm90_tuning<Input, false, true, true, 2>
};

template <class Input>
struct sm90_tuning<Input, false, true, true, 4>
struct sm90_tuning<Input, false, true, 4, true, 4>
{
static constexpr int threads = 256;
static constexpr int items = 14;
Expand All @@ -364,7 +365,7 @@ struct sm90_tuning<Input, false, true, true, 4>
};

template <class Input>
struct sm90_tuning<Input, false, true, true, 8>
struct sm90_tuning<Input, false, true, 4, true, 8>
{
static constexpr int threads = 128;
static constexpr int items = 12;
Expand All @@ -376,7 +377,7 @@ struct sm90_tuning<Input, false, true, true, 8>

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, false, true, false, sizeof(__int128_t)>
struct sm90_tuning<__int128_t, false, true, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 192;
static constexpr int items = 5;
Expand All @@ -387,7 +388,7 @@ struct sm90_tuning<__int128_t, false, true, false, sizeof(__int128_t)>
};

template <>
struct sm90_tuning<__uint128_t, false, true, false, sizeof(__uint128_t)>
struct sm90_tuning<__uint128_t, false, true, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 192;
static constexpr int items = 5;
Expand All @@ -400,7 +401,7 @@ struct sm90_tuning<__uint128_t, false, true, false, sizeof(__uint128_t)>

// partition::flagged
template <class Input>
struct sm90_tuning<Input, true, true, true, 1>
struct sm90_tuning<Input, true, true, 4, true, 1>
{
static constexpr int threads = 256;
static constexpr int items = 20;
Expand All @@ -411,7 +412,7 @@ struct sm90_tuning<Input, true, true, true, 1>
};

template <class Input>
struct sm90_tuning<Input, true, true, true, 2>
struct sm90_tuning<Input, true, true, 4, true, 2>
{
static constexpr int threads = 512;
static constexpr int items = 20;
Expand All @@ -422,7 +423,7 @@ struct sm90_tuning<Input, true, true, true, 2>
};

template <class Input>
struct sm90_tuning<Input, true, true, true, 4>
struct sm90_tuning<Input, true, true, 4, true, 4>
{
static constexpr int threads = 256;
static constexpr int items = 20;
Expand All @@ -433,7 +434,7 @@ struct sm90_tuning<Input, true, true, true, 4>
};

template <class Input>
struct sm90_tuning<Input, true, true, true, 8>
struct sm90_tuning<Input, true, true, 4, true, 8>
{
static constexpr int threads = 128;
static constexpr int items = 9;
Expand All @@ -445,7 +446,7 @@ struct sm90_tuning<Input, true, true, true, 8>

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, true, true, false, sizeof(__int128_t)>
struct sm90_tuning<__int128_t, true, true, 4, false, sizeof(__int128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;
Expand All @@ -456,7 +457,7 @@ struct sm90_tuning<__int128_t, true, true, false, sizeof(__int128_t)>
};

template <>
struct sm90_tuning<__uint128_t, true, true, false, sizeof(__uint128_t)>
struct sm90_tuning<__uint128_t, true, true, 4, false, sizeof(__uint128_t)>
{
static constexpr int threads = 160;
static constexpr int items = 5;
Expand All @@ -469,7 +470,7 @@ struct sm90_tuning<__uint128_t, true, true, false, sizeof(__uint128_t)>

}

template <class InputT, class FlagT, bool MayAlias, bool KeepRejects>
template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects>
struct device_select_policy_hub
{
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
Expand All @@ -492,7 +493,7 @@ struct device_select_policy_hub
{
static constexpr bool flagged = std::is_same<FlagT, NullType>::value == false;

using tuning = detail::select::sm90_tuning<InputT, flagged, KeepRejects>;
using tuning = detail::select::sm90_tuning<InputT, flagged, KeepRejects, sizeof(OffsetT)>;

using SelectIfPolicyT = AgentSelectIfPolicy<tuning::threads,
tuning::items,
Expand Down Expand Up @@ -554,6 +555,7 @@ template <typename InputIteratorT,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>,
cub::detail::value_t<FlagsInputIteratorT>,
OffsetT,
MayAlias,
KEEP_REJECTS>>
struct DispatchSelectIf : SelectedPolicy
Expand Down

0 comments on commit 553b40d

Please sign in to comment.