diff --git a/csrc/selective_scan/reverse_scan.cuh b/csrc/selective_scan/reverse_scan.cuh index d19397879..0f4346b20 100644 --- a/csrc/selective_scan/reverse_scan.cuh +++ b/csrc/selective_scan/reverse_scan.cuh @@ -96,12 +96,14 @@ struct WarpReverseScan { /// Whether the logical warp size and the PTX warp size coincide - // In hipcub, warp_threads is defined as HIPCUB_WARP_THREADS ::rocprim::warp_size() - // While in cub, it's defined as a macro that takes a redundant unused argument. #ifndef USE_ROCM #define WARP_THREADS CUB_WARP_THREADS(0) #else - #define WARP_THREADS HIPCUB_WARP_THREADS + #if ROCM_MAJOR_VERSION >= 7 + #define WARP_THREADS rocprim::arch::wavefront::max_size() + #else + #define WARP_THREADS HIPCUB_WARP_THREADS + #endif #endif static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == WARP_THREADS); /// The number of warp scan steps diff --git a/setup.py b/setup.py index f61ca90d3..2ed98fad9 100755 --- a/setup.py +++ b/setup.py @@ -157,6 +157,7 @@ def append_nvcc_threads(nvcc_extra_args): "Refer to the README.md for detailed instructions.", UserWarning ) + cc_flag.append(f"-DROCM_MAJOR_VERSION={hip_version.major}") cc_flag.append("-DBUILD_PYTHON_PACKAGE")