Skip to content

Commit

Permalink
Fix unique by key
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Jun 5, 2023
1 parent b37d87f commit 8d66e1f
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions cub/device/dispatch/dispatch_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ CUB_NAMESPACE_BEGIN
* Unique by key kernel entry point (multi-block)
*/
template <
typename AgentUniqueByKeyPolicyT, ///< Parameterized AgentUniqueByKeyPolicy tuning policy type
typename ChainedPolicyT,
typename KeyInputIteratorT, ///< Random-access input iterator type for keys
typename ValueInputIteratorT, ///< Random-access input iterator type for values
typename KeyOutputIteratorT, ///< Random-access output iterator type for keys
Expand All @@ -58,7 +58,7 @@ template <
typename ScanTileStateT, ///< Tile status interface type
typename EqualityOpT, ///< Equality operator type
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int(AgentUniqueByKeyPolicyT::UniqueByKeyPolicyT::BLOCK_THREADS))
__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT::BLOCK_THREADS))
__global__ void DeviceUniqueByKeySweepKernel(
KeyInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys
ValueInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of values
Expand All @@ -70,15 +70,16 @@ __global__ void DeviceUniqueByKeySweepKernel(
OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_keys_in or \p d_values_in)
int num_tiles) ///< [in] Total number of tiles for the entire problem
{
using AgentUniqueByKeyPolicyT = typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT;

// Thread block type for selecting data from input tiles
using AgentUniqueByKeyT = AgentUniqueByKey<
typename AgentUniqueByKeyPolicyT::UniqueByKeyPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
EqualityOpT,
OffsetT>;
using AgentUniqueByKeyT = AgentUniqueByKey<AgentUniqueByKeyPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
ValueOutputIteratorT,
EqualityOpT,
OffsetT>;

// Shared memory for AgentUniqueByKey
__shared__ typename AgentUniqueByKeyT::TempStorage temp_storage;
Expand All @@ -101,7 +102,8 @@ struct DeviceUniqueByKeyPolicy
using KeyT = typename std::iterator_traits<KeyInputIteratorT>::value_type;

// SM350
struct Policy350 : ChainedPolicy<350, Policy350, Policy350> {
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
{
const static int INPUT_SIZE = sizeof(KeyT);
enum
{
Expand All @@ -114,7 +116,7 @@ struct DeviceUniqueByKeyPolicy
cub::BLOCK_LOAD_WARP_TRANSPOSE,
cub::LOAD_LDG,
cub::BLOCK_SCAN_WARP_SCANS,
detail::default_delay_constructor_t>;
detail::default_delay_constructor_t<int>>;
};

// SM520
Expand All @@ -132,7 +134,7 @@ struct DeviceUniqueByKeyPolicy
cub::BLOCK_LOAD_WARP_TRANSPOSE,
cub::LOAD_LDG,
cub::BLOCK_SCAN_WARP_SCANS,
detail::default_delay_constructor_t>;
detail::default_delay_constructor_t<int>>;
};

/// MaxPolicy
Expand Down Expand Up @@ -393,11 +395,13 @@ struct DispatchUniqueByKey: SelectedPolicy
CUB_RUNTIME_FUNCTION __host__ __forceinline__
cudaError_t Invoke()
{
using MaxPolicyT = typename DispatchUniqueByKey::MaxPolicy;

// Ensure kernels are instantiated.
return Invoke<ActivePolicyT>(
DeviceCompactInitKernel<ScanTileStateT, NumSelectedIteratorT>,
DeviceUniqueByKeySweepKernel<
ActivePolicyT,
MaxPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyOutputIteratorT,
Expand Down

0 comments on commit 8d66e1f

Please sign in to comment.