From c3f7988ae4677b2ffab72b121fc16cb0eade31e1 Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Sat, 6 May 2023 12:09:02 +0400 Subject: [PATCH] Rework select --- cub/device/dispatch/dispatch_select_if.cuh | 335 +++++++++++---------- 1 file changed, 180 insertions(+), 155 deletions(-) diff --git a/cub/device/dispatch/dispatch_select_if.cuh b/cub/device/dispatch/dispatch_select_if.cuh index 0c422d38c..aaa147a5d 100644 --- a/cub/device/dispatch/dispatch_select_if.cuh +++ b/cub/device/dispatch/dispatch_select_if.cuh @@ -63,7 +63,7 @@ CUB_NAMESPACE_BEGIN * Otherwise performs discontinuity selection (keep unique) */ template < - typename AgentSelectIfPolicyT, ///< Parameterized AgentSelectIfPolicyT tuning policy type + typename ChainedPolicyT, typename InputIteratorT, ///< Random-access input iterator type for reading input items typename FlagsInputIteratorT, ///< Random-access input iterator type for reading selection flags (NullType* if a selection functor or discontinuity flagging is to be used for selection) typename SelectedOutputIteratorT, ///< Random-access output iterator type for writing selected items @@ -73,7 +73,7 @@ template < typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selection flags is to be used for selection) typename OffsetT, ///< Signed integer type for global offsets bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output -__launch_bounds__ (int(AgentSelectIfPolicyT::BLOCK_THREADS)) +__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREADS)) __global__ void DeviceSelectSweepKernel( InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items FlagsInputIteratorT d_flags, ///< [in] Pointer to the input sequence of selection flags (if applicable) @@ -85,16 +85,17 @@ __global__ void DeviceSelectSweepKernel( OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) int num_tiles) ///< [in] Total number of tiles for the entire problem { + using AgentSelectIfPolicyT = typename ChainedPolicyT::ActivePolicy::SelectIfPolicyT; + // Thread block type for selecting data from input tiles - typedef AgentSelectIf< - AgentSelectIfPolicyT, - InputIteratorT, - FlagsInputIteratorT, - SelectedOutputIteratorT, - SelectOpT, - EqualityOpT, - OffsetT, - KEEP_REJECTS> AgentSelectIfT; + using AgentSelectIfT = AgentSelectIf; // Shared memory for AgentSelectIf __shared__ typename AgentSelectIfT::TempStorage temp_storage; @@ -165,66 +166,101 @@ struct DispatchSelectIf : SelectedPolicy // The flag value type using FlagT = cub::detail::value_t; - enum - { - INIT_KERNEL_THREADS = 128, - }; - // Tile status descriptor interface type - typedef ScanTileState ScanTileStateT; + using ScanTileStateT = ScanTileState; + static constexpr int INIT_KERNEL_THREADS = 128; - // "Opaque" policies (whose parameterizations aren't reflected in the type signature) - struct PtxSelectIfPolicyT : SelectedPolicy::SelectIfPolicyT {}; + /// Device-accessible allocation of temporary storage. + /// When `nullptr`, the required allocation size is written to `temp_storage_bytes` + /// and no work is done. + void* d_temp_storage; + /// Reference to size in bytes of `d_temp_storage` allocation + size_t& temp_storage_bytes; - /****************************************************************************** - * Utilities - ******************************************************************************/ + /// Pointer to the input sequence of data items + InputIteratorT d_in; - /** - * Initialize kernel dispatch configurations with the policies corresponding to the PTX assembly we will use - */ - template - CUB_RUNTIME_FUNCTION __forceinline__ - static void InitConfigs( - int ptx_version, - KernelConfig &select_if_config) - { - NV_IF_TARGET(NV_IS_DEVICE, - ( - (void)ptx_version; - // We're on the device, so initialize the kernel dispatch configurations with the current PTX policy - select_if_config.template Init(); - ), ( - // We're on the host, so lookup and initialize the kernel dispatch configurations with the policies that match the device's PTX version - - // (There's only one policy right now) - (void)ptx_version; - select_if_config.template Init(); - )); - } + /// Pointer to the input sequence of selection flags (if applicable) + FlagsInputIteratorT d_flags; + /// Pointer to the output sequence of selected data items + SelectedOutputIteratorT d_selected_out; - /** - * Kernel kernel dispatch configuration. - */ - struct KernelConfig - { - int block_threads; - int items_per_thread; - int tile_items; + /// Pointer to the total number of items selected (i.e., length of `d_selected_out`) + NumSelectedIteratorT d_num_selected_out; - template - CUB_RUNTIME_FUNCTION __forceinline__ - void Init() - { - block_threads = PolicyT::BLOCK_THREADS; - items_per_thread = PolicyT::ITEMS_PER_THREAD; - tile_items = block_threads * items_per_thread; - } - }; + /// Selection operator + SelectOpT select_op; + + /// Equality operator + EqualityOpT equality_op; + /// Total number of input items (i.e., length of `d_in`) + OffsetT num_items; + + /// CUDA stream to launch kernels within. Default is stream0. + cudaStream_t stream; + + int ptx_version; + + /** + * @param d_temp_storage + * Device-accessible allocation of temporary storage. + * When `nullptr`, the required allocation size is written to `temp_storage_bytes` + * and no work is done. + * + * @param temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param d_in + * Pointer to the input sequence of data items + * + * @param d_flags + * Pointer to the input sequence of selection flags (if applicable) + * + * @param d_selected_out + * Pointer to the output sequence of selected data items + * + * @param d_num_selected_out + * Pointer to the total number of items selected (i.e., length of `d_selected_out`) + * + * @param select_op + * Selection operator + * + * @param equality_op + * Equality operator + * + * @param num_items + * Total number of input items (i.e., length of `d_in`) + * + * @param stream + * CUDA stream to launch kernels within. Default is stream0. + */ + CUB_RUNTIME_FUNCTION __forceinline__ DispatchSelectIf(void *d_temp_storage, + size_t &temp_storage_bytes, + InputIteratorT d_in, + FlagsInputIteratorT d_flags, + SelectedOutputIteratorT d_selected_out, + NumSelectedIteratorT d_num_selected_out, + SelectOpT select_op, + EqualityOpT equality_op, + OffsetT num_items, + cudaStream_t stream, + int ptx_version) + : d_temp_storage(d_temp_storage) + , temp_storage_bytes(temp_storage_bytes) + , d_in(d_in) + , d_flags(d_flags) + , d_selected_out(d_selected_out) + , d_num_selected_out(d_num_selected_out) + , select_op(select_op) + , equality_op(equality_op) + , num_items(num_items) + , stream(stream) + , ptx_version(ptx_version) + {} /****************************************************************************** * Dispatch entrypoints @@ -234,27 +270,16 @@ struct DispatchSelectIf : SelectedPolicy * Internal dispatch routine for computing a device-wide selection using the * specified kernel functions. */ - template < - typename ScanInitKernelPtrT, ///< Function type of cub::DeviceScanInitKernel - typename SelectIfKernelPtrT> ///< Function type of cub::SelectIfKernelPtrT - CUB_RUNTIME_FUNCTION __forceinline__ - static cudaError_t Dispatch( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items - FlagsInputIteratorT d_flags, ///< [in] Pointer to the input sequence of selection flags (if applicable) - SelectedOutputIteratorT d_selected_out, ///< [in] Pointer to the output sequence of selected data items - NumSelectedIteratorT d_num_selected_out, ///< [in] Pointer to the total number of items selected (i.e., length of \p d_selected_out) - SelectOpT select_op, ///< [in] Selection operator - EqualityOpT equality_op, ///< [in] Equality operator - OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) - cudaStream_t stream, ///< [in] CUDA stream to launch kernels within. Default is stream0. - int /*ptx_version*/, ///< [in] PTX version of dispatch kernels - ScanInitKernelPtrT scan_init_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceScanInitKernel - SelectIfKernelPtrT select_if_kernel, ///< [in] Kernel function pointer to parameterization of cub::DeviceSelectSweepKernel - KernelConfig select_if_config) ///< [in] Dispatch parameters that match the policy that \p select_if_kernel was compiled for + template + CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke(ScanInitKernelPtrT scan_init_kernel, + SelectIfKernelPtrT select_if_kernel) { cudaError error = cudaSuccess; + + const int block_threads = ActivePolicyT::SelectIfPolicyT::BLOCK_THREADS; + const int items_per_thread = ActivePolicyT::SelectIfPolicyT::ITEMS_PER_THREAD; + const int tile_size = block_threads * items_per_thread; + do { // Get device ordinal @@ -262,7 +287,6 @@ struct DispatchSelectIf : SelectedPolicy if (CubDebug(error = cudaGetDevice(&device_ordinal))) break; // Number of input tiles - int tile_size = select_if_config.block_threads * select_if_config.items_per_thread; int num_tiles = static_cast(cub::DivideAndRoundUp(num_items, tile_size)); // Specify temporary storage allocation requirements @@ -331,7 +355,7 @@ struct DispatchSelectIf : SelectedPolicy int range_select_sm_occupancy; if (CubDebug(error = MaxSmOccupancy(range_select_sm_occupancy, // out select_if_kernel, - select_if_config.block_threads))) + block_threads))) { break; } @@ -341,16 +365,16 @@ struct DispatchSelectIf : SelectedPolicy scan_grid_size.x, scan_grid_size.y, scan_grid_size.z, - select_if_config.block_threads, + block_threads, (long long)stream, - select_if_config.items_per_thread, + items_per_thread, range_select_sm_occupancy); } #endif // Invoke select_if_kernel THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - scan_grid_size, select_if_config.block_threads, 0, stream + scan_grid_size, block_threads, 0, stream ).doit(select_if_kernel, d_in, d_flags, @@ -380,8 +404,59 @@ struct DispatchSelectIf : SelectedPolicy return error; } - template - CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED + template + CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke() + { + using MaxPolicyT = typename SelectedPolicy::MaxPolicy; + + return Invoke(DeviceCompactInitKernel, + DeviceSelectSweepKernel); + } + + /** + * Internal dispatch routine + * + * @param d_temp_storage + * Device-accessible allocation of temporary storage. + * When `nullptr`, the required allocation size is written to `temp_storage_bytes` + * and no work is done. + * + * @param temp_storage_bytes + * Reference to size in bytes of `d_temp_storage` allocation + * + * @param d_in + * Pointer to the input sequence of data items + * + * @param d_flags + * Pointer to the input sequence of selection flags (if applicable) + * + * @param d_selected_out + * Pointer to the output sequence of selected data items + * + * @param d_num_selected_out + * Pointer to the total number of items selected (i.e., length of `d_selected_out`) + * + * @param select_op + * Selection operator + * + * @param equality_op + * Equality operator + * + * @param num_items + * Total number of input items (i.e., length of `d_in`) + * + * @param stream + * CUDA stream to launch kernels within. Default is stream0. + */ CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t Dispatch(void *d_temp_storage, size_t &temp_storage_bytes, @@ -392,79 +467,29 @@ struct DispatchSelectIf : SelectedPolicy SelectOpT select_op, EqualityOpT equality_op, OffsetT num_items, - cudaStream_t stream, - bool debug_synchronous, - int ptx_version, - ScanInitKernelPtrT scan_init_kernel, - SelectIfKernelPtrT select_if_kernel, - KernelConfig select_if_config) + cudaStream_t stream) { - CUB_DETAIL_RUNTIME_DEBUG_SYNC_USAGE_LOG - - return Dispatch( - d_temp_storage, - temp_storage_bytes, - d_in, - d_flags, - d_selected_out, - d_num_selected_out, - select_op, - equality_op, - num_items, - stream, - ptx_version, - scan_init_kernel, - select_if_kernel, - select_if_config); - } + using MaxPolicyT = typename SelectedPolicy::MaxPolicy; - /** - * Internal dispatch routine - */ - CUB_RUNTIME_FUNCTION __forceinline__ - static cudaError_t Dispatch( - void* d_temp_storage, ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done. - size_t& temp_storage_bytes, ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation - InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items - FlagsInputIteratorT d_flags, ///< [in] Pointer to the input sequence of selection flags (if applicable) - SelectedOutputIteratorT d_selected_out, ///< [in] Pointer to the output sequence of selected data items - NumSelectedIteratorT d_num_selected_out, ///< [in] Pointer to the total number of items selected (i.e., length of \p d_selected_out) - SelectOpT select_op, ///< [in] Selection operator - EqualityOpT equality_op, ///< [in] Equality operator - OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in) - cudaStream_t stream) ///< [in] [optional] CUDA stream to launch kernels within. Default is stream0. - { - cudaError error = cudaSuccess; - do + int ptx_version = 0; + if (cudaError_t error = CubDebug(PtxVersion(ptx_version))) { - // Get PTX version - int ptx_version = 0; - if (CubDebug(error = PtxVersion(ptx_version))) break; - - // Get kernel kernel dispatch configurations - KernelConfig select_if_config; - InitConfigs(ptx_version, select_if_config); - - // Dispatch - if (CubDebug(error = Dispatch( - d_temp_storage, - temp_storage_bytes, - d_in, - d_flags, - d_selected_out, - d_num_selected_out, - select_op, - equality_op, - num_items, - stream, - ptx_version, - DeviceCompactInitKernel, - DeviceSelectSweepKernel, - select_if_config))) break; + return error; } - while (0); - return error; + DispatchSelectIf dispatch(d_temp_storage, + temp_storage_bytes, + d_in, + d_flags, + d_selected_out, + d_num_selected_out, + select_op, + equality_op, + num_items, + stream, + ptx_version); + + return CubDebug(MaxPolicyT::Invoke(ptx_version, dispatch)); } CUB_DETAIL_RUNTIME_DEBUG_SYNC_IS_NOT_SUPPORTED