Skip to content

Commit

Permalink
Introduce SM90 tuning policy into scan
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 7, 2023
1 parent b87c356 commit 7f02615
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 4 deletions.
2 changes: 1 addition & 1 deletion benchmarks/scripts/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def iterate_case_dfs(args, callable):

for gpu in ctk_cub_df['gpu'].unique():
target_df = ctk_cub_df[ctk_cub_df['gpu'] == gpu]
target_df.drop(columns=['ctk', 'cub', 'gpu'], inplace=True)
target_df = target_df.drop(columns=['ctk', 'cub', 'gpu'])
target_df = compute_speedup(target_df)

for ct_point in ct_space(target_df):
Expand Down
115 changes: 112 additions & 3 deletions cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,101 @@ __launch_bounds__(int(ChainedPolicyT::ActivePolicy::ScanPolicyT::BLOCK_THREADS))
* Policy
******************************************************************************/

template <typename AccumT> ///< Data type
namespace detail
{
namespace scan
{

template <class AccumT,
bool PrimitiveBinaryOp,
std::size_t AccumSize = sizeof(AccumT),
bool PrimitiveType = Traits<AccumT>::PRIMITIVE>
struct sm90_tuning
{
static constexpr int threads = 128;
static constexpr int items = 15;

using delay_constructor = detail::default_delay_constructor_t<AccumT>;
};

template <class AccumT>
struct sm90_tuning<AccumT, true, 1, true>
{
static constexpr int threads = 192;
static constexpr int items = 22;

using delay_constructor = detail::fixed_delay_constructor_t<168, 1140>;
};

template <class AccumT>
struct sm90_tuning<AccumT, true, 2, true>
{
static constexpr int threads = 512;
static constexpr int items = 12;

using delay_constructor = detail::fixed_delay_constructor_t<376, 1125>;
};

template <class AccumT>
struct sm90_tuning<AccumT, true, 4, true>
{
static constexpr int threads = 128;
static constexpr int items = 24;

using delay_constructor = detail::fixed_delay_constructor_t<648, 1245>;
};

template <class AccumT>
struct sm90_tuning<AccumT, true, 8, true>
{
static constexpr int threads = 224;
static constexpr int items = 24;

using delay_constructor = detail::fixed_delay_constructor_t<632, 1290>;
};

#if CUB_IS_INT128_ENABLED
template <>
struct sm90_tuning<__int128_t, true, 16, false>
{
static constexpr int threads = 576;
static constexpr int items = 21;

using delay_constructor = detail::fixed_delay_constructor_t<860, 630>;
};

template <>
struct sm90_tuning<__uint128_t, true, 16, false>
{
static constexpr int threads = 576;
static constexpr int items = 21;

using delay_constructor = detail::fixed_delay_constructor_t<860, 630>;
};
#endif

template <>
struct sm90_tuning<float, true, sizeof(float), true>
{
static constexpr int threads = 128;
static constexpr int items = 24;

using delay_constructor = detail::fixed_delay_constructor_t<688, 1140>;
};

template <>
struct sm90_tuning<double, true, sizeof(double), true>
{
static constexpr int threads = 224;
static constexpr int items = 24;

using delay_constructor = detail::fixed_delay_constructor_t<576, 1215>;
};

} // namespace scan
} // namespace detail

template <typename AccumT, typename ScanOpT = Sum>
struct DeviceScanPolicy
{
// For large values, use timesliced loads/stores to fit shared memory.
Expand Down Expand Up @@ -271,7 +365,22 @@ struct DeviceScanPolicy
detail::default_delay_constructor_t<AccumT>>;
};

using MaxPolicy = Policy600;
/// SM900
struct Policy900 : ChainedPolicy<900, Policy900, Policy600>
{
using tuning = detail::scan::sm90_tuning<AccumT, detail::basic_binary_op_t<ScanOpT>::value>;

using ScanPolicyT = policy_t<tuning::threads,
tuning::items,
AccumT,
ScanTransposedLoad,
LOAD_DEFAULT,
ScanTransposedStore,
BLOCK_SCAN_WARP_SCANS,
typename tuning::delay_constructor>;
};

using MaxPolicy = Policy900;
};

/******************************************************************************
Expand Down Expand Up @@ -312,7 +421,7 @@ template <typename InputIteratorT,
cub::detail::value_t<InputIteratorT>,
typename InitValueT::value_type>,
cub::detail::value_t<InputIteratorT>>,
typename SelectedPolicy = DeviceScanPolicy<AccumT>>
typename SelectedPolicy = DeviceScanPolicy<AccumT, ScanOpT>>
struct DispatchScan : SelectedPolicy
{
//---------------------------------------------------------------------
Expand Down
27 changes: 27 additions & 0 deletions cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,33 @@ struct ArgMin
}
};

namespace detail
{
template <class OpT>
struct basic_binary_op_t
{
static constexpr bool value = false;
};

template <>
struct basic_binary_op_t<Sum>
{
static constexpr bool value = true;
};

template <>
struct basic_binary_op_t<Min>
{
static constexpr bool value = true;
};

template <>
struct basic_binary_op_t<Max>
{
static constexpr bool value = true;
};
} // namespace detail

/// @brief Default cast functor
template <typename B>
struct CastOp
Expand Down

0 comments on commit 7f02615

Please sign in to comment.