From 351a2f7326cb5250bb912b929b7c243da2810432 Mon Sep 17 00:00:00 2001 From: Anton Gorenko Date: Tue, 14 May 2019 16:48:35 +0600 Subject: [PATCH] Implement invoke_result for device-only lambdas invoke_result is based on https://en.cppreference.com/w/cpp/types/result_of The main difference is using ROCPRIM_HOST_DEVICE, this allows to use invoke_result with device-only lambdas/functors in host-only functions on HIP-clang. Workarounds with ROCPRIM_HOST_DEVICE for such lambdas are removed. --- .../rocprim/detail/match_result_type.hpp | 85 ++++++++++++++----- .../device/detail/device_transform.hpp | 7 +- .../rocprim/device/device_binary_search.hpp | 5 -- .../rocprim/device/device_scan_by_key.hpp | 5 -- .../rocprim/device/device_segmented_scan.hpp | 5 -- .../rocprim/device/device_transform.hpp | 7 +- .../rocprim/iterator/transform_iterator.hpp | 10 +-- 7 files changed, 72 insertions(+), 52 deletions(-) diff --git a/rocprim/include/rocprim/detail/match_result_type.hpp b/rocprim/include/rocprim/detail/match_result_type.hpp index 75ba31b01..e143a8aac 100644 --- a/rocprim/include/rocprim/detail/match_result_type.hpp +++ b/rocprim/include/rocprim/detail/match_result_type.hpp @@ -24,38 +24,85 @@ #include #include "../config.hpp" -#include "../types/tuple.hpp" BEGIN_ROCPRIM_NAMESPACE namespace detail { -// tuple_contains_type::value is false if Tuple is not rocprim::tuple<> or Tuple is -// rocprim::tuple<> class which does not contain element of type T; otherwise it's true. -template -struct tuple_contains_type : std::false_type {}; +// invoke_result is based on https://en.cppreference.com/w/cpp/types/result_of +// The main difference is using ROCPRIM_HOST_DEVICE, this allows to +// use invoke_result with device-only lambdas/functors in host-only functions +// on HIP-clang. + +template +struct is_reference_wrapper : std::false_type {}; +template +struct is_reference_wrapper> : std::true_type {}; template -struct tuple_contains_type> : std::false_type {}; +struct invoke_impl { + template + ROCPRIM_HOST_DEVICE + static auto call(F&& f, Args&&... args) + -> decltype(std::forward(f)(std::forward(args)...)); +}; + +template +struct invoke_impl +{ + template::type, + class = typename std::enable_if::value>::type + > + ROCPRIM_HOST_DEVICE + static auto get(T&& t) -> T&&; + + template::type, + class = typename std::enable_if::value>::type + > + ROCPRIM_HOST_DEVICE + static auto get(T&& t) -> decltype(t.get()); + + template::type, + class = typename std::enable_if::value>::type, + class = typename std::enable_if::value>::type + > + ROCPRIM_HOST_DEVICE + static auto get(T&& t) -> decltype(*std::forward(t)); -template -struct tuple_contains_type> : tuple_contains_type> {}; + template::value>::type + > + ROCPRIM_HOST_DEVICE + static auto call(MT1 B::*pmf, T&& t, Args&&... args) + -> decltype((invoke_impl::get(std::forward(t)).*pmf)(std::forward(args)...)); + + template + ROCPRIM_HOST_DEVICE + static auto call(MT B::*pmd, T&& t) + -> decltype(invoke_impl::get(std::forward(t)).*pmd); +}; + +template::type> +ROCPRIM_HOST_DEVICE +auto INVOKE(F&& f, Args&&... args) + -> decltype(invoke_impl::call(std::forward(f), std::forward(args)...)); + +// Conforming C++14 implementation (is also a valid C++11 implementation): +template +struct invoke_result_impl { }; +template +struct invoke_result_impl(), std::declval()...))), F, Args...> +{ + using type = decltype(INVOKE(std::declval(), std::declval()...)); +}; -template -struct tuple_contains_type> : std::true_type {}; +template +struct invoke_result : invoke_result_impl {}; template struct match_result_type { -private: - #ifdef __cpp_lib_is_invocable - using binary_result_type = typename std::invoke_result::type; - #else - using binary_result_type = typename std::result_of::type; - #endif - -public: - using type = binary_result_type; + using type = typename invoke_result::type; }; } // end namespace detail diff --git a/rocprim/include/rocprim/device/detail/device_transform.hpp b/rocprim/include/rocprim/device/detail/device_transform.hpp index 24febc211..3665daded 100644 --- a/rocprim/include/rocprim/device/detail/device_transform.hpp +++ b/rocprim/include/rocprim/device/detail/device_transform.hpp @@ -26,6 +26,7 @@ #include "../../config.hpp" #include "../../detail/various.hpp" +#include "../../detail/match_result_type.hpp" #include "../../intrinsics.hpp" #include "../../functional.hpp" @@ -44,11 +45,7 @@ namespace detail template struct unpack_binary_op { - #ifdef __cpp_lib_is_invocable - using result_type = typename std::invoke_result::type; - #else - using result_type = typename std::result_of::type; - #endif + using result_type = typename ::rocprim::detail::invoke_result::type; ROCPRIM_HOST_DEVICE inline unpack_binary_op() = default; diff --git a/rocprim/include/rocprim/device/device_binary_search.hpp b/rocprim/include/rocprim/device/device_binary_search.hpp index 8cdfd2a64..5f655b227 100644 --- a/rocprim/include/rocprim/device/device_binary_search.hpp +++ b/rocprim/include/rocprim/device/device_binary_search.hpp @@ -74,12 +74,7 @@ hipError_t binary_search(void * temporary_storage, needles, output, needles_size, [haystack, haystack_size, search_op, compare_op] - #ifdef __HIP__ - // Workaround: hip-clang does not support std::result_of of device-only functions - ROCPRIM_HOST_DEVICE - #else ROCPRIM_DEVICE - #endif (const value_type& value) { return search_op(haystack, haystack_size, value, compare_op); diff --git a/rocprim/include/rocprim/device/device_scan_by_key.hpp b/rocprim/include/rocprim/device/device_scan_by_key.hpp index 97dad64e6..b9cee7231 100644 --- a/rocprim/include/rocprim/device/device_scan_by_key.hpp +++ b/rocprim/include/rocprim/device/device_scan_by_key.hpp @@ -164,12 +164,7 @@ hipError_t inclusive_scan_by_key(void * temporary_storage, rocprim::make_transform_iterator( rocprim::make_counting_iterator(0), [values_input, keys_input, key_compare_op] - #ifdef __HIP__ - // Workaround: hip-clang does not support std::result_of of device-only functions - ROCPRIM_HOST_DEVICE - #else ROCPRIM_DEVICE - #endif (const size_t i) { flag_type flag(true); diff --git a/rocprim/include/rocprim/device/device_segmented_scan.hpp b/rocprim/include/rocprim/device/device_segmented_scan.hpp index e171c518b..7c3a9e967 100644 --- a/rocprim/include/rocprim/device/device_segmented_scan.hpp +++ b/rocprim/include/rocprim/device/device_segmented_scan.hpp @@ -614,12 +614,7 @@ hipError_t segmented_exclusive_scan(void * temporary_storage, rocprim::make_transform_iterator( rocprim::make_counting_iterator(0), [input, head_flags, initial_value_converted, size] - #ifdef __HIP__ - // Workaround: hip-clang does not support std::result_of of device-only functions - ROCPRIM_HOST_DEVICE - #else ROCPRIM_DEVICE - #endif (const size_t i) { flag_type flag(false); diff --git a/rocprim/include/rocprim/device/device_transform.hpp b/rocprim/include/rocprim/device/device_transform.hpp index 2779a85a6..bccf3fb78 100644 --- a/rocprim/include/rocprim/device/device_transform.hpp +++ b/rocprim/include/rocprim/device/device_transform.hpp @@ -26,6 +26,7 @@ #include "../config.hpp" #include "../detail/various.hpp" +#include "../detail/match_result_type.hpp" #include "../types/tuple.hpp" #include "../iterator/zip_iterator.hpp" @@ -145,11 +146,7 @@ hipError_t transform(InputIterator input, bool debug_synchronous = false) { using input_type = typename std::iterator_traits::value_type; - #ifdef __cpp_lib_is_invocable - using result_type = typename std::invoke_result::type; - #else - using result_type = typename std::result_of::type; - #endif + using result_type = typename ::rocprim::detail::invoke_result::type; // Get default config if Config is default_config using config = detail::default_or_custom_config< diff --git a/rocprim/include/rocprim/iterator/transform_iterator.hpp b/rocprim/include/rocprim/iterator/transform_iterator.hpp index d53c20738..2880a1ea0 100644 --- a/rocprim/include/rocprim/iterator/transform_iterator.hpp +++ b/rocprim/include/rocprim/iterator/transform_iterator.hpp @@ -26,6 +26,7 @@ #include #include "../config.hpp" +#include "../detail/match_result_type.hpp" /// \addtogroup iteratormodule /// @{ @@ -49,17 +50,10 @@ BEGIN_ROCPRIM_NAMESPACE template< class InputIterator, class UnaryFunction, -#if defined(__cpp_lib_is_invocable) && !defined(DOXYGEN_SHOULD_SKIP_THIS) // C++17 class ValueType = - typename std::invoke_result< + typename ::rocprim::detail::invoke_result< UnaryFunction, typename std::iterator_traits::value_type >::type -#else - class ValueType = - typename std::result_of< - UnaryFunction(typename std::iterator_traits::value_type) - >::type -#endif > class transform_iterator {