Skip to content

Commit

Permalink
Primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 7, 2023
1 parent 6a9be57 commit 566593d
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,19 @@ namespace detail
namespace scan
{
// TODO Only for sum/max as ScanOp?
template <class AccumT, std::size_t AccumSize = sizeof(AccumT)>
template <class AccumT,
std::size_t AccumSize = sizeof(AccumT),
bool Primitive = Traits<AccumT>::PRIMITIVE>
struct sm90_tuning
{
static constexpr int threads = 128;
static constexpr int items = 15;

using delay_constructor = detail::default_delay_constructor_t<AccumT>;
using delay_constructor = detail::exponential_backoff_constructor_t<32, 1140>;
};

template <class AccumT>
struct sm90_tuning<AccumT, 1>
struct sm90_tuning<AccumT, 1, true>
{
static constexpr int threads = 192;
static constexpr int items = 22;
Expand All @@ -225,7 +227,7 @@ struct sm90_tuning<AccumT, 1>
};

template <class AccumT>
struct sm90_tuning<AccumT, 2>
struct sm90_tuning<AccumT, 2, true>
{
static constexpr int threads = 512;
static constexpr int items = 12;
Expand All @@ -234,7 +236,7 @@ struct sm90_tuning<AccumT, 2>
};

template <class AccumT>
struct sm90_tuning<AccumT, 4>
struct sm90_tuning<AccumT, 4, true>
{
static constexpr int threads = 128;
static constexpr int items = 24;
Expand All @@ -243,25 +245,36 @@ struct sm90_tuning<AccumT, 4>
};

template <class AccumT>
struct sm90_tuning<AccumT, 8>
struct sm90_tuning<AccumT, 8, true>
{
static constexpr int threads = 224;
static constexpr int items = 24;

using delay_constructor = detail::fixed_delay_constructor_t<632, 1290>;
};

template <class AccumT>
struct sm90_tuning<AccumT, 16>
#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, 16, false>
{
static constexpr int threads = 576;
static constexpr int items = 21;

using delay_constructor = detail::fixed_delay_constructor_t<860, 630>;
};

template <>
struct sm90_tuning<__uint128_t, 16, false>
{
static constexpr int threads = 576;
static constexpr int items = 21;

using delay_constructor = detail::fixed_delay_constructor_t<860, 630>;
};
#endif

template <>
struct sm90_tuning<float, sizeof(float)>
struct sm90_tuning<float, sizeof(float), true>
{
static constexpr int threads = 128;
static constexpr int items = 24;
Expand All @@ -270,7 +283,7 @@ struct sm90_tuning<float, sizeof(float)>
};

template <>
struct sm90_tuning<double, sizeof(double)>
struct sm90_tuning<double, sizeof(double), true>
{
static constexpr int threads = 224;
static constexpr int items = 24;
Expand Down

0 comments on commit 566593d

Please sign in to comment.