From 93501cfa2eb9c78bb03307fb0d6b7cbe31d1f60a Mon Sep 17 00:00:00 2001 From: Stanley Tsang Date: Tue, 27 Aug 2024 09:42:11 -0600 Subject: [PATCH] Cherry-pick: Optimize block_reduce_warp_reduce when block size is the same as warp size (#599) * Optimize block_reduce_warp_reduce when block size == warp size * Make conditional constexpr --- CHANGELOG.md | 7 +- .../block/detail/block_reduce_warp_reduce.hpp | 66 +++++++++++-------- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 829a5bdd5..b843b5371 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,12 @@ Documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). -## Unreleased rocPRIM-3.2.0 for ROCm 6.2.0 +## rocPRIM-3.2.1 for ROCm 6.2.1 + +### Optimizations +* Improved performance of block_reduce_warp_reduce when warp size == block size. + +## rocPRIM-3.2.0 for ROCm 6.2.0 ### Additions diff --git a/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp b/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp index 590b57220..a6957f6a8 100644 --- a/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp +++ b/rocprim/include/rocprim/block/detail/block_reduce_warp_reduce.hpp @@ -178,21 +178,25 @@ class block_reduce_warp_reduce input, output, num_valid, reduce_op ); - // i-th warp will have its partial stored in storage_.warp_partials[i-1] - if(lane_id == 0) + // Final reduction across warps is only required if there is more than 1 warp + if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1) { - storage_.warp_partials[warp_id] = output; - } - ::rocprim::syncthreads(); - - if(flat_tid < warps_no_) - { - // Use warp partial to calculate the final reduce results for every thread - auto warp_partial = storage_.warp_partials[lane_id]; - - warp_reduce( - warp_partial, output, warps_no_, reduce_op - ); + // i-th warp will have its partial stored in storage_.warp_partials[i-1] + if(lane_id == 0) + { + storage_.warp_partials[warp_id] = output; + } + ::rocprim::syncthreads(); + + if(flat_tid < warps_no_) + { + // Use warp partial to calculate the final reduce results for every thread + auto warp_partial = storage_.warp_partials[lane_id]; + + warp_reduce( + warp_partial, output, warps_no_, reduce_op + ); + } } } @@ -244,22 +248,26 @@ class block_reduce_warp_reduce input, output, num_valid, reduce_op ); - // i-th warp will have its partial stored in storage_.warp_partials[i-1] - if(lane_id == 0) + // Final reduction across warps is only required if there is more than 1 warp + if ROCPRIM_IF_CONSTEXPR (warps_no_ > 1) { - storage_.warp_partials[warp_id] = output; - } - ::rocprim::syncthreads(); - - if(flat_tid < warps_no_) - { - // Use warp partial to calculate the final reduce results for every thread - auto warp_partial = storage_.warp_partials[lane_id]; - - unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_; - warp_reduce_output_type().reduce( - warp_partial, output, valid_warps_no, reduce_op - ); + // i-th warp will have its partial stored in storage_.warp_partials[i-1] + if(lane_id == 0) + { + storage_.warp_partials[warp_id] = output; + } + ::rocprim::syncthreads(); + + if(flat_tid < warps_no_) + { + // Use warp partial to calculate the final reduce results for every thread + auto warp_partial = storage_.warp_partials[lane_id]; + + unsigned int valid_warps_no = (valid_items + warp_size_ - 1) / warp_size_; + warp_reduce_output_type().reduce( + warp_partial, output, valid_warps_no, reduce_op + ); + } } } };