From 8d66e1fe9e253afbeea0f50ad53050bd1add1ee3 Mon Sep 17 00:00:00 2001 From: Georgy Evtushenko Date: Mon, 5 Jun 2023 14:33:23 +0400 Subject: [PATCH] Fix unique by key --- .../dispatch/dispatch_unique_by_key.cuh | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/cub/device/dispatch/dispatch_unique_by_key.cuh b/cub/device/dispatch/dispatch_unique_by_key.cuh index 4b44f1a8b..02ecd33e6 100644 --- a/cub/device/dispatch/dispatch_unique_by_key.cuh +++ b/cub/device/dispatch/dispatch_unique_by_key.cuh @@ -49,7 +49,7 @@ CUB_NAMESPACE_BEGIN * Unique by key kernel entry point (multi-block) */ template < - typename AgentUniqueByKeyPolicyT, ///< Parameterized AgentUniqueByKeyPolicy tuning policy type + typename ChainedPolicyT, typename KeyInputIteratorT, ///< Random-access input iterator type for keys typename ValueInputIteratorT, ///< Random-access input iterator type for values typename KeyOutputIteratorT, ///< Random-access output iterator type for keys @@ -58,7 +58,7 @@ template < typename ScanTileStateT, ///< Tile status interface type typename EqualityOpT, ///< Equality operator type typename OffsetT> ///< Signed integer type for global offsets -__launch_bounds__ (int(AgentUniqueByKeyPolicyT::UniqueByKeyPolicyT::BLOCK_THREADS)) +__launch_bounds__ (int(ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT::BLOCK_THREADS)) __global__ void DeviceUniqueByKeySweepKernel( KeyInputIteratorT d_keys_in, ///< [in] Pointer to the input sequence of keys ValueInputIteratorT d_values_in, ///< [in] Pointer to the input sequence of values @@ -70,15 +70,16 @@ __global__ void DeviceUniqueByKeySweepKernel( OffsetT num_items, ///< [in] Total number of input items (i.e., length of \p d_keys_in or \p d_values_in) int num_tiles) ///< [in] Total number of tiles for the entire problem { + using AgentUniqueByKeyPolicyT = typename ChainedPolicyT::ActivePolicy::UniqueByKeyPolicyT; + // Thread block type for selecting data from input tiles - using AgentUniqueByKeyT = AgentUniqueByKey< - typename AgentUniqueByKeyPolicyT::UniqueByKeyPolicyT, - KeyInputIteratorT, - ValueInputIteratorT, - KeyOutputIteratorT, - ValueOutputIteratorT, - EqualityOpT, - OffsetT>; + using AgentUniqueByKeyT = AgentUniqueByKey; // Shared memory for AgentUniqueByKey __shared__ typename AgentUniqueByKeyT::TempStorage temp_storage; @@ -101,7 +102,8 @@ struct DeviceUniqueByKeyPolicy using KeyT = typename std::iterator_traits::value_type; // SM350 - struct Policy350 : ChainedPolicy<350, Policy350, Policy350> { + struct Policy350 : ChainedPolicy<350, Policy350, Policy350> + { const static int INPUT_SIZE = sizeof(KeyT); enum { @@ -114,7 +116,7 @@ struct DeviceUniqueByKeyPolicy cub::BLOCK_LOAD_WARP_TRANSPOSE, cub::LOAD_LDG, cub::BLOCK_SCAN_WARP_SCANS, - detail::default_delay_constructor_t>; + detail::default_delay_constructor_t>; }; // SM520 @@ -132,7 +134,7 @@ struct DeviceUniqueByKeyPolicy cub::BLOCK_LOAD_WARP_TRANSPOSE, cub::LOAD_LDG, cub::BLOCK_SCAN_WARP_SCANS, - detail::default_delay_constructor_t>; + detail::default_delay_constructor_t>; }; /// MaxPolicy @@ -393,11 +395,13 @@ struct DispatchUniqueByKey: SelectedPolicy CUB_RUNTIME_FUNCTION __host__ __forceinline__ cudaError_t Invoke() { + using MaxPolicyT = typename DispatchUniqueByKey::MaxPolicy; + // Ensure kernels are instantiated. return Invoke( DeviceCompactInitKernel, DeviceUniqueByKeySweepKernel< - ActivePolicyT, + MaxPolicyT, KeyInputIteratorT, ValueInputIteratorT, KeyOutputIteratorT,