diff --git a/CMakeLists.txt b/CMakeLists.txt index d9a58e987..42be31b77 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -124,13 +124,13 @@ set(CPACK_RPM_EXCLUDE_FROM_AUTO_FILELIST_ADDITION "\${CPACK_PACKAGING_INSTALL_PR if(HIP_PLATFORM STREQUAL "hcc") rocm_create_package( NAME rocprim - DESCRIPTION "Radeon Open Compute Parallel Primitives Libary" + DESCRIPTION "Radeon Open Compute Parallel Primitives Library" MAINTAINER "Stream HPC Maintainers " ) else() rocm_create_package( NAME rocprim-hipcub - DESCRIPTION "Radeon Open Compute Parallel Primitives Libary (hipCUB only)" + DESCRIPTION "Radeon Open Compute Parallel Primitives Library (hipCUB only)" MAINTAINER "Stream HPC Maintainers " ) endif() diff --git a/benchmark/benchmark_hc_block_histogram.cpp b/benchmark/benchmark_hc_block_histogram.cpp index 7fcc739a6..ddccd0037 100644 --- a/benchmark/benchmark_hc_block_histogram.cpp +++ b/benchmark/benchmark_hc_block_histogram.cpp @@ -35,7 +35,7 @@ #include "benchmark_utils.hpp" // rocPRIM -#include +#include #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; diff --git a/benchmark/benchmark_hc_block_radix_sort.cpp b/benchmark/benchmark_hc_block_radix_sort.cpp index 6f1d21dd1..63f7a6681 100644 --- a/benchmark/benchmark_hc_block_radix_sort.cpp +++ b/benchmark/benchmark_hc_block_radix_sort.cpp @@ -24,9 +24,11 @@ #include #include #include +#include #include #include #include +#include // Google Benchmark #include "benchmark/benchmark.h" diff --git a/benchmark/benchmark_hc_block_sort.cpp b/benchmark/benchmark_hc_block_sort.cpp index eb9e4b0b4..eef23fd3f 100644 --- a/benchmark/benchmark_hc_block_sort.cpp +++ b/benchmark/benchmark_hc_block_sort.cpp @@ -28,6 +28,7 @@ #include #include #include +#include // Google Benchmark #include "benchmark/benchmark.h" @@ -35,22 +36,9 @@ #include "cmdparser.hpp" #include "benchmark_utils.hpp" -// HIP API -#include -#include - // rocPRIM #include -#define HIP_CHECK(condition) \ - { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - std::cout << "HIP error: " << error << " line: " << __LINE__ << std::endl; \ - exit(error); \ - } \ - } - #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 128; #endif diff --git a/benchmark/benchmark_hc_device_reduce.cpp b/benchmark/benchmark_hc_device_reduce.cpp index a4ed54fd2..f122f7f44 100644 --- a/benchmark/benchmark_hc_device_reduce.cpp +++ b/benchmark/benchmark_hc_device_reduce.cpp @@ -34,7 +34,7 @@ // HC API #include -// rocPRIM HIP API +// rocPRIM #include // CmdParser diff --git a/benchmark/benchmark_hc_warp_scan.cpp b/benchmark/benchmark_hc_warp_scan.cpp index f4d6c8007..ab0e896ec 100644 --- a/benchmark/benchmark_hc_warp_scan.cpp +++ b/benchmark/benchmark_hc_warp_scan.cpp @@ -34,19 +34,10 @@ // HC API #include // rocPRIM -#include +#include #include "benchmark_utils.hpp" -#define HIP_CHECK(condition) \ - { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - std::cout << "HIP error: " << error << " line: " << __LINE__ << std::endl; \ - exit(error); \ - } \ - } - #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; #endif diff --git a/benchmark/benchmark_hc_warp_sort.cpp b/benchmark/benchmark_hc_warp_sort.cpp index 8ec9ab3d5..89c9af5bc 100644 --- a/benchmark/benchmark_hc_warp_sort.cpp +++ b/benchmark/benchmark_hc_warp_sort.cpp @@ -36,19 +36,10 @@ // HC API #include // rocPRIM -#include +#include #include "benchmark_utils.hpp" -#define HIP_CHECK(condition) \ - { \ - hipError_t error = condition; \ - if(error != hipSuccess){ \ - std::cout << "HIP error: " << error << " line: " << __LINE__ << std::endl; \ - exit(error); \ - } \ - } - #ifndef DEFAULT_N const size_t DEFAULT_N = 1024 * 1024 * 32; #endif diff --git a/benchmark/benchmark_hip_block_histogram.cpp b/benchmark/benchmark_hip_block_histogram.cpp index bdee6f690..495e7fb11 100644 --- a/benchmark/benchmark_hip_block_histogram.cpp +++ b/benchmark/benchmark_hip_block_histogram.cpp @@ -39,7 +39,7 @@ #include // rocPRIM -#include +#include #define HIP_CHECK(condition) \ { \ diff --git a/benchmark/benchmark_hip_block_reduce.cpp b/benchmark/benchmark_hip_block_reduce.cpp index 79e007455..46f697244 100644 --- a/benchmark/benchmark_hip_block_reduce.cpp +++ b/benchmark/benchmark_hip_block_reduce.cpp @@ -39,7 +39,7 @@ #include // rocPRIM -#include +#include #define HIP_CHECK(condition) \ { \ diff --git a/benchmark/benchmark_hip_device_reduce.cpp b/benchmark/benchmark_hip_device_reduce.cpp index f1f8b2342..7fdd658f7 100644 --- a/benchmark/benchmark_hip_device_reduce.cpp +++ b/benchmark/benchmark_hip_device_reduce.cpp @@ -35,7 +35,7 @@ #include // rocPRIM HIP API -#include +#include // CmdParser #include "cmdparser.hpp" diff --git a/benchmark/benchmark_hip_device_scan.cpp b/benchmark/benchmark_hip_device_scan.cpp index 117717214..53d328255 100644 --- a/benchmark/benchmark_hip_device_scan.cpp +++ b/benchmark/benchmark_hip_device_scan.cpp @@ -36,7 +36,7 @@ #include #include // rocPRIM -#include +#include #include "benchmark_utils.hpp" diff --git a/benchmark/benchmark_hip_warp_scan.cpp b/benchmark/benchmark_hip_warp_scan.cpp index c339215db..954d6e72f 100644 --- a/benchmark/benchmark_hip_warp_scan.cpp +++ b/benchmark/benchmark_hip_warp_scan.cpp @@ -36,7 +36,7 @@ #include #include // rocPRIM -#include +#include #include "benchmark_utils.hpp" diff --git a/benchmark/benchmark_hip_warp_sort.cpp b/benchmark/benchmark_hip_warp_sort.cpp index 51be3e844..b237835c7 100644 --- a/benchmark/benchmark_hip_warp_sort.cpp +++ b/benchmark/benchmark_hip_warp_sort.cpp @@ -36,7 +36,7 @@ #include #include // rocPRIM -#include +#include #include "benchmark_utils.hpp" diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 3ba4abd77..47d7791bf 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -80,7 +80,7 @@ if(BUILD_TEST) download_project( PROJ googletest GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG master + GIT_TAG release-1.8.1 INSTALL_DIR ${GTEST_ROOT} CMAKE_ARGS -DBUILD_GTEST=ON -DINSTALL_GTEST=ON -Dgtest_force_shared_crt=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_INSTALL_PREFIX= LOG_DOWNLOAD TRUE diff --git a/hipcub/include/hipcub/config.hpp b/hipcub/include/hipcub/config.hpp index 1844e76ac..7eede8455 100644 --- a/hipcub/include/hipcub/config.hpp +++ b/hipcub/include/hipcub/config.hpp @@ -1,7 +1,7 @@ // Copyright (c) 2017 Advanced Micro Devices, Inc. All rights reserved. // // Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the >Software>), to deal +// of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is @@ -10,7 +10,7 @@ // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // -// THE SOFTWARE IS PROVIDED >AS IS>, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER diff --git a/rocprim/include/rocprim/block/block_scan.hpp b/rocprim/include/rocprim/block/block_scan.hpp index 08104fb25..cbd854518 100644 --- a/rocprim/include/rocprim/block/block_scan.hpp +++ b/rocprim/include/rocprim/block/block_scan.hpp @@ -404,8 +404,8 @@ class block_scan /// The signature of the \p prefix_callback_op should be equivalent to the following: /// T f(const T &block_reduction);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. - /// The object will be called by the first thread in the first warp from the block with - /// block reduction of \p input values as input argument. The result will be used as the + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the /// block-wide prefix. /// \param [in] scan_op - binary operation function object that will be used for scan. /// The signature of the function should be equivalent to the following: @@ -794,8 +794,8 @@ class block_scan /// The signature of the \p prefix_callback_op should be equivalent to the following: /// T f(const T &block_reduction);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. - /// The object will be called by the first thread in the first warp from the block with - /// block reduction of \p input values as input argument. The result will be used as the + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the /// block-wide prefix. /// \param [in] scan_op - binary operation function object that will be used for scan. /// The signature of the function should be equivalent to the following: @@ -1168,8 +1168,8 @@ class block_scan /// The signature of the \p prefix_callback_op should be equivalent to the following: /// T f(const T &block_reduction);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. - /// The object will be called by the first thread in the first warp from the block with - /// block reduction of \p input values as input argument. The result will be used as the + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the /// block-wide prefix. /// \param [in] scan_op - binary operation function object that will be used for scan. /// The signature of the function should be equivalent to the following: @@ -1579,8 +1579,8 @@ class block_scan /// The signature of the \p prefix_callback_op should be equivalent to the following: /// T f(const T &block_reduction);. The signature does not need to have /// const &, but function object must not modify the objects passed to it. - /// The object will be called by the first thread in the first warp from the block with - /// block reduction of \p input values as input argument. The result will be used as the + /// The object will be called by the first warp of the block with block reduction of + /// \p input values as input argument. The result of the first thread will be used as the /// block-wide prefix. /// \param [in] scan_op - binary operation function object that will be used for scan. /// The signature of the function should be equivalent to the following: diff --git a/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp b/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp index a4b56c332..12ad70041 100644 --- a/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp +++ b/rocprim/include/rocprim/block/detail/block_sort_bitonic.hpp @@ -293,8 +293,8 @@ class block_sort_bitonic copy_to_shared(kv..., flat_tid, storage); bool is_even = (flat_tid % 2) == 0; - unsigned int odd_id = (is_even) ? std::max(flat_tid, (unsigned int) 1) - 1 : std::min(flat_tid + 1, Size - 1); - unsigned int even_id = (is_even) ? std::min(flat_tid + 1, Size - 1) : std::max(flat_tid, (unsigned int) 1) - 1; + unsigned int odd_id = (is_even) ? ::rocprim::max(flat_tid, 1u) - 1 : ::rocprim::min(flat_tid + 1, Size - 1); + unsigned int even_id = (is_even) ? ::rocprim::min(flat_tid + 1, Size - 1) : ::rocprim::max(flat_tid, 1u) - 1; #pragma unroll for(unsigned int length = 0; length < Size; length++) @@ -331,13 +331,18 @@ class block_sort_bitonic copy_to_shared(kv..., flat_tid, storage); bool is_even = (flat_tid % 2 == 0); - unsigned int odd_id = (is_even) ? std::max(flat_tid, (unsigned int) 1) - 1 : std::min(flat_tid + 1, size - 1); - unsigned int even_id = (is_even) ? std::min(flat_tid + 1, size - 1) : std::max(flat_tid, (unsigned int) 1) - 1; + unsigned int odd_id = (is_even) ? ::rocprim::max(flat_tid, 1u) - 1 : ::rocprim::min(flat_tid + 1, size - 1); + unsigned int even_id = (is_even) ? ::rocprim::min(flat_tid + 1, size - 1) : ::rocprim::max(flat_tid, 1u) - 1; for(unsigned int length = 0; length < size; length++) { unsigned int next_id = (length % 2 == 0) ? even_id : odd_id; - swap(kv..., flat_tid, next_id, 0, storage, compare_function); + // Use only "valid" keys to ensure that compare_function will not use garbage keys + // for example, as indices of an array (a lookup table) + if(flat_tid < size) + { + swap(kv..., flat_tid, next_id, 0, storage, compare_function); + } ::rocprim::syncthreads(); copy_to_shared(kv..., flat_tid, storage); } diff --git a/rocprim/include/rocprim/config.hpp b/rocprim/include/rocprim/config.hpp index 91d486dd3..3fdecdbfa 100644 --- a/rocprim/include/rocprim/config.hpp +++ b/rocprim/include/rocprim/config.hpp @@ -76,6 +76,12 @@ #define ROCPRIM_DETAIL_USE_DPP true #endif +#ifdef ROCPRIM_DISABLE_LOOKBACK_SCAN + #define ROCPRIM_DETAIL_USE_LOOKBACK_SCAN false +#else + #define ROCPRIM_DETAIL_USE_LOOKBACK_SCAN true +#endif + // Defines targeted AMD architecture. Supported values: // * 803 (gfx803) // * 900 (gfx900) diff --git a/rocprim/include/rocprim/detail/radix_sort.hpp b/rocprim/include/rocprim/detail/radix_sort.hpp index df0f0e212..f1f115679 100644 --- a/rocprim/include/rocprim/detail/radix_sort.hpp +++ b/rocprim/include/rocprim/detail/radix_sort.hpp @@ -103,15 +103,33 @@ template struct radix_key_codec_base { static_assert(sizeof(Key) == 0, - "Only integral (except bool) and floating point types supported as radix sort keys"); + "Only integral and floating point types supported as radix sort keys"); }; template struct radix_key_codec_base< Key, - typename std::enable_if<::rocprim::is_integral::value && !std::is_same::value>::type + typename std::enable_if<::rocprim::is_integral::value>::type > : radix_key_codec_integral::type> { }; +template<> +struct radix_key_codec_base +{ + using bit_key_type = unsigned char; + + ROCPRIM_DEVICE inline + static bit_key_type encode(bool key) + { + return static_cast(key); + } + + ROCPRIM_DEVICE inline + static bool decode(bit_key_type bit_key) + { + return static_cast(bit_key); + } +}; + template<> struct radix_key_codec_base<::rocprim::half> : radix_key_codec_floating<::rocprim::half, unsigned short> { }; diff --git a/rocprim/include/rocprim/detail/various.hpp b/rocprim/include/rocprim/detail/various.hpp index fbac587ed..65973d37c 100644 --- a/rocprim/include/rocprim/detail/various.hpp +++ b/rocprim/include/rocprim/detail/various.hpp @@ -157,7 +157,7 @@ struct match_fundamental_type template ROCPRIM_DEVICE inline auto store_volatile(T * output, T value) - -> typename std::enable_if<::rocprim::is_fundamental::value>::type + -> typename std::enable_if::value>::type { *const_cast(output) = value; } @@ -165,7 +165,7 @@ auto store_volatile(T * output, T value) template ROCPRIM_DEVICE inline auto store_volatile(T * output, T value) - -> typename std::enable_if::value>::type + -> typename std::enable_if::value>::type { using fundamental_type = typename match_fundamental_type::type; constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type); @@ -183,7 +183,7 @@ auto store_volatile(T * output, T value) template ROCPRIM_DEVICE inline auto load_volatile(T * input) - -> typename std::enable_if<::rocprim::is_fundamental::value, T>::type + -> typename std::enable_if::value, T>::type { T retval = *const_cast(input); return retval; @@ -192,7 +192,7 @@ auto load_volatile(T * input) template ROCPRIM_DEVICE inline auto load_volatile(T * input) - -> typename std::enable_if::value, T>::type + -> typename std::enable_if::value, T>::type { using fundamental_type = typename match_fundamental_type::type; constexpr unsigned int n = sizeof(T) / sizeof(fundamental_type); @@ -226,6 +226,21 @@ struct raw_storage } }; +// Checks if two iterators have the same type and value +template +inline +bool are_iterators_equal(Iterator1, Iterator2) +{ + return false; +} + +template +inline +bool are_iterators_equal(Iterator iter1, Iterator iter2) +{ + return iter1 == iter2; +} + } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/device/detail/device_merge_sort.hpp b/rocprim/include/rocprim/device/detail/device_merge_sort.hpp index b97b26f4d..55b38ffb7 100644 --- a/rocprim/include/rocprim/device/detail/device_merge_sort.hpp +++ b/rocprim/include/rocprim/device/detail/device_merge_sort.hpp @@ -335,9 +335,10 @@ void block_sort_kernel_impl(KeysInputIterator keys_input, { using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; + using stable_key_type = rocprim::tuple; constexpr bool with_values = !std::is_same::value; - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); const unsigned int block_offset = flat_block_id * BlockSize; const unsigned int number_of_blocks = (input_size + BlockSize - 1)/BlockSize; @@ -358,13 +359,23 @@ void block_sort_kernel_impl(KeysInputIterator keys_input, value ); + // Special comparison that preserves relative order of equal keys + auto stable_compare_function = [compare_function](const stable_key_type& a, const stable_key_type& b) -> bool + { + const bool ab = compare_function(rocprim::get<0>(a), rocprim::get<0>(b)); + const bool ba = compare_function(rocprim::get<0>(b), rocprim::get<0>(a)); + return ab || (!ba && (rocprim::get<1>(a) < rocprim::get<1>(b))); + }; + + stable_key_type stable_key = rocprim::make_tuple(key[0], flat_id); block_sort_impl( - key[0], + stable_key, value[0], valid_in_last_block, last_block, - compare_function + stable_compare_function ); + key[0] = rocprim::get<0>(stable_key); block_store_impl( flat_id, diff --git a/rocprim/include/rocprim/device/detail/device_partition.hpp b/rocprim/include/rocprim/device/detail/device_partition.hpp index 4ad5d7ad8..97d39fe29 100644 --- a/rocprim/include/rocprim/device/detail/device_partition.hpp +++ b/rocprim/include/rocprim/device/detail/device_partition.hpp @@ -377,7 +377,6 @@ auto partition_scatter(ValueType (&values)[ItemsPerThread], ::rocprim::syncthreads(); // sync threads to reuse shared memory // Coalesced write from shared memory to global memory - #pragma unroll for(unsigned int i = flat_block_thread_id; i < selected_in_block; i += BlockSize) { output[selected_prefix + i] = scatter_storage[i]; @@ -458,17 +457,20 @@ void partition_kernel_impl(InputIterator input, using exchange_storage_type = value_type[items_per_block]; using raw_exchange_storage_type = typename detail::raw_storage; - ROCPRIM_SHARED_MEMORY union + ROCPRIM_SHARED_MEMORY struct { - raw_exchange_storage_type exchange_values; typename order_bid_type::storage_type ordered_bid; - typename block_load_value_type::storage_type load_values; - typename block_load_flag_type::storage_type load_flags; - typename block_discontinuity_value_type::storage_type discontinuity_values; - typename block_scan_offset_type::storage_type scan_offsets; + union + { + raw_exchange_storage_type exchange_values; + typename block_load_value_type::storage_type load_values; + typename block_load_flag_type::storage_type load_flags; + typename block_discontinuity_value_type::storage_type discontinuity_values; + typename block_scan_offset_type::storage_type scan_offsets; + }; } storage; - const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>(); + const auto flat_block_thread_id = ::rocprim::flat_block_thread_id(); const auto flat_block_id = ordered_bid.get(flat_block_thread_id, storage.ordered_bid); const unsigned int block_offset = flat_block_id * items_per_block; const auto valid_in_last_block = size - items_per_block * (number_of_blocks - 1); @@ -550,7 +552,12 @@ void partition_kernel_impl(InputIterator input, } ::rocprim::syncthreads(); // sync threads to reuse shared memory } - else + // Workaround: Fiji (gfx803) crashes with "Memory access fault by GPU node" on HCC 1.3.18482 (ROCm 2.0) + // Instead of just `} else {` we use `} syncthreads(); if() {`, because the else-branch can be executed + // for some unknown reason and 0-th block reads incorrect addresses in lookback_scan_prefix_op::get_prefix. + ::rocprim::syncthreads(); + if(flat_block_id > 0) + // end of the workaround { ROCPRIM_SHARED_MEMORY typename offset_scan_prefix_op_type::storage_type storage_prefix_op; auto prefix_op = offset_scan_prefix_op_type( diff --git a/rocprim/include/rocprim/device/detail/device_reduce.hpp b/rocprim/include/rocprim/device/detail/device_reduce.hpp index 148a25d10..eb0f0c98a 100644 --- a/rocprim/include/rocprim/device/detail/device_reduce.hpp +++ b/rocprim/include/rocprim/device/detail/device_reduce.hpp @@ -98,7 +98,7 @@ void block_reduce_kernel_impl(InputIterator input, >; constexpr unsigned int items_per_block = block_size * items_per_thread; - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); const unsigned int block_offset = flat_block_id * items_per_block; const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>(); diff --git a/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp b/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp index ec0dc5b16..d7710a3a8 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_lookback.hpp @@ -214,15 +214,18 @@ void lookback_scan_kernel_impl(InputIterator input, result_type, BinaryFunction, LookbackScanState >; - ROCPRIM_SHARED_MEMORY union + ROCPRIM_SHARED_MEMORY struct { typename order_bid_type::storage_type ordered_bid; - typename block_load_type::storage_type load; - typename block_store_type::storage_type store; - typename block_scan_type::storage_type scan; + union + { + typename block_load_type::storage_type load; + typename block_store_type::storage_type store; + typename block_scan_type::storage_type scan; + }; } storage; - const auto flat_block_thread_id = ::rocprim::detail::block_thread_id<0>(); + const auto flat_block_thread_id = ::rocprim::flat_block_thread_id(); const auto flat_block_id = ordered_bid.get(flat_block_thread_id, storage.ordered_bid); const unsigned int block_offset = flat_block_id * items_per_block; const auto valid_in_last_block = size - items_per_block * (number_of_blocks - 1); @@ -267,7 +270,12 @@ void lookback_scan_kernel_impl(InputIterator input, scan_state.set_complete(flat_block_id, reduction); } } - else + // Workaround: Fiji (gfx803) crashes with "Memory access fault by GPU node" on HCC 1.3.18482 (ROCm 2.0) + // Instead of just `} else {` we use `} syncthreads(); if() {`, because the else-branch can be executed + // for some unknown reason and 0-th block reads incorrect addresses in lookback_scan_prefix_op::get_prefix. + ::rocprim::syncthreads(); + if(flat_block_id > 0) + // original code: else { // Scan of block values auto prefix_op = lookback_scan_prefix_op_type( diff --git a/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp b/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp index ad775804e..47be6a07d 100644 --- a/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp +++ b/rocprim/include/rocprim/device/detail/device_scan_reduce_then_scan.hpp @@ -196,8 +196,7 @@ void block_reduce_kernel_impl(InputIterator input, typename block_reduce_type::storage_type reduce; } storage; - // It's assumed kernel is executed in 1D - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); const unsigned int flat_block_id = ::rocprim::detail::block_id<0>(); const unsigned int block_offset = flat_block_id * items_per_thread * block_size; diff --git a/rocprim/include/rocprim/device/detail/device_segmented_reduce.hpp b/rocprim/include/rocprim/device/detail/device_segmented_reduce.hpp index 64af2f6d9..996383e4f 100644 --- a/rocprim/include/rocprim/device/detail/device_segmented_reduce.hpp +++ b/rocprim/include/rocprim/device/detail/device_segmented_reduce.hpp @@ -65,7 +65,7 @@ void segmented_reduce(InputIterator input, ROCPRIM_SHARED_MEMORY typename reduce_type::storage_type reduce_storage; - const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>(); + const unsigned int flat_id = ::rocprim::flat_block_thread_id(); const unsigned int segment_id = ::rocprim::detail::block_id<0>(); const unsigned int begin_offset = begin_offsets[segment_id]; diff --git a/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp b/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp index f64ff7d71..23d1f684c 100644 --- a/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp +++ b/rocprim/include/rocprim/device/detail/lookback_scan_state.hpp @@ -58,38 +58,25 @@ enum prefix_flag // a look-back prefix scan. Initially every prefix can be either // invalid (padding values) or empty. One thread in a block should // later set it to partial, and later to complete. -// -// is_arithmetic - arithmetic types up to 8 bytes have separate faster -// and simpler implementation. See below. -// TODO: consider other types that can be loaded in single op. -template::value> +template struct lookback_scan_state; -// Flag and prefix value are load/store in one operation. Volatile -// loads/stores are not used as there is no ordering of load/store -// operation within one prefix (prefix_type). +// Packed flag and prefix value are loaded/stored in one atomic operation. template struct lookback_scan_state { private: - using flag_type_ = + using flag_type_ = char; + + // Type which is used in store/load operations of block prefix (flag and value). + // It is 32-bit or 64-bit int and can be loaded/stored using single atomic instruction. + using prefix_underlying_type = typename std::conditional< - sizeof(T) == 8, - long long, - typename std::conditional< - sizeof(T) == 4, - int, - typename std::conditional< - sizeof(T) == 2, - short, - char - >::type - >::type + (sizeof(T) > 2), + unsigned long long, + unsigned int >::type; - // Type which is used in store/load operations of block prefix (flag and value). - // It is essential that this type is load/store using single instruction. - using prefix_underlying_type = typename make_vector_type::type; static constexpr unsigned int padding = ::rocprim::warp_size(); // Helper struct @@ -99,6 +86,8 @@ struct lookback_scan_state T value; } __attribute__((aligned(sizeof(prefix_underlying_type)))); + static_assert(sizeof(prefix_underlying_type) == sizeof(prefix_type), ""); + public: // Type used for flag/flag of block prefix using flag_type = flag_type_; @@ -124,16 +113,21 @@ struct lookback_scan_state void initialize_prefix(const unsigned int block_id, const unsigned int number_of_blocks) { - prefix_underlying_type prefix; if(block_id < number_of_blocks) { - reinterpret_cast(&prefix)->flag = PREFIX_EMPTY; - prefixes[padding + block_id] = prefix; + prefix_type prefix; + prefix.flag = PREFIX_EMPTY; + prefix_underlying_type p; + __builtin_memcpy(&p, &prefix, sizeof(prefix_type)); + prefixes[padding + block_id] = p; } if(block_id < padding) { - reinterpret_cast(&prefix)->flag = PREFIX_INVALID; - prefixes[block_id] = prefix; + prefix_type prefix; + prefix.flag = PREFIX_INVALID; + prefix_underlying_type p; + __builtin_memcpy(&p, &prefix, sizeof(prefix_type)); + prefixes[block_id] = p; } } @@ -156,10 +150,10 @@ struct lookback_scan_state prefix_type prefix; do { - ::rocprim::detail::memory_fence_system(); - auto p = prefixes[padding + block_id]; - prefix = *reinterpret_cast(&p); - } while(::rocprim::detail::warp_any(prefix.flag == PREFIX_EMPTY)); + // atomic_add(..., 0) is used to load values atomically + prefix_underlying_type p = ::rocprim::detail::atomic_add(&prefixes[padding + block_id], 0); + __builtin_memcpy(&prefix, &p, sizeof(prefix_type)); + } while(prefix.flag == PREFIX_EMPTY); // return flag = prefix.flag; @@ -171,17 +165,16 @@ struct lookback_scan_state void set(const unsigned int block_id, const flag_type flag, const T value) { prefix_type prefix = { flag, value }; - prefix_underlying_type p = *reinterpret_cast(&prefix); - prefixes[padding + block_id] = p; + prefix_underlying_type p; + __builtin_memcpy(&p, &prefix, sizeof(prefix_type)); + ::rocprim::detail::atomic_exch(&prefixes[padding + block_id], p); } prefix_underlying_type * prefixes; }; -#define ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_USE_VOLATILE 1 - -// This does not work for unknown reasons. Lookback-based scan should -// be only enabled for arithmetic types for now. +// Flag, partial and final prefixes are stored in separate arrays. +// Consistency ensured by memory fences between flag and prefixes load/store operations. template struct lookback_scan_state { @@ -237,67 +230,38 @@ struct lookback_scan_state ROCPRIM_DEVICE inline void set_partial(const unsigned int block_id, const T value) { - #ifdef ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_USE_VOLATILE - store_volatile(&prefixes_partial_values[padding + block_id], value); - ::rocprim::detail::memory_fence_device(); - store_volatile(&prefixes_flags[padding + block_id], PREFIX_PARTIAL); - #else - prefixes_partial_values[padding + block_id] = value; - // ::rocprim::detail::memory_fence_device() (aka __threadfence()) should be - // enough, but does not work when T is 32 bytes or bigger. - ::rocprim::detail::memory_fence_system(); - prefixes_flags[padding + block_id] = PREFIX_PARTIAL; - #endif + store_volatile(&prefixes_partial_values[padding + block_id], value); + ::rocprim::detail::memory_fence_device(); + store_volatile(&prefixes_flags[padding + block_id], PREFIX_PARTIAL); } ROCPRIM_DEVICE inline void set_complete(const unsigned int block_id, const T value) { - #ifdef ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_USE_VOLATILE - store_volatile(&prefixes_complete_values[padding + block_id], value); - ::rocprim::detail::memory_fence_device(); - store_volatile(&prefixes_flags[padding + block_id], PREFIX_COMPLETE); - #else - prefixes_complete_values[padding + block_id] = value; - // ::rocprim::detail::memory_fence_device() (aka __threadfence()) should be - // enough, but does not work when T is 32 bytes or bigger. - ::rocprim::detail::memory_fence_system(); - prefixes_flags[padding + block_id] = PREFIX_COMPLETE; - #endif + store_volatile(&prefixes_complete_values[padding + block_id], value); + ::rocprim::detail::memory_fence_device(); + store_volatile(&prefixes_flags[padding + block_id], PREFIX_COMPLETE); } // block_id must be > 0 ROCPRIM_DEVICE inline void get(const unsigned int block_id, flag_type& flag, T& value) { - #ifdef ROCPRIM_DETAIL_LOOKBACK_SCAN_STATE_USE_VOLATILE - do - { - ::rocprim::detail::memory_fence_system(); - flag = load_volatile(&prefixes_flags[padding + block_id]); - } while(flag == PREFIX_EMPTY); - - if(flag == PREFIX_PARTIAL) - value = load_volatile(&prefixes_partial_values[padding + block_id]); - else - value = load_volatile(&prefixes_complete_values[padding + block_id]); - #else - do - { - ::rocprim::detail::memory_fence_system(); - flag = prefixes_flags[padding + block_id]; - } while(flag == PREFIX_EMPTY); - - if(flag == PREFIX_PARTIAL) - value = prefixes_partial_values[padding + block_id]; - else - value = prefixes_complete_values[padding + block_id]; - #endif + do + { + flag = load_volatile(&prefixes_flags[padding + block_id]); + ::rocprim::detail::memory_fence_device(); + } while(flag == PREFIX_EMPTY); + + if(flag == PREFIX_PARTIAL) + value = load_volatile(&prefixes_partial_values[padding + block_id]); + else + value = load_volatile(&prefixes_complete_values[padding + block_id]); } private: flag_type * prefixes_flags; - // We need to seprate arrays for partial and final prefixes, because + // We need to separate arrays for partial and final prefixes, because // value can be overwritten before flag is changed (flag and value are // not stored in single instruction). T * prefixes_partial_values; diff --git a/rocprim/include/rocprim/device/detail/ordered_block_id.hpp b/rocprim/include/rocprim/device/detail/ordered_block_id.hpp index 1760f8eb3..118c6b9e5 100644 --- a/rocprim/include/rocprim/device/detail/ordered_block_id.hpp +++ b/rocprim/include/rocprim/device/detail/ordered_block_id.hpp @@ -69,10 +69,9 @@ struct ordered_block_id ROCPRIM_DEVICE inline id_type get(unsigned int tid, storage_type& storage) { - constexpr static id_type max = std::numeric_limits::max(); if(tid == 0) { - storage.id = ::rocprim::detail::atomic_wrapinc(this->id, max); + storage.id = ::rocprim::detail::atomic_add(this->id, 1); } ::rocprim::syncthreads(); return storage.id; diff --git a/rocprim/include/rocprim/device/device_merge_sort_config.hpp b/rocprim/include/rocprim/device/device_merge_sort_config.hpp index 44610cf12..3f932c241 100644 --- a/rocprim/include/rocprim/device/device_merge_sort_config.hpp +++ b/rocprim/include/rocprim/device/device_merge_sort_config.hpp @@ -42,42 +42,47 @@ using merge_sort_config = kernel_config; namespace detail { +// TODO investigate why some tests fail with block size > 256 template struct merge_sort_config_803 { - static constexpr size_t key_value_size = sizeof(Key) + sizeof(Value); - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(key_value_size, 8); + // static constexpr size_t key_value_size = sizeof(Key) + sizeof(Value); + // static constexpr unsigned int item_scale = + // ::rocprim::detail::ceiling_div(key_value_size, 8); - using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + using type = merge_sort_config<256U>; }; template struct merge_sort_config_803 { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(sizeof(Key), 8); + // static constexpr unsigned int item_scale = + // ::rocprim::detail::ceiling_div(sizeof(Key), 8); - using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + using type = merge_sort_config<256U>; }; template struct merge_sort_config_900 { - static constexpr size_t key_value_size = sizeof(Key) + sizeof(Value); - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(key_value_size, 16); + // static constexpr size_t key_value_size = sizeof(Key) + sizeof(Value); + // static constexpr unsigned int item_scale = + // ::rocprim::detail::ceiling_div(key_value_size, 16); - using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + using type = merge_sort_config<256U>; }; template struct merge_sort_config_900 { - static constexpr unsigned int item_scale = - ::rocprim::detail::ceiling_div(sizeof(Key), 16); + // static constexpr unsigned int item_scale = + // ::rocprim::detail::ceiling_div(sizeof(Key), 16); - using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + // using type = merge_sort_config<::rocprim::max(256U, 1024U / item_scale)>; + using type = merge_sort_config<256U>; }; template diff --git a/rocprim/include/rocprim/device/device_merge_sort_hip.hpp b/rocprim/include/rocprim/device/device_merge_sort_hip.hpp index 978c44aed..01b564c85 100644 --- a/rocprim/include/rocprim/device/device_merge_sort_hip.hpp +++ b/rocprim/include/rocprim/device/device_merge_sort_hip.hpp @@ -167,7 +167,7 @@ hipError_t merge_sort_impl(void * temporary_storage, char* ptr = reinterpret_cast(temporary_storage); key_type * keys_buffer = reinterpret_cast(ptr); ptr += keys_bytes; - value_type* values_buffer = + value_type * values_buffer = with_values ? reinterpret_cast(ptr) : nullptr; // Start point for time measurements @@ -213,20 +213,20 @@ hipError_t merge_sort_impl(void * temporary_storage, if(temporary_store) { - if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - ::rocprim::transform( + hipError_t error = ::rocprim::transform( keys_buffer, keys_output, size, ::rocprim::identity(), stream, debug_synchronous ); + if(error != hipSuccess) return error; if(with_values) { - ::rocprim::transform( + hipError_t error = ::rocprim::transform( values_buffer, values_output, size, ::rocprim::identity(), stream, debug_synchronous ); + if(error != hipSuccess) return error; } - ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("transform", size, start); } return hipSuccess; diff --git a/rocprim/include/rocprim/device/device_radix_sort_hc.hpp b/rocprim/include/rocprim/device/device_radix_sort_hc.hpp index 4d5d83631..19fce06f7 100644 --- a/rocprim/include/rocprim/device/device_radix_sort_hc.hpp +++ b/rocprim/include/rocprim/device/device_radix_sort_hc.hpp @@ -76,9 +76,9 @@ void radix_sort_iteration(KeysInputIterator keys_input, unsigned int size, unsigned int * batch_digit_counts, unsigned int * digit_counts, + bool from_input, bool to_output, unsigned int bit, - unsigned int begin_bit, unsigned int end_bit, unsigned int blocks_per_full_batch, unsigned int full_batches, @@ -92,8 +92,6 @@ void radix_sort_iteration(KeysInputIterator keys_input, // iteration has a shorter mask. const unsigned int current_radix_bits = ::rocprim::min(RadixBits, end_bit - bit); - const bool is_first_iteration = (bit == begin_bit); - std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) @@ -104,7 +102,7 @@ void radix_sort_iteration(KeysInputIterator keys_input, } if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - if(is_first_iteration) + if(from_input) { hc::parallel_for_each( acc_view, @@ -188,7 +186,7 @@ void radix_sort_iteration(KeysInputIterator keys_input, ROCPRIM_DETAIL_HC_SYNC("scan_digits", radix_size, start) if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - if(is_first_iteration) + if(from_input) { if(to_output) { @@ -365,7 +363,7 @@ void radix_sort_impl(void * temporary_storage, ptr += batch_digit_counts_bytes; unsigned int * digit_counts = reinterpret_cast(ptr); ptr += digit_counts_bytes; - if(!with_double_buffer) + if(!with_double_buffer) { keys_tmp = reinterpret_cast(ptr); ptr += keys_bytes; @@ -373,19 +371,45 @@ void radix_sort_impl(void * temporary_storage, } bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; + bool from_input = true; + if(!with_double_buffer && to_output) + { + // Copy input keys and values if necessary (in-place sorting: input and output iterators are equal) + const bool keys_equal = ::rocprim::detail::are_iterators_equal(keys_input, keys_output); + const bool values_equal = with_values && ::rocprim::detail::are_iterators_equal(values_input, values_output); + if(keys_equal || values_equal) + { + ::rocprim::transform( + keys_input, keys_tmp, size, + ::rocprim::identity(), acc_view, debug_synchronous + ); + + if(with_values) + { + ::rocprim::transform( + values_input, values_tmp, size, + ::rocprim::identity(), acc_view, debug_synchronous + ); + } + + from_input = false; + } + } + unsigned int bit = begin_bit; for(unsigned int i = 0; i < long_iterations; i++) { radix_sort_iteration( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, size, batch_digit_counts, digit_counts, - to_output, - bit, begin_bit, end_bit, + from_input, to_output, + bit, end_bit, blocks_per_full_batch, full_batches, batches, acc_view, debug_synchronous ); is_result_in_output = to_output; + from_input = false; to_output = !to_output; bit += config::long_radix_bits; } @@ -394,13 +418,14 @@ void radix_sort_impl(void * temporary_storage, radix_sort_iteration( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, size, batch_digit_counts, digit_counts, - to_output, - bit, begin_bit, end_bit, + from_input, to_output, + bit, end_bit, blocks_per_full_batch, full_batches, batches, acc_view, debug_synchronous ); is_result_in_output = to_output; + from_input = false; to_output = !to_output; bit += config::short_radix_bits; } diff --git a/rocprim/include/rocprim/device/device_radix_sort_hip.hpp b/rocprim/include/rocprim/device/device_radix_sort_hip.hpp index 75057f556..b9507ad94 100644 --- a/rocprim/include/rocprim/device/device_radix_sort_hip.hpp +++ b/rocprim/include/rocprim/device/device_radix_sort_hip.hpp @@ -154,9 +154,9 @@ hipError_t radix_sort_iteration(KeysInputIterator keys_input, unsigned int size, unsigned int * batch_digit_counts, unsigned int * digit_counts, + bool from_input, bool to_output, unsigned int bit, - unsigned int begin_bit, unsigned int end_bit, unsigned int blocks_per_full_batch, unsigned int full_batches, @@ -170,8 +170,6 @@ hipError_t radix_sort_iteration(KeysInputIterator keys_input, // iteration has a shorter mask. const unsigned int current_radix_bits = ::rocprim::min(RadixBits, end_bit - bit); - const bool is_first_iteration = (bit == begin_bit); - std::chrono::high_resolution_clock::time_point start; if(debug_synchronous) @@ -182,7 +180,7 @@ hipError_t radix_sort_iteration(KeysInputIterator keys_input, } if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - if(is_first_iteration) + if(from_input) { hipLaunchKernelGGL( HIP_KERNEL_NAME(fill_digit_counts_kernel< @@ -243,7 +241,7 @@ hipError_t radix_sort_iteration(KeysInputIterator keys_input, ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_digits", radix_size, start) if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - if(is_first_iteration) + if(from_input) { if(to_output) { @@ -415,20 +413,48 @@ hipError_t radix_sort_impl(void * temporary_storage, } bool to_output = with_double_buffer || (iterations - 1) % 2 == 0; + bool from_input = true; + if(!with_double_buffer && to_output) + { + // Copy input keys and values if necessary (in-place sorting: input and output iterators are equal) + const bool keys_equal = ::rocprim::detail::are_iterators_equal(keys_input, keys_output); + const bool values_equal = with_values && ::rocprim::detail::are_iterators_equal(values_input, values_output); + if(keys_equal || values_equal) + { + hipError_t error = ::rocprim::transform( + keys_input, keys_tmp, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != hipSuccess) return error; + + if(with_values) + { + hipError_t error = ::rocprim::transform( + values_input, values_tmp, size, + ::rocprim::identity(), stream, debug_synchronous + ); + if(error != hipSuccess) return error; + } + + from_input = false; + } + } + unsigned int bit = begin_bit; for(unsigned int i = 0; i < long_iterations; i++) { hipError_t error = radix_sort_iteration( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, size, batch_digit_counts, digit_counts, - to_output, - bit, begin_bit, end_bit, + from_input, to_output, + bit, end_bit, blocks_per_full_batch, full_batches, batches, stream, debug_synchronous ); if(error != hipSuccess) return error; is_result_in_output = to_output; + from_input = false; to_output = !to_output; bit += config::long_radix_bits; } @@ -437,14 +463,15 @@ hipError_t radix_sort_impl(void * temporary_storage, hipError_t error = radix_sort_iteration( keys_input, keys_tmp, keys_output, values_input, values_tmp, values_output, size, batch_digit_counts, digit_counts, - to_output, - bit, begin_bit, end_bit, + from_input, to_output, + bit, end_bit, blocks_per_full_batch, full_batches, batches, stream, debug_synchronous ); if(error != hipSuccess) return error; is_result_in_output = to_output; + from_input = false; to_output = !to_output; bit += config::short_radix_bits; } diff --git a/rocprim/include/rocprim/device/device_scan_config.hpp b/rocprim/include/rocprim/device/device_scan_config.hpp index 1a8b7c00f..126064732 100644 --- a/rocprim/include/rocprim/device/device_scan_config.hpp +++ b/rocprim/include/rocprim/device/device_scan_config.hpp @@ -41,12 +41,14 @@ BEGIN_ROCPRIM_NAMESPACE /// /// \tparam BlockSize - number of threads in a block. /// \tparam ItemsPerThread - number of items processed by each thread. +/// \tparam UseLookback - whether to use lookback scan or reduce-then-scan algorithm. /// \tparam BlockLoadMethod - method for loading input values. /// \tparam StoreLoadMethod - method for storing values. /// \tparam BlockScanMethod - algorithm for block scan. template< unsigned int BlockSize, unsigned int ItemsPerThread, + bool UseLookback, ::rocprim::block_load_method BlockLoadMethod, ::rocprim::block_store_method BlockStoreMethod, ::rocprim::block_scan_algorithm BlockScanMethod @@ -57,6 +59,8 @@ struct scan_config static constexpr unsigned int block_size = BlockSize; /// \brief Number of items processed by each thread. static constexpr unsigned int items_per_thread = ItemsPerThread; + /// \brief Whether to use lookback scan or reduce-then-scan algorithm. + static constexpr bool use_lookback = UseLookback; /// \brief Method for loading input values. static constexpr block_load_method block_load_method = BlockLoadMethod; /// \brief Method for storing values. @@ -77,6 +81,7 @@ struct scan_config_803 using type = scan_config< 256, ::rocprim::max(1u, 16u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, ::rocprim::block_scan_algorithm::using_warp_scan @@ -92,6 +97,7 @@ struct scan_config_900 using type = scan_config< 256, ::rocprim::max(1u, 16u / item_scale), + ROCPRIM_DETAIL_USE_LOOKBACK_SCAN, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, ::rocprim::block_scan_algorithm::using_warp_scan diff --git a/rocprim/include/rocprim/device/device_scan_hc.hpp b/rocprim/include/rocprim/device/device_scan_hc.hpp index 98752dd20..aa3fcb074 100644 --- a/rocprim/include/rocprim/device/device_scan_hc.hpp +++ b/rocprim/include/rocprim/device/device_scan_hc.hpp @@ -56,7 +56,6 @@ namespace detail template< bool Exclusive, - bool UseLoopback, class Config, class InputIterator, class OutputIterator, @@ -73,7 +72,7 @@ auto scan_impl(void * temporary_storage, BinaryFunction scan_op, hc::accelerator_view acc_view, const bool debug_synchronous) - -> typename std::enable_if::type + -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; @@ -82,10 +81,7 @@ auto scan_impl(void * temporary_storage, >::type; // Get default config if Config is default_config - using config = default_or_custom_config< - Config, - default_scan_config - >; + using config = Config; constexpr unsigned int block_size = config::block_size; constexpr unsigned int items_per_thread = config::items_per_thread; @@ -142,7 +138,7 @@ auto scan_impl(void * temporary_storage, auto nested_temp_storage_size = storage_size - (number_of_blocks * sizeof(result_type)); if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - scan_impl( + scan_impl( nested_temp_storage, nested_temp_storage_size, block_prefixes, // input @@ -190,7 +186,6 @@ auto scan_impl(void * temporary_storage, template< bool Exclusive, - bool UseLoopback, class Config, class InputIterator, class OutputIterator, @@ -207,7 +202,7 @@ auto scan_impl(void * temporary_storage, BinaryFunction scan_op, hc::accelerator_view acc_view, const bool debug_synchronous) - -> typename std::enable_if::type + -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; @@ -216,10 +211,7 @@ auto scan_impl(void * temporary_storage, >::type; // Get default config if Config is default_config - using config = default_or_custom_config< - Config, - default_scan_config - >; + using config = Config; using scan_state_type = detail::lookback_scan_state; using ordered_block_id_type = detail::ordered_block_id; @@ -407,10 +399,13 @@ void inclusive_scan(void * temporary_storage, input_type, output_type, BinaryFunction >::type; - // Lookback scan has problems with types that are not arithmetic - // TODO: Investigate why the compiler never finishes linking if half is used - // Workaround: rocprim::is_arithmetic is replaced by std::is_arithmetic - detail::scan_impl::value, Config>( + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_config + >; + + detail::scan_impl( temporary_storage, storage_size, // result_type() is a dummy initial value (not used) input, output, result_type(), size, @@ -526,10 +521,13 @@ void exclusive_scan(void * temporary_storage, input_type, output_type, BinaryFunction >::type; - // Lookback scan has problems with types that are not arithmetic - // TODO: Investigate why the compiler never finishes linking if half is used - // Workaround: rocprim::is_arithmetic is replaced by std::is_arithmetic - detail::scan_impl::value, Config>( + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_config + >; + + detail::scan_impl( temporary_storage, storage_size, input, output, initial_value, size, scan_op, acc_view, debug_synchronous diff --git a/rocprim/include/rocprim/device/device_scan_hip.hpp b/rocprim/include/rocprim/device/device_scan_hip.hpp index d518c6703..f88a1cff4 100644 --- a/rocprim/include/rocprim/device/device_scan_hip.hpp +++ b/rocprim/include/rocprim/device/device_scan_hip.hpp @@ -170,7 +170,6 @@ void init_lookback_scan_state_kernel(LookBackScanState lookback_scan_state, template< bool Exclusive, - bool UseLoopback, class Config, class InputIterator, class OutputIterator, @@ -187,7 +186,7 @@ auto scan_impl(void * temporary_storage, BinaryFunction scan_op, const hipStream_t stream, bool debug_synchronous) - -> typename std::enable_if::type + -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; @@ -195,11 +194,7 @@ auto scan_impl(void * temporary_storage, input_type, output_type, BinaryFunction >::type; - // Get default config if Config is default_config - using config = default_or_custom_config< - Config, - default_scan_config - >; + using config = Config; constexpr unsigned int block_size = config::block_size; constexpr unsigned int items_per_thread = config::items_per_thread; @@ -252,7 +247,7 @@ auto scan_impl(void * temporary_storage, auto nested_temp_storage_size = storage_size - (number_of_blocks * sizeof(result_type)); if(debug_synchronous) start = std::chrono::high_resolution_clock::now(); - auto error = scan_impl( + auto error = scan_impl( nested_temp_storage, nested_temp_storage_size, block_prefixes, // input @@ -305,7 +300,6 @@ auto scan_impl(void * temporary_storage, template< bool Exclusive, - bool UseLoopback, class Config, class InputIterator, class OutputIterator, @@ -322,7 +316,7 @@ auto scan_impl(void * temporary_storage, BinaryFunction scan_op, const hipStream_t stream, bool debug_synchronous) - -> typename std::enable_if::type + -> typename std::enable_if::type { using input_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; @@ -330,11 +324,7 @@ auto scan_impl(void * temporary_storage, input_type, output_type, BinaryFunction >::type; - // Get default config if Config is default_config - using config = default_or_custom_config< - Config, - default_scan_config - >; + using config = Config; using scan_state_type = detail::lookback_scan_state; using ordered_block_id_type = detail::ordered_block_id; @@ -518,10 +508,13 @@ hipError_t inclusive_scan(void * temporary_storage, input_type, output_type, BinaryFunction >::type; - // Lookback scan has problems with types that are not arithmetic - // TODO: Investigate why the compiler never finishes linking if half is used - // Workaround: rocprim::is_arithmetic is replaced by std::is_arithmetic - return detail::scan_impl::value, Config>( + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_config + >; + + return detail::scan_impl( temporary_storage, storage_size, // result_type() is a dummy initial value (not used) input, output, result_type(), size, @@ -637,10 +630,13 @@ hipError_t exclusive_scan(void * temporary_storage, input_type, output_type, BinaryFunction >::type; - // Lookback scan has problems with types that are not arithmetic - // TODO: Investigate why the compiler never finishes linking if half is used - // Workaround: rocprim::is_arithmetic is replaced by std::is_arithmetic - return detail::scan_impl::value, Config>( + // Get default config if Config is default_config + using config = detail::default_or_custom_config< + Config, + detail::default_scan_config + >; + + return detail::scan_impl( temporary_storage, storage_size, input, output, initial_value, size, scan_op, stream, debug_synchronous diff --git a/rocprim/include/rocprim/intrinsics/atomic.hpp b/rocprim/include/rocprim/intrinsics/atomic.hpp index 0e373531e..6804e2c5d 100644 --- a/rocprim/include/rocprim/intrinsics/atomic.hpp +++ b/rocprim/include/rocprim/intrinsics/atomic.hpp @@ -76,6 +76,26 @@ namespace detail return ::atomicInc(address, value); #endif } + + ROCPRIM_DEVICE inline + unsigned int atomic_exch(unsigned int * address, unsigned int value) + { + #ifdef ROCPRIM_HC_API + return hc::atomic_exchange(address, value); + #else + return ::atomicExch(address, value); + #endif + } + + ROCPRIM_DEVICE inline + unsigned long long atomic_exch(unsigned long long * address, unsigned long long value) + { + #ifdef ROCPRIM_HC_API + return hc::atomic_exchange(reinterpret_cast(address), static_cast(value)); + #else + return ::atomicExch(address, value); + #endif + } } END_ROCPRIM_NAMESPACE diff --git a/rocprim/include/rocprim/intrinsics/thread.hpp b/rocprim/include/rocprim/intrinsics/thread.hpp index 0a00e2099..978652252 100644 --- a/rocprim/include/rocprim/intrinsics/thread.hpp +++ b/rocprim/include/rocprim/intrinsics/thread.hpp @@ -258,36 +258,61 @@ namespace detail } #ifdef ROCPRIM_HIP_API + ROCPRIM_DEVICE inline - void memory_fence_system(void) + void memory_fence_system() { ::__threadfence_system(); } ROCPRIM_DEVICE inline - void memory_fence_block(void) + void memory_fence_block() { ::__threadfence_block(); } ROCPRIM_DEVICE inline - void memory_fence_device(void) + void memory_fence_device() { ::__threadfence(); } + #else - // __threadfence_system() + + extern "C" ROCPRIM_DEVICE void __atomic_work_item_fence(unsigned int, unsigned int, unsigned int); + + // Works like __threadfence_system() ROCPRIM_DEVICE inline - void memory_fence_system(void) + void memory_fence_system() { - std::atomic_thread_fence(std::memory_order_seq_cst); + __atomic_work_item_fence( + 0, + /* memory_order_seq_cst */ __ATOMIC_SEQ_CST, + /* memory_scope_all_svm_devices */ __OPENCL_MEMORY_SCOPE_ALL_SVM_DEVICES + ); } // Works like __threadfence_block() - extern __attribute__((const)) ROCPRIM_DEVICE void memory_fence_block() __asm("__llvm_fence_sc_wg"); + ROCPRIM_DEVICE inline + void memory_fence_block() + { + __atomic_work_item_fence( + 0, + /* memory_order_seq_cst */ __ATOMIC_SEQ_CST, + /* memory_scope_work_group */ __OPENCL_MEMORY_SCOPE_WORK_GROUP + ); + } // Works like __threadfence() - extern __attribute__((const)) ROCPRIM_DEVICE void memory_fence_device() __asm("__llvm_fence_sc_dev"); + ROCPRIM_DEVICE inline + void memory_fence_device() + { + __atomic_work_item_fence( + 0, + /* memory_order_seq_cst */ __ATOMIC_SEQ_CST, + /* memory_scope_device */ __OPENCL_MEMORY_SCOPE_DEVICE + ); + } #endif } diff --git a/rocprim/include/rocprim/type_traits.hpp b/rocprim/include/rocprim/type_traits.hpp index df9860bd1..5cc0ff34d 100644 --- a/rocprim/include/rocprim/type_traits.hpp +++ b/rocprim/include/rocprim/type_traits.hpp @@ -48,7 +48,7 @@ using is_integral = std::is_integral; /// \brief Behaves like std::is_arithmetic, but also includes half-precision /// floating point type (\ref rocprim::half). -template< class T > +template struct is_arithmetic : std::integral_constant< bool, @@ -58,7 +58,7 @@ struct is_arithmetic /// \brief Behaves like std::is_fundamental, but also includes half-precision /// floating point type (\ref rocprim::half). -template< class T > +template struct is_fundamental : std::integral_constant< bool, @@ -72,7 +72,7 @@ using is_unsigned = std::is_unsigned; /// \brief Behaves like std::is_signed, but also includes half-precision /// floating point type (\ref rocprim::half). -template +template struct is_signed : std::integral_constant< bool, @@ -82,7 +82,7 @@ struct is_signed /// \brief Behaves like std::is_scalar, but also includes half-precision /// floating point type (\ref rocprim::half). -template +template struct is_scalar : std::integral_constant< bool, @@ -92,7 +92,7 @@ struct is_scalar /// \brief Behaves like std::is_compound, but also supports half-precision /// floating point type (\ref rocprim::half). `value` for \ref rocprim::half is `false`. -template +template struct is_compound : std::integral_constant< bool, @@ -104,4 +104,4 @@ END_ROCPRIM_NAMESPACE /// @} // end of group utilsmodule_typetraits -#endif // ROCPRIM_TYPE_TRAITS_HPP_ \ No newline at end of file +#endif // ROCPRIM_TYPE_TRAITS_HPP_ diff --git a/test/rocprim/test_hc_device_scan.cpp b/test/rocprim/test_hc_device_scan.cpp index 90bc74018..1405966ec 100644 --- a/test/rocprim/test_hc_device_scan.cpp +++ b/test/rocprim/test_hc_device_scan.cpp @@ -64,18 +64,28 @@ class RocprimDeviceScanTests : public ::testing::Test }; typedef ::testing::Types< + // Small + DeviceScanParams, DeviceScanParams, + DeviceScanParams, DeviceScanParams, + DeviceScanParams >, + DeviceScanParams, + // Large + DeviceScanParams >, + DeviceScanParams >, + DeviceScanParams >, + DeviceScanParams >, + DeviceScanParams >, DeviceScanParams >, - DeviceScanParams, DeviceScanParams >, DeviceScanParams >, + DeviceScanParams >, DeviceScanParams< test_utils::custom_test_type, test_utils::custom_test_type, rp::plus > >, - DeviceScanParams, - DeviceScanParams + DeviceScanParams > > RocprimDeviceScanTestsParams; std::vector get_sizes() @@ -83,7 +93,8 @@ std::vector get_sizes() std::vector sizes = { 1, 10, 53, 211, 1024, 2048, 5096, - 34567, (1 << 18) + 34567, (1 << 18), + (1 << 20) - 12345 }; const std::vector random_sizes = test_utils::get_random_data(3, 1, 100000); sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); @@ -117,7 +128,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) // Get size of d_temp_storage rocprim::inclusive_scan( nullptr, temp_storage_size_bytes, - rocprim::make_constant_iterator(345), + rocprim::make_constant_iterator(T(345)), d_checking_output, 0, scan_op_type(), acc_view, debug_synchronous ); @@ -129,7 +140,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) // Run rocprim::inclusive_scan( d_temp_storage.accelerator_pointer(), temp_storage_size_bytes, - rocprim::make_constant_iterator(345), + rocprim::make_constant_iterator(T(345)), d_checking_output, 0, scan_op_type(), acc_view, debug_synchronous ); diff --git a/test/rocprim/test_hc_intrinsics.cpp b/test/rocprim/test_hc_intrinsics.cpp index 96477d0d8..681cd1cf8 100644 --- a/test/rocprim/test_hc_intrinsics.cpp +++ b/test/rocprim/test_hc_intrinsics.cpp @@ -93,7 +93,7 @@ struct custom_16aligned this->f += rhs.f; return *this; } -} __attribute__((aligned(16)));; +} __attribute__((aligned(16))); inline ROCPRIM_HOST_DEVICE custom_16aligned operator+(custom_16aligned lhs, const custom_16aligned& rhs) diff --git a/test/rocprim/test_hip_block_scan.cpp b/test/rocprim/test_hip_block_scan.cpp index dace5313c..b7c514f28 100644 --- a/test/rocprim/test_hip_block_scan.cpp +++ b/test/rocprim/test_hip_block_scan.cpp @@ -30,7 +30,7 @@ #include "test_utils.hpp" -#define HIP_CHECK(error) ASSERT_EQ(static_cast(error), hipSuccess) +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) namespace rp = rocprim; @@ -90,6 +90,9 @@ typedef ::testing::Types< params, params, params, + // custom structs tests + params, 128>, + params, 256U>, // ----------------------------------------------------------------------- // rocprim::block_scan_algorithm::reduce_then_scan // ----------------------------------------------------------------------- @@ -103,7 +106,9 @@ typedef ::testing::Types< params, params, params, - params + params, + params, 140, 1, rocprim::block_scan_algorithm::reduce_then_scan>, + params, 201U, 1, rocprim::block_scan_algorithm::reduce_then_scan> > SingleValueTestParams; TYPED_TEST_CASE(RocprimBlockScanSingleValueTests, SingleValueTestParams); @@ -183,10 +188,7 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, InclusiveScan) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_EQ(output[i], expected[i]); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); HIP_CHECK(hipFree(device_output)); } @@ -289,15 +291,8 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, InclusiveScanReduce) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_EQ(output[i], expected[i]); - } - - for(size_t i = 0; i < output_reductions.size(); i++) - { - ASSERT_EQ(output_reductions[i], expected_reductions[i]); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_reductions, expected_reductions, 0.01)); HIP_CHECK(hipFree(device_output)); HIP_CHECK(hipFree(device_output_reductions)); @@ -316,7 +311,7 @@ void inclusive_scan_prefix_callback_kernel(T* device_output, T* device_output_bp auto prefix_callback = [&prefix_value](T reduction) { T prefix = prefix_value; - prefix_value += reduction; + prefix_value = prefix_value + reduction; return prefix; }; @@ -413,15 +408,8 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, InclusiveScanPrefixCallback) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_EQ(output[i], expected[i]); - } - - for(size_t i = 0; i < output_block_prefixes.size(); i++) - { - ASSERT_EQ(output_block_prefixes[i], expected_block_prefixes[i]); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_block_prefixes, expected_block_prefixes, 0.01)); HIP_CHECK(hipFree(device_output)); HIP_CHECK(hipFree(device_output_bp)); @@ -504,10 +492,7 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, ExclusiveScan) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_EQ(output[i], expected[i]); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); HIP_CHECK(hipFree(device_output)); } @@ -568,7 +553,7 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, ExclusiveScanReduce) for(size_t j = 0; j < block_size; j++) { auto idx = i * block_size + j; - expected_reductions[i] += output[idx]; + expected_reductions[i] = expected_reductions[i] + output[idx]; } } @@ -619,15 +604,8 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, ExclusiveScanReduce) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_EQ(output[i], expected[i]); - } - - for(size_t i = 0; i < output_reductions.size(); i++) - { - ASSERT_EQ(output_reductions[i], expected_reductions[i]); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_reductions, expected_reductions, 0.01)); HIP_CHECK(hipFree(device_output)); HIP_CHECK(hipFree(device_output_reductions)); @@ -646,7 +624,7 @@ void exclusive_scan_prefix_callback_kernel(T* device_output, T* device_output_bp auto prefix_callback = [&prefix_value](T reduction) { T prefix = prefix_value; - prefix_value += reduction; + prefix_value = prefix_value + reduction; return prefix; }; @@ -698,7 +676,7 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, ExclusiveScanPrefixCallback) for(size_t j = 0; j < block_size; j++) { auto idx = i * block_size + j; - expected_block_prefixes[i] += output[idx]; + expected_block_prefixes[i] = expected_block_prefixes[i] + output[idx]; } } @@ -749,98 +727,13 @@ TYPED_TEST(RocprimBlockScanSingleValueTests, ExclusiveScanPrefixCallback) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_EQ(output[i], expected[i]); - } - - for(size_t i = 0; i < output_block_prefixes.size(); i++) - { - ASSERT_EQ(output_block_prefixes[i], expected_block_prefixes[i]); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_block_prefixes, expected_block_prefixes, 0.01)); HIP_CHECK(hipFree(device_output)); HIP_CHECK(hipFree(device_output_bp)); } -TYPED_TEST(RocprimBlockScanSingleValueTests, CustomStruct) -{ - using base_type = typename TestFixture::type; - using T = test_utils::custom_test_type; - constexpr auto algorithm = TestFixture::algorithm; - constexpr size_t block_size = TestFixture::block_size; - - // Given block size not supported - if(block_size > test_utils::get_max_block_size()) - { - return; - } - - const size_t size = block_size * 113; - const size_t grid_size = size / block_size; - // Generate data - std::vector output(size); - { - std::vector random_values = - test_utils::get_random_data(2 * output.size(), 2, 100); - for(size_t i = 0; i < output.size(); i++) - { - output[i].x = random_values[i], - output[i].y = random_values[i + output.size()]; - } - } - - // Calculate expected results on host - std::vector expected(output.size(), T(0)); - for(size_t i = 0; i < output.size() / block_size; i++) - { - for(size_t j = 0; j < block_size; j++) - { - auto idx = i * block_size + j; - expected[idx] = output[idx] + expected[j > 0 ? idx-1 : idx]; - } - } - - // Writing to device memory - T* device_output; - HIP_CHECK(hipMalloc(&device_output, output.size() * sizeof(typename decltype(output)::value_type))); - - HIP_CHECK( - hipMemcpy( - device_output, output.data(), - output.size() * sizeof(T), - hipMemcpyHostToDevice - ) - ); - - // Launching kernel - hipLaunchKernelGGL( - HIP_KERNEL_NAME(inclusive_scan_kernel), - dim3(grid_size), dim3(block_size), 0, 0, - device_output - ); - - HIP_CHECK(hipPeekAtLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - // Read from device memory - HIP_CHECK( - hipMemcpy( - output.data(), device_output, - output.size() * sizeof(T), - hipMemcpyDeviceToHost - ) - ); - - // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_EQ(output[i], expected[i]); - } - - HIP_CHECK(hipFree(device_output)); -} - // --------------------------------------------------------- // Test for scan ops taking array of values as input // --------------------------------------------------------- @@ -872,6 +765,8 @@ typedef ::testing::Types< params, params, params, + params, 110, 4>, + params, 256U, 3>, // ----------------------------------------------------------------------- // rocprim::block_scan_algorithm::reduce_then_scan // ----------------------------------------------------------------------- @@ -886,7 +781,9 @@ typedef ::testing::Types< params, params, params, - params + params, + params, 256, 5, rocprim::block_scan_algorithm::reduce_then_scan>, + params, 180, 3, rocprim::block_scan_algorithm::reduce_then_scan> > InputArrayTestParams; TYPED_TEST_CASE(RocprimBlockScanInputArrayTests, InputArrayTestParams); @@ -982,13 +879,7 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, InclusiveScan) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_NEAR( - output[i], expected[i], - static_cast(0.05) * expected[i] - ); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); HIP_CHECK(hipFree(device_output)); } @@ -1117,21 +1008,8 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, InclusiveScanReduce) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_NEAR( - output[i], expected[i], - static_cast(0.05) * expected[i] - ); - } - - for(size_t i = 0; i < output_reductions.size(); i++) - { - ASSERT_NEAR( - output_reductions[i], expected_reductions[i], - static_cast(0.05) * expected_reductions[i] - ); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_reductions, expected_reductions, 0.01)); HIP_CHECK(hipFree(device_output)); HIP_CHECK(hipFree(device_output_reductions)); @@ -1151,7 +1029,7 @@ void inclusive_scan_array_prefix_callback_kernel(T* device_output, T* device_out auto prefix_callback = [&prefix_value](T reduction) { T prefix = prefix_value; - prefix_value += reduction; + prefix_value = prefix_value + reduction; return prefix; }; @@ -1270,21 +1148,8 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, InclusiveScanPrefixCallback) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_NEAR( - output[i], expected[i], - static_cast(0.05) * expected[i] - ); - } - - for(size_t i = 0; i < output_block_prefixes.size(); i++) - { - ASSERT_NEAR( - output_block_prefixes[i], expected_block_prefixes[i], - static_cast(0.05) * expected_block_prefixes[i] - ); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_block_prefixes, expected_block_prefixes, 0.01)); HIP_CHECK(hipFree(device_output)); HIP_CHECK(hipFree(device_output_bp)); @@ -1381,13 +1246,7 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, ExclusiveScan) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_NEAR( - output[i], expected[i], - static_cast(0.05) * expected[i] - ); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); HIP_CHECK(hipFree(device_output)); } @@ -1461,7 +1320,7 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, ExclusiveScanReduce) } for(size_t j = 0; j < items_per_block; j++) { - expected_reductions[i] += output[i * items_per_block + j]; + expected_reductions[i] = expected_reductions[i] + output[i * items_per_block + j]; } } @@ -1514,21 +1373,8 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, ExclusiveScanReduce) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_NEAR( - output[i], expected[i], - static_cast(0.05) * expected[i] - ); - } - - for(size_t i = 0; i < output_reductions.size(); i++) - { - ASSERT_NEAR( - output_reductions[i], expected_reductions[i], - static_cast(0.05) * expected_reductions[i] - ); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_reductions, expected_reductions, 0.01)); } template< @@ -1549,7 +1395,7 @@ void exclusive_scan_prefix_callback_array_kernel( auto prefix_callback = [&prefix_value](T reduction) { T prefix = prefix_value; - prefix_value += reduction; + prefix_value = prefix_value + reduction; return prefix; }; @@ -1612,7 +1458,7 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, ExclusiveScanPrefixCallback) for(size_t j = 0; j < items_per_block; j++) { auto idx = i * items_per_block + j; - expected_block_prefixes[i] += output[idx]; + expected_block_prefixes[i] = expected_block_prefixes[i] + output[idx]; } } @@ -1665,21 +1511,8 @@ TYPED_TEST(RocprimBlockScanInputArrayTests, ExclusiveScanPrefixCallback) ); // Validating results - for(size_t i = 0; i < output.size(); i++) - { - ASSERT_NEAR( - output[i], expected[i], - static_cast(0.05) * expected[i] - ); - } - - for(size_t i = 0; i < output_block_prefixes.size(); i++) - { - ASSERT_NEAR( - output_block_prefixes[i], expected_block_prefixes[i], - static_cast(0.05) * expected_block_prefixes[i] - ); - } + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, 0.01)); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output_block_prefixes, expected_block_prefixes, 0.01)); HIP_CHECK(hipFree(device_output)); HIP_CHECK(hipFree(device_output_bp)); diff --git a/test/rocprim/test_hip_device_merge_sort.cpp b/test/rocprim/test_hip_device_merge_sort.cpp index f56c9bd62..7e89edb97 100644 --- a/test/rocprim/test_hip_device_merge_sort.cpp +++ b/test/rocprim/test_hip_device_merge_sort.cpp @@ -35,20 +35,21 @@ #include "test_utils.hpp" -#define HIP_CHECK(error) \ - ASSERT_EQ(static_cast(error),hipSuccess) +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) namespace rp = rocprim; // Params for tests template< class KeyType, - class ValueType = KeyType + class ValueType = KeyType, + class CompareFunction = ::rocprim::less > struct DeviceSortParams { using key_type = KeyType; using value_type = ValueType; + using compare_function = CompareFunction; }; // --------------------------------------------------------- @@ -61,26 +62,30 @@ class RocprimDeviceSortTests : public ::testing::Test public: using key_type = typename Params::key_type; using value_type = typename Params::value_type; + using compare_function = typename Params::compare_function; const bool debug_synchronous = false; }; typedef ::testing::Types< + DeviceSortParams, + DeviceSortParams>, DeviceSortParams, DeviceSortParams>, DeviceSortParams, DeviceSortParams, - DeviceSortParams, - DeviceSortParams>, + DeviceSortParams>, + DeviceSortParams>, DeviceSortParams>, - DeviceSortParams> + DeviceSortParams, test_utils::custom_test_type> > RocprimDeviceSortTestsParams; std::vector get_sizes() { std::vector sizes = { 1, 10, 53, 211, - 1024, 2048, 5096, - 34567, (1 << 17) - 1220 + 128, 256, 512, + 1024, 2048, 5000, + 34567, (1 << 17) - 1220, (1 << 20) - 123 }; const std::vector random_sizes = test_utils::get_random_data(5, 1, 100000); sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); @@ -93,15 +98,19 @@ TYPED_TEST_CASE(RocprimDeviceSortTests, RocprimDeviceSortTestsParams); TYPED_TEST(RocprimDeviceSortTests, SortKey) { using key_type = typename TestFixture::key_type; + using compare_function = typename TestFixture::compare_function; const bool debug_synchronous = TestFixture::debug_synchronous; - const std::vector sizes = get_sizes(); - for(auto size : sizes) + bool in_place = false; + + for(size_t size : get_sizes()) { hipStream_t stream = 0; // default SCOPED_TRACE(testing::Message() << "with size = " << size); + in_place = !in_place; + // Generate data std::vector input = test_utils::get_random_data(size, 0, size); std::vector output(size); @@ -109,7 +118,14 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) key_type * d_input; key_type * d_output; HIP_CHECK(hipMalloc(&d_input, input.size() * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_output, output.size() * sizeof(key_type))); + if(in_place) + { + d_output = d_input; + } + else + { + HIP_CHECK(hipMalloc(&d_output, output.size() * sizeof(key_type))); + } HIP_CHECK( hipMemcpy( d_input, input.data(), @@ -119,15 +135,17 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) ); HIP_CHECK(hipDeviceSynchronize()); + // compare function + compare_function compare_op; + // Calculate expected results on host std::vector expected(input); - std::sort( + std::stable_sort( expected.begin(), - expected.end() + expected.end(), + compare_op ); - // compare function - ::rocprim::less lesser_op; // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -136,7 +154,7 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) rocprim::merge_sort( d_temp_storage, temp_storage_size_bytes, d_input, d_output, input.size(), - lesser_op, stream, debug_synchronous + compare_op, stream, debug_synchronous ) ); @@ -152,7 +170,7 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) rocprim::merge_sort( d_temp_storage, temp_storage_size_bytes, d_input, d_output, input.size(), - lesser_op, stream, debug_synchronous + compare_op, stream, debug_synchronous ) ); HIP_CHECK(hipPeekAtLastError()); @@ -175,7 +193,10 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) } hipFree(d_input); - hipFree(d_output); + if(!in_place) + { + hipFree(d_output); + } hipFree(d_temp_storage); } } @@ -184,31 +205,39 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) { using key_type = typename TestFixture::key_type; using value_type = typename TestFixture::value_type; + using compare_function = typename TestFixture::compare_function; const bool debug_synchronous = TestFixture::debug_synchronous; - const std::vector sizes = get_sizes(); - for(auto size : sizes) + bool in_place = false; + + for(size_t size : get_sizes()) { hipStream_t stream = 0; // default SCOPED_TRACE(testing::Message() << "with size = " << size); + in_place = !in_place; + // Generate data - std::vector keys_input(size); - std::iota(keys_input.begin(), keys_input.end(), 0); - std::shuffle( - keys_input.begin(), - keys_input.end(), - std::mt19937{std::random_device{}()} - ); - std::vector values_input = test_utils::get_random_data(size, -1000, 1000); + std::vector keys_input = test_utils::get_random_data(size, 0, size); + + std::vector values_input(size); + std::iota(values_input.begin(), values_input.end(), 0); + std::vector keys_output(size); std::vector values_output(size); key_type * d_keys_input; key_type * d_keys_output; HIP_CHECK(hipMalloc(&d_keys_input, keys_input.size() * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, keys_output.size() * sizeof(key_type))); + if(in_place) + { + d_keys_output = d_keys_input; + } + else + { + HIP_CHECK(hipMalloc(&d_keys_output, keys_output.size() * sizeof(key_type))); + } HIP_CHECK( hipMemcpy( d_keys_input, keys_input.data(), @@ -221,7 +250,14 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) value_type * d_values_input; value_type * d_values_output; HIP_CHECK(hipMalloc(&d_values_input, values_input.size() * sizeof(value_type))); - HIP_CHECK(hipMalloc(&d_values_output, values_output.size() * sizeof(value_type))); + if(in_place) + { + d_values_output = d_values_input; + } + else + { + HIP_CHECK(hipMalloc(&d_values_output, values_output.size() * sizeof(value_type))); + } HIP_CHECK( hipMemcpy( d_values_input, values_input.data(), @@ -231,6 +267,9 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) ); HIP_CHECK(hipDeviceSynchronize()); + // compare function + compare_function compare_op; + // Calculate expected results on host using key_value = std::pair; std::vector expected(size); @@ -238,13 +277,12 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) { expected[i] = key_value(keys_input[i], values_input[i]); } - std::sort( + std::stable_sort( expected.begin(), - expected.end() + expected.end(), + [compare_op](const key_value& a, const key_value& b) { return compare_op(a.first, b.first); } ); - // compare function - ::rocprim::less lesser_op; // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -254,7 +292,7 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) d_temp_storage, temp_storage_size_bytes, d_keys_input, d_keys_output, d_values_input, d_values_output, keys_input.size(), - lesser_op, stream, debug_synchronous + compare_op, stream, debug_synchronous ) ); @@ -271,7 +309,7 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) d_temp_storage, temp_storage_size_bytes, d_keys_input, d_keys_output, d_values_input, d_values_output, keys_input.size(), - lesser_op, stream, debug_synchronous + compare_op, stream, debug_synchronous ) ); HIP_CHECK(hipPeekAtLastError()); @@ -302,9 +340,12 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) } hipFree(d_keys_input); - hipFree(d_keys_output); hipFree(d_values_input); - hipFree(d_values_output); + if(!in_place) + { + hipFree(d_keys_output); + hipFree(d_values_output); + } hipFree(d_temp_storage); } } diff --git a/test/rocprim/test_hip_device_radix_sort.cpp b/test/rocprim/test_hip_device_radix_sort.cpp index b5dd81333..7a3ab0ec6 100644 --- a/test/rocprim/test_hip_device_radix_sort.cpp +++ b/test/rocprim/test_hip_device_radix_sort.cpp @@ -40,7 +40,7 @@ namespace rp = rocprim; -#define HIP_CHECK(error) ASSERT_EQ(static_cast(error), hipSuccess) +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) template< class Key, @@ -154,13 +154,16 @@ TYPED_TEST(RocprimDeviceRadixSort, SortKeys) const bool debug_synchronous = false; - const std::vector sizes = get_sizes(); - for(size_t size : sizes) + bool in_place = false; + + for(size_t size : get_sizes()) { if(size > (1 << 20) && !check_huge_sizes) continue; SCOPED_TRACE(testing::Message() << "with size = " << size); + in_place = !in_place; + // Generate data std::vector keys_input; if(rp::is_floating_point::value) @@ -179,7 +182,14 @@ TYPED_TEST(RocprimDeviceRadixSort, SortKeys) key_type * d_keys_input; key_type * d_keys_output; HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + if(in_place) + { + d_keys_output = d_keys_input; + } + else + { + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + } HIP_CHECK( hipMemcpy( d_keys_input, keys_input.data(), @@ -232,8 +242,6 @@ TYPED_TEST(RocprimDeviceRadixSort, SortKeys) ); } - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); std::vector keys_output(size); HIP_CHECK( @@ -244,7 +252,12 @@ TYPED_TEST(RocprimDeviceRadixSort, SortKeys) ) ); - HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + if(!in_place) + { + HIP_CHECK(hipFree(d_keys_output)); + } ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, expected)); } @@ -263,13 +276,16 @@ TYPED_TEST(RocprimDeviceRadixSort, SortPairs) const bool debug_synchronous = false; - const std::vector sizes = get_sizes(); - for(size_t size : sizes) + bool in_place = false; + + for(size_t size : get_sizes()) { if(size > (1 << 20) && !check_huge_sizes) continue; SCOPED_TRACE(testing::Message() << "with size = " << size); + in_place = !in_place; + // Generate data std::vector keys_input; if(rp::is_floating_point::value) @@ -291,7 +307,14 @@ TYPED_TEST(RocprimDeviceRadixSort, SortPairs) key_type * d_keys_input; key_type * d_keys_output; HIP_CHECK(hipMalloc(&d_keys_input, size * sizeof(key_type))); - HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + if(in_place) + { + d_keys_output = d_keys_input; + } + else + { + HIP_CHECK(hipMalloc(&d_keys_output, size * sizeof(key_type))); + } HIP_CHECK( hipMemcpy( d_keys_input, keys_input.data(), @@ -303,7 +326,14 @@ TYPED_TEST(RocprimDeviceRadixSort, SortPairs) value_type * d_values_input; value_type * d_values_output; HIP_CHECK(hipMalloc(&d_values_input, size * sizeof(value_type))); - HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); + if(in_place) + { + d_values_output = d_values_input; + } + else + { + HIP_CHECK(hipMalloc(&d_values_output, size * sizeof(value_type))); + } HIP_CHECK( hipMemcpy( d_values_input, values_input.data(), @@ -369,9 +399,6 @@ TYPED_TEST(RocprimDeviceRadixSort, SortPairs) ); } - HIP_CHECK(hipFree(d_temporary_storage)); - HIP_CHECK(hipFree(d_keys_input)); - HIP_CHECK(hipFree(d_values_input)); std::vector keys_output(size); HIP_CHECK( @@ -391,8 +418,14 @@ TYPED_TEST(RocprimDeviceRadixSort, SortPairs) ) ); - HIP_CHECK(hipFree(d_keys_output)); - HIP_CHECK(hipFree(d_values_output)); + HIP_CHECK(hipFree(d_temporary_storage)); + HIP_CHECK(hipFree(d_keys_input)); + HIP_CHECK(hipFree(d_values_input)); + if(!in_place) + { + HIP_CHECK(hipFree(d_keys_output)); + HIP_CHECK(hipFree(d_values_output)); + } ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(keys_output, keys_expected)); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(values_output, values_expected)); diff --git a/test/rocprim/test_hip_device_scan.cpp b/test/rocprim/test_hip_device_scan.cpp index a18036637..2990d42c9 100644 --- a/test/rocprim/test_hip_device_scan.cpp +++ b/test/rocprim/test_hip_device_scan.cpp @@ -36,7 +36,7 @@ namespace rp = rocprim; -#define HIP_CHECK(error) ASSERT_EQ(static_cast(error),hipSuccess) +#define HIP_CHECK(error) ASSERT_EQ(error, hipSuccess) // Params for tests template< @@ -71,18 +71,31 @@ class RocprimDeviceScanTests : public ::testing::Test }; typedef ::testing::Types< + // Small + DeviceScanParams, DeviceScanParams, + DeviceScanParams, DeviceScanParams, + DeviceScanParams >, + DeviceScanParams, + DeviceScanParams, + // Large + DeviceScanParams >, + DeviceScanParams >, + DeviceScanParams >, + DeviceScanParams >, + DeviceScanParams >, DeviceScanParams, true>, - DeviceScanParams, DeviceScanParams >, DeviceScanParams >, + DeviceScanParams >, DeviceScanParams< test_utils::custom_test_type, test_utils::custom_test_type, rp::plus >, true >, - DeviceScanParams, - DeviceScanParams + DeviceScanParams >, + DeviceScanParams >, + DeviceScanParams > > RocprimDeviceScanTestsParams; std::vector get_sizes() @@ -90,7 +103,8 @@ std::vector get_sizes() std::vector sizes = { 1, 10, 53, 211, 1024, 2048, 5096, - 34567, (1 << 18) + 34567, (1 << 18), + (1 << 20) - 12345 }; const std::vector random_sizes = test_utils::get_random_data(3, 1, 100000); sizes.insert(sizes.end(), random_sizes.begin(), random_sizes.end()); @@ -126,7 +140,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) HIP_CHECK( rocprim::inclusive_scan( d_temp_storage, temp_storage_size_bytes, - rocprim::make_constant_iterator(345), + rocprim::make_constant_iterator(T(345)), d_checking_output, 0, scan_op_type(), stream, debug_synchronous ) @@ -139,7 +153,7 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) HIP_CHECK( rocprim::inclusive_scan( d_temp_storage, temp_storage_size_bytes, - rocprim::make_constant_iterator(345), + rocprim::make_constant_iterator(T(345)), d_checking_output, 0, scan_op_type(), stream, debug_synchronous ) diff --git a/test/rocprim/test_utils.hpp b/test/rocprim/test_utils.hpp index 65f35d6d0..6c407d584 100644 --- a/test/rocprim/test_utils.hpp +++ b/test/rocprim/test_utils.hpp @@ -346,6 +346,11 @@ struct is_custom_test_type : std::false_type { }; +template +struct is_custom_test_array_type : std::false_type +{ +}; + template struct inner_type { @@ -362,8 +367,10 @@ struct custom_test_type T x; T y; + // Non-zero values in default constructor for checking reduce and scan: + // ensure that scan_op(custom_test_type(), value) != value ROCPRIM_HOST_DEVICE inline - custom_test_type() {} + custom_test_type() : x(12), y(34) {} ROCPRIM_HOST_DEVICE inline custom_test_type(T x, T y) : x(x), y(y) {} @@ -427,6 +434,124 @@ struct custom_test_type } }; +// Custom type used in tests +template +struct custom_test_array_type +{ + using value_type = T; + static constexpr size_t size = N; + + T values[N]; + + ROCPRIM_HOST_DEVICE inline + custom_test_array_type() + { + for(size_t i = 0; i < N; i++) + { + values[i] = T(i + 1); + } + } + + ROCPRIM_HOST_DEVICE inline + custom_test_array_type(T v) + { + for(size_t i = 0; i < N; i++) + { + values[i] = v; + } + } + + template + ROCPRIM_HOST_DEVICE inline + custom_test_array_type(const custom_test_array_type& other) + { + for(size_t i = 0; i < N; i++) + { + values[i] = other.values[i]; + } + } + + ROCPRIM_HOST_DEVICE inline + ~custom_test_array_type() {} + + ROCPRIM_HOST_DEVICE inline + custom_test_array_type& operator=(const custom_test_array_type& other) + { + for(size_t i = 0; i < N; i++) + { + values[i] = other.values[i]; + } + return *this; + } + + ROCPRIM_HOST_DEVICE inline + custom_test_array_type operator+(const custom_test_array_type& other) const + { + custom_test_array_type result; + for(size_t i = 0; i < N; i++) + { + result.values[i] = values[i] + other.values[i]; + } + return result; + } + + ROCPRIM_HOST_DEVICE inline + custom_test_array_type operator-(const custom_test_array_type& other) const + { + custom_test_array_type result; + for(size_t i = 0; i < N; i++) + { + result.values[i] = values[i] - other.values[i]; + } + return result; + } + + ROCPRIM_HOST_DEVICE inline + bool operator<(const custom_test_array_type& other) const + { + for(size_t i = 0; i < N; i++) + { + if(values[i] >= other.values[i]) + { + return false; + } + } + return true; + } + + ROCPRIM_HOST_DEVICE inline + bool operator>(const custom_test_array_type& other) const + { + for(size_t i = 0; i < N; i++) + { + if(values[i] <= other.values[i]) + { + return false; + } + } + return true; + } + + ROCPRIM_HOST_DEVICE inline + bool operator==(const custom_test_array_type& other) const + { + for(size_t i = 0; i < N; i++) + { + if(values[i] != other.values[i]) + { + return false; + } + } + return true; + } + + ROCPRIM_HOST_DEVICE inline + bool operator!=(const custom_test_array_type& other) const + { + return !(*this == other); + } +}; + template inline std::ostream& operator<<(std::ostream& stream, const custom_test_type& value) @@ -435,17 +560,46 @@ std::ostream& operator<<(std::ostream& stream, return stream; } +template inline +std::ostream& operator<<(std::ostream& stream, + const custom_test_array_type& value) +{ + stream << "["; + for(size_t i = 0; i < N; i++) + { + stream << value.values[i]; + if(i != N - 1) + { + stream << "; "; + } + } + stream << "]"; + return stream; +} + template struct is_custom_test_type> : std::true_type { }; +template +struct is_custom_test_array_type> : std::true_type +{ +}; + + template struct inner_type> { using type = T; }; +template +struct inner_type> +{ + using type = T; +}; + namespace detail { template @@ -494,9 +648,35 @@ inline auto get_random_data(size_t size, typename T::value_type min, typename T: return data; } +template +inline auto get_random_data(size_t size, typename T::value_type min, typename T::value_type max) + -> typename std::enable_if< + is_custom_test_array_type::value && std::is_integral::value, + std::vector + >::type +{ + std::random_device rd; + std::default_random_engine gen(rd()); + std::uniform_int_distribution distribution(min, max); + std::vector data(size); + std::generate( + data.begin(), data.end(), + [&]() + { + T result; + for(size_t i = 0; i < T::size; i++) + { + result.values[i] = distribution(gen); + } + return result; + } + ); + return data; +} + template inline auto get_random_value(typename T::value_type min, typename T::value_type max) - -> typename std::enable_if::value, T>::type + -> typename std::enable_if::value || is_custom_test_array_type::value, T>::type { return get_random_data(1, min, max)[0]; } @@ -504,20 +684,25 @@ inline auto get_random_value(typename T::value_type min, typename T::value_type template auto assert_near(const std::vector& result, const std::vector& expected, const float percent) - -> typename std::enable_if::value && std::is_arithmetic::value>::type + -> typename std::enable_if::value>::type { ASSERT_EQ(result.size(), expected.size()); for(size_t i = 0; i < result.size(); i++) { - if(std::is_integral::value) - { - ASSERT_EQ(result[i], expected[i]) << "where index = " << i; - } - else - { - auto diff = std::max(std::abs(percent * expected[i]), T(percent)); - ASSERT_NEAR(result[i], expected[i], diff) << "where index = " << i; - } + auto diff = std::max(std::abs(percent * expected[i]), T(percent)); + ASSERT_NEAR(result[i], expected[i], diff) << "where index = " << i; + } +} + +template +auto assert_near(const std::vector& result, const std::vector& expected, const float percent) + -> typename std::enable_if::value>::type +{ + (void)percent; + ASSERT_EQ(result.size(), expected.size()); + for(size_t i = 0; i < result.size(); i++) + { + ASSERT_EQ(result[i], expected[i]) << "where index = " << i; } } @@ -532,60 +717,44 @@ void assert_near(const std::vector& result, const std::vector -auto assert_near(const T& result, const T& expected, const float percent) - -> typename std::enable_if::value && std::is_arithmetic::value>::type +auto assert_near(const std::vector>& result, const std::vector>& expected, const float percent) + -> typename std::enable_if::value>::type { - auto diff = std::max(std::abs(percent * expected), T(percent)); - if(std::is_integral::value) diff = 0; - ASSERT_NEAR(result, expected, diff); + ASSERT_EQ(result.size(), expected.size()); + for(size_t i = 0; i < result.size(); i++) + { + auto diff1 = std::max(std::abs(percent * expected[i].x), T(percent)); + auto diff2 = std::max(std::abs(percent * expected[i].y), T(percent)); + ASSERT_NEAR(result[i].x, expected[i].x, diff1) << "where index = " << i; + ASSERT_NEAR(result[i].y, expected[i].y, diff2) << "where index = " << i; + } } - template auto assert_near(const T& result, const T& expected, const float percent) - -> typename std::enable_if::value>::type + -> typename std::enable_if::value>::type { - using value_type = typename T::value_type; - auto diff1 = std::max(std::abs(percent * expected.x), value_type(percent)); - auto diff2 = std::max(std::abs(percent * expected.y), value_type(percent)); - if(std::is_integral::value) - { - diff1 = 0; - diff2 = 0; - } - ASSERT_NEAR(result.x, expected.x, diff1); - ASSERT_NEAR(result.y, expected.y, diff2); + auto diff = std::max(std::abs(percent * expected), T(percent)); + ASSERT_NEAR(result, expected, diff); } template -auto assert_near(const std::vector& result, const std::vector& expected, const float percent) - -> typename std::enable_if::value>::type +auto assert_near(const T& result, const T& expected, const float percent) + -> typename std::enable_if::value>::type { - using value_type = typename T::value_type; - ASSERT_EQ(result.size(), expected.size()); - for(size_t i = 0; i < result.size(); i++) - { - auto diff1 = std::max(std::abs(percent * expected[i].x), value_type(percent)); - auto diff2 = std::max(std::abs(percent * expected[i].y), value_type(percent)); - if(std::is_integral::value) - { - diff1 = 0; - diff2 = 0; - } - ASSERT_NEAR(result[i].x, expected[i].x, diff1) << "where index = " << i; - ASSERT_NEAR(result[i].y, expected[i].y, diff2) << "where index = " << i; - } + (void)percent; + ASSERT_EQ(result, expected); } + template -auto assert_near(const std::vector& result, const std::vector& expected, const float) - -> typename std::enable_if::value && !std::is_arithmetic::value>::type +auto assert_near(const custom_test_type& result, const custom_test_type& expected, const float percent) + -> typename std::enable_if::value>::type { - ASSERT_EQ(result.size(), expected.size()); - for(size_t i = 0; i < result.size(); i++) - { - ASSERT_EQ(result[i], expected[i]) << "where index = " << i; - } + auto diff1 = std::max(std::abs(percent * expected.x), T(percent)); + auto diff2 = std::max(std::abs(percent * expected.y), T(percent)); + ASSERT_NEAR(result.x, expected.x, diff1); + ASSERT_NEAR(result.y, expected.y, diff2); } template