Skip to content

Commit

Permalink
Further formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 6, 2023
1 parent c3f7988 commit 94a96e1
Showing 1 changed file with 144 additions and 42 deletions.
186 changes: 144 additions & 42 deletions cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
******************************************************************************/

/**
* \file
* cub::DeviceSelect provides device-wide, parallel operations for selecting items from sequences of data items residing within device-accessible memory.
* @file
* cub::DeviceSelect provides device-wide, parallel operations for selecting items from sequences
* of data items residing within device-accessible memory.
*/

#pragma once
Expand Down Expand Up @@ -61,29 +62,84 @@ CUB_NAMESPACE_BEGIN
* Performs functor-based selection if SelectOpT functor type != NullType
* Otherwise performs flag-based selection if FlagsInputIterator's value type != NullType
* Otherwise performs discontinuity selection (keep unique)
*
* @tparam InputIteratorT
* Random-access input iterator type for reading input items
*
* @tparam FlagsInputIteratorT
* Random-access input iterator type for reading selection flags (NullType* if a selection functor
* or discontinuity flagging is to be used for selection)
*
* @tparam SelectedOutputIteratorT
* Random-access output iterator type for writing selected items
*
* @tparam NumSelectedIteratorT
* Output iterator type for recording the number of items selected
*
* @tparam ScanTileStateT
* Tile status interface type
*
* @tparam SelectOpT
* Selection operator type (NullType if selection flags or discontinuity flagging is
* to be used for selection)
*
* @tparam EqualityOpT
* Equality operator type (NullType if selection functor or selection flags is
* to be used for selection)
*
* @tparam OffsetT
* Signed integer type for global offsets
*
* @tparam KEEP_REJECTS
* Whether or not we push rejected items to the back of the output
*
* @param[in] d_in
* Pointer to the input sequence of data items
*
* @param[in] d_flags
* Pointer to the input sequence of selection flags (if applicable)
*
* @param[out] d_selected_out
* Pointer to the output sequence of selected data items
*
* @param[out] d_num_selected_out
* Pointer to the total number of items selected (i.e., length of \p d_selected_out)
*
* @param[in] tile_status
* Tile status interface
*
* @param[in] select_op
* Selection operator
*
* @param[in] equality_op
* Equality operator
*
* @param[in] num_items
* Total number of input items (i.e., length of \p d_in)
*
* @param[in] num_tiles
* Total number of tiles for the entire problem
*/
template <
typename ChainedPolicyT,
typename InputIteratorT, ///< Random-access input iterator type for reading input items
typename FlagsInputIteratorT, ///< Random-access input iterator type for reading selection flags (NullType* if a selection functor or discontinuity flagging is to be used for selection)
typename SelectedOutputIteratorT, ///< Random-access output iterator type for writing selected items
typename NumSelectedIteratorT, ///< Output iterator type for recording the number of items selected
typename ScanTileStateT, ///< Tile status interface type
typename SelectOpT, ///< Selection operator type (NullType if selection flags or discontinuity flagging is to be used for selection)
typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selection flags is to be used for selection)
typename OffsetT, ///< Signed integer type for global offsets
bool KEEP_REJECTS> ///< Whether or not we push rejected items to the back of the output
__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREADS))
__global__ void DeviceSelectSweepKernel(
InputIteratorT d_in, ///< [in] Pointer to the input sequence of data items
FlagsInputIteratorT d_flags, ///< [in] Pointer to the input sequence of selection flags (if applicable)
SelectedOutputIteratorT d_selected_out, ///< [out] Pointer to the output sequence of selected data items
NumSelectedIteratorT d_num_selected_out, ///< [out] Pointer to the total number of items selected (i.e., length of \p d_selected_out)
ScanTileStateT tile_status, ///< [in] Tile status interface
SelectOpT select_op, ///< [in] Selection operator
EqualityOpT equality_op, ///< [in] Equality operator
OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_in)
int num_tiles) ///< [in] Total number of tiles for the entire problem
template <typename ChainedPolicyT,
typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename NumSelectedIteratorT,
typename ScanTileStateT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
bool KEEP_REJECTS>
__launch_bounds__(int(ChainedPolicyT::ActivePolicy::SelectIfPolicyT::BLOCK_THREADS)) __global__
void DeviceSelectSweepKernel(InputIteratorT d_in,
FlagsInputIteratorT d_flags,
SelectedOutputIteratorT d_selected_out,
NumSelectedIteratorT d_num_selected_out,
ScanTileStateT tile_status,
SelectOpT select_op,
EqualityOpT equality_op,
OffsetT num_items,
int num_tiles)
{
using AgentSelectIfPolicyT = typename ChainedPolicyT::ActivePolicy::SelectIfPolicyT;

Expand Down Expand Up @@ -141,19 +197,45 @@ struct device_select_policy_hub

/**
* Utility class for dispatching the appropriately-tuned kernels for DeviceSelect
*
* @tparam InputIteratorT
* Random-access input iterator type for reading input items
*
* @tparam FlagsInputIteratorT
* Random-access input iterator type for reading selection flags
* (NullType* if a selection functor or discontinuity flagging is to be used for selection)
*
* @tparam SelectedOutputIteratorT
* Random-access output iterator type for writing selected items
*
* @tparam NumSelectedIteratorT
* Output iterator type for recording the number of items selected
*
* @tparam SelectOpT
* Selection operator type (NullType if selection flags or discontinuity flagging is
* to be used for selection)
*
* @tparam EqualityOpT
* Equality operator type (NullType if selection functor or selection flags is to
* be used for selection)
*
* @tparam OffsetT
* Signed integer type for global offsets
*
* @tparam KEEP_REJECTS
* Whether or not we push rejected items to the back of the output
*/
template <
typename InputIteratorT, ///< Random-access input iterator type for reading input items
typename FlagsInputIteratorT, ///< Random-access input iterator type for reading selection flags (NullType* if a selection functor or discontinuity flagging is to be used for selection)
typename SelectedOutputIteratorT, ///< Random-access output iterator type for writing selected items
typename NumSelectedIteratorT, ///< Output iterator type for recording the number of items selected
typename SelectOpT, ///< Selection operator type (NullType if selection flags or discontinuity flagging is to be used for selection)
typename EqualityOpT, ///< Equality operator type (NullType if selection functor or selection flags is to be used for selection)
typename OffsetT, ///< Signed integer type for global offsets
bool KEEP_REJECTS, ///< Whether or not we push rejected items to the back of the output
bool MayAlias = false,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>, MayAlias>>
template <typename InputIteratorT,
typename FlagsInputIteratorT,
typename SelectedOutputIteratorT,
typename NumSelectedIteratorT,
typename SelectOpT,
typename EqualityOpT,
typename OffsetT,
bool KEEP_REJECTS,
bool MayAlias = false,
typename SelectedPolicy =
detail::device_select_policy_hub<cub::detail::value_t<InputIteratorT>, MayAlias>>
struct DispatchSelectIf : SelectedPolicy
{
/******************************************************************************
Expand Down Expand Up @@ -284,27 +366,42 @@ struct DispatchSelectIf : SelectedPolicy
{
// Get device ordinal
int device_ordinal;
if (CubDebug(error = cudaGetDevice(&device_ordinal))) break;
if (CubDebug(error = cudaGetDevice(&device_ordinal)))
{
break;
}

// Number of input tiles
int num_tiles = static_cast<int>(cub::DivideAndRoundUp(num_items, tile_size));

// Specify temporary storage allocation requirements
size_t allocation_sizes[1];
if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0]))) break; // bytes needed for tile status descriptors

// bytes needed for tile status descriptors
if (CubDebug(error = ScanTileStateT::AllocationSize(num_tiles, allocation_sizes[0])))
{
break;
}

// Compute allocation pointers into the single storage blob (or compute the necessary size of the blob)
void* allocations[1] = {};
if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes))) break;
if (d_temp_storage == NULL)
if (CubDebug(error = AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes)))
{
break;
}

if (d_temp_storage == nullptr)
{
// Return if the caller is simply requesting the size of the storage allocation
break;
}

// Construct the tile status interface
ScanTileStateT tile_status;
if (CubDebug(error = tile_status.Init(num_tiles, allocations[0], allocation_sizes[0]))) break;
if (CubDebug(error = tile_status.Init(num_tiles, allocations[0], allocation_sizes[0])))
{
break;
}

// Log scan_init_kernel configuration
int init_grid_size = CUB_MAX(1, cub::DivideAndRoundUp(num_tiles, INIT_KERNEL_THREADS));
Expand Down Expand Up @@ -336,11 +433,16 @@ struct DispatchSelectIf : SelectedPolicy

// Return if empty problem
if (num_items == 0)
{
break;
}

// Get max x-dimension of grid
int max_dim_x;
if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal))) break;
if (CubDebug(error = cudaDeviceGetAttribute(&max_dim_x, cudaDevAttrMaxGridDimX, device_ordinal)))
{
break;
}

// Get grid size for scanning tiles
dim3 scan_grid_size;
Expand Down

0 comments on commit 94a96e1

Please sign in to comment.