Skip to content

Commit

Permalink
Merge pull request #70 from ROCmSoftwarePlatform/develop_stream_20190515
Browse files Browse the repository at this point in the history
Implement invoke_result for device-only lambdas
  • Loading branch information
saadrahim committed May 16, 2019
2 parents 9962bf8 + d2d2d3a commit 360401f
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 52 deletions.
85 changes: 66 additions & 19 deletions rocprim/include/rocprim/detail/match_result_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,85 @@
#include <type_traits>

#include "../config.hpp"
#include "../types/tuple.hpp"

BEGIN_ROCPRIM_NAMESPACE
namespace detail
{

// tuple_contains_type::value is false if Tuple is not rocprim::tuple<> or Tuple is
// rocprim::tuple<> class which does not contain element of type T; otherwise it's true.
template<class T, class Tuple>
struct tuple_contains_type : std::false_type {};
// invoke_result is based on https://en.cppreference.com/w/cpp/types/result_of
// The main difference is using ROCPRIM_HOST_DEVICE, this allows to
// use invoke_result with device-only lambdas/functors in host-only functions
// on HIP-clang.

template <class T>
struct is_reference_wrapper : std::false_type {};
template <class U>
struct is_reference_wrapper<std::reference_wrapper<U>> : std::true_type {};

template<class T>
struct tuple_contains_type<T, ::rocprim::tuple<>> : std::false_type {};
struct invoke_impl {
template<class F, class... Args>
ROCPRIM_HOST_DEVICE
static auto call(F&& f, Args&&... args)
-> decltype(std::forward<F>(f)(std::forward<Args>(args)...));
};

template<class B, class MT>
struct invoke_impl<MT B::*>
{
template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<std::is_base_of<B, Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> T&&;

template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<is_reference_wrapper<Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> decltype(t.get());

template<class T, class Td = typename std::decay<T>::type,
class = typename std::enable_if<!std::is_base_of<B, Td>::value>::type,
class = typename std::enable_if<!is_reference_wrapper<Td>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto get(T&& t) -> decltype(*std::forward<T>(t));

template<class T, class U, class... Ts>
struct tuple_contains_type<T, ::rocprim::tuple<U, Ts...>> : tuple_contains_type<T, ::rocprim::tuple<Ts...>> {};
template<class T, class... Args, class MT1,
class = typename std::enable_if<std::is_function<MT1>::value>::type
>
ROCPRIM_HOST_DEVICE
static auto call(MT1 B::*pmf, T&& t, Args&&... args)
-> decltype((invoke_impl::get(std::forward<T>(t)).*pmf)(std::forward<Args>(args)...));

template<class T>
ROCPRIM_HOST_DEVICE
static auto call(MT B::*pmd, T&& t)
-> decltype(invoke_impl::get(std::forward<T>(t)).*pmd);
};

template<class F, class... Args, class Fd = typename std::decay<F>::type>
ROCPRIM_HOST_DEVICE
auto INVOKE(F&& f, Args&&... args)
-> decltype(invoke_impl<Fd>::call(std::forward<F>(f), std::forward<Args>(args)...));

// Conforming C++14 implementation (is also a valid C++11 implementation):
template <typename AlwaysVoid, typename, typename...>
struct invoke_result_impl { };
template <typename F, typename...Args>
struct invoke_result_impl<decltype(void(INVOKE(std::declval<F>(), std::declval<Args>()...))), F, Args...>
{
using type = decltype(INVOKE(std::declval<F>(), std::declval<Args>()...));
};

template<class T, class... Ts>
struct tuple_contains_type<T, ::rocprim::tuple<T, Ts...>> : std::true_type {};
template <class F, class... ArgTypes>
struct invoke_result : invoke_result_impl<void, F, ArgTypes...> {};

template<class InputType, class BinaryFunction>
struct match_result_type
{
private:
#ifdef __cpp_lib_is_invocable
using binary_result_type = typename std::invoke_result<BinaryFunction, InputType, InputType>::type;
#else
using binary_result_type = typename std::result_of<BinaryFunction(InputType, InputType)>::type;
#endif

public:
using type = binary_result_type;
using type = typename invoke_result<BinaryFunction, InputType, InputType>::type;
};

} // end namespace detail
Expand Down
7 changes: 2 additions & 5 deletions rocprim/include/rocprim/device/detail/device_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "../../config.hpp"
#include "../../detail/various.hpp"
#include "../../detail/match_result_type.hpp"

#include "../../intrinsics.hpp"
#include "../../functional.hpp"
Expand All @@ -44,11 +45,7 @@ namespace detail
template<class T1, class T2, class BinaryFunction>
struct unpack_binary_op
{
#ifdef __cpp_lib_is_invocable
using result_type = typename std::invoke_result<BinaryFunction, T1, T2>::type;
#else
using result_type = typename std::result_of<BinaryFunction(T1, T2)>::type;
#endif
using result_type = typename ::rocprim::detail::invoke_result<BinaryFunction, T1, T2>::type;

ROCPRIM_HOST_DEVICE inline
unpack_binary_op() = default;
Expand Down
5 changes: 0 additions & 5 deletions rocprim/include/rocprim/device/device_binary_search.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,7 @@ hipError_t binary_search(void * temporary_storage,
needles, output,
needles_size,
[haystack, haystack_size, search_op, compare_op]
#ifdef __HIP__
// Workaround: hip-clang does not support std::result_of of device-only functions
ROCPRIM_HOST_DEVICE
#else
ROCPRIM_DEVICE
#endif
(const value_type& value)
{
return search_op(haystack, haystack_size, value, compare_op);
Expand Down
5 changes: 0 additions & 5 deletions rocprim/include/rocprim/device/device_scan_by_key.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,7 @@ hipError_t inclusive_scan_by_key(void * temporary_storage,
rocprim::make_transform_iterator(
rocprim::make_counting_iterator<size_t>(0),
[values_input, keys_input, key_compare_op]
#ifdef __HIP__
// Workaround: hip-clang does not support std::result_of of device-only functions
ROCPRIM_HOST_DEVICE
#else
ROCPRIM_DEVICE
#endif
(const size_t i)
{
flag_type flag(true);
Expand Down
5 changes: 0 additions & 5 deletions rocprim/include/rocprim/device/device_segmented_scan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,12 +614,7 @@ hipError_t segmented_exclusive_scan(void * temporary_storage,
rocprim::make_transform_iterator(
rocprim::make_counting_iterator<size_t>(0),
[input, head_flags, initial_value_converted, size]
#ifdef __HIP__
// Workaround: hip-clang does not support std::result_of of device-only functions
ROCPRIM_HOST_DEVICE
#else
ROCPRIM_DEVICE
#endif
(const size_t i)
{
flag_type flag(false);
Expand Down
7 changes: 2 additions & 5 deletions rocprim/include/rocprim/device/device_transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "../config.hpp"
#include "../detail/various.hpp"
#include "../detail/match_result_type.hpp"
#include "../types/tuple.hpp"
#include "../iterator/zip_iterator.hpp"

Expand Down Expand Up @@ -145,11 +146,7 @@ hipError_t transform(InputIterator input,
bool debug_synchronous = false)
{
using input_type = typename std::iterator_traits<InputIterator>::value_type;
#ifdef __cpp_lib_is_invocable
using result_type = typename std::invoke_result<UnaryFunction, input_type>::type;
#else
using result_type = typename std::result_of<UnaryFunction(input_type)>::type;
#endif
using result_type = typename ::rocprim::detail::invoke_result<UnaryFunction, input_type>::type;

// Get default config if Config is default_config
using config = detail::default_or_custom_config<
Expand Down
10 changes: 2 additions & 8 deletions rocprim/include/rocprim/iterator/transform_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <type_traits>

#include "../config.hpp"
#include "../detail/match_result_type.hpp"

/// \addtogroup iteratormodule
/// @{
Expand All @@ -49,17 +50,10 @@ BEGIN_ROCPRIM_NAMESPACE
template<
class InputIterator,
class UnaryFunction,
#if defined(__cpp_lib_is_invocable) && !defined(DOXYGEN_SHOULD_SKIP_THIS) // C++17
class ValueType =
typename std::invoke_result<
typename ::rocprim::detail::invoke_result<
UnaryFunction, typename std::iterator_traits<InputIterator>::value_type
>::type
#else
class ValueType =
typename std::result_of<
UnaryFunction(typename std::iterator_traits<InputIterator>::value_type)
>::type
#endif
>
class transform_iterator
{
Expand Down

0 comments on commit 360401f

Please sign in to comment.