Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow mutation through a transform_iterator #2006

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions thrust/testing/transform_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <thrust/sequence.h>

#include <memory>
#include <vector>

#include <unittest/unittest.h>

Expand Down Expand Up @@ -108,3 +109,110 @@ void TestTransformIteratorNonCopyable()
}

DECLARE_UNITTEST(TestTransformIteratorNonCopyable);

struct flip_value
{
_CCCL_HOST_DEVICE bool operator()(bool b) const
{
return !b;
}
};

struct pass_ref
{
_CCCL_HOST_DEVICE const bool& operator()(const bool& b) const
{
return b;
}
};

// TODO(bgruber): replace by libc++ with C++14
struct forward
{
template <class _Tp>
constexpr _Tp&& operator()(_Tp&& __t) const noexcept
{
return _CUDA_VSTD::forward<_Tp>(__t);
}
};

void TestTransformIteratorReferenceAndValueType()
{
using ::cuda::std::is_same;
using ::cuda::std::negate;
{
thrust::host_vector<bool> v;
auto it = v.begin();
static_assert(is_same<decltype(*it), bool&>::value, "");
auto it_tr_val = thrust::make_transform_iterator(it, flip_value{});
static_assert(is_same<decltype(*it_tr_val), bool>::value, "");
auto it_tr_ref = thrust::make_transform_iterator(it, pass_ref{});
static_assert(is_same<decltype(*it_tr_ref), const bool&>::value, "");
auto it_tr_fwd = thrust::make_transform_iterator(it, forward{});
}

{
thrust::device_vector<bool> v;
auto it = v.begin();
static_assert(is_same<decltype(*it), thrust::device_reference<bool>>::value, "");
auto it_tr_val = thrust::make_transform_iterator(it, flip_value{});
static_assert(is_same<decltype(*it_tr_val), bool>::value, "");
auto it_tr_ref = thrust::make_transform_iterator(it, pass_ref{});
static_assert(is_same<decltype(*it_tr_ref), const bool&>::value, "");
auto it_tr_fwd = thrust::make_transform_iterator(it, forward{});
static_assert(is_same<decltype(*it_tr_fwd), const bool&>::value, ""); // device ref. is decayed
}

{
std::vector<bool> v;
auto it = v.begin();
static_assert(is_same<decltype(*it), std::vector<bool>::reference>::value, "");
auto it_tr_val = thrust::make_transform_iterator(it, flip_value{});
static_assert(is_same<decltype(*it_tr_val), bool>::value, "");
auto it_tr_ref = thrust::make_transform_iterator(it, pass_ref{});
static_assert(is_same<decltype(*it_tr_ref), const bool&>::value, "");
auto it_tr_fwd = thrust::make_transform_iterator(it, forward{});
static_assert(is_same<decltype(*it_tr_fwd), std::vector<bool>::reference&&>::value, ""); // No handling for std
}
}
DECLARE_UNITTEST(TestTransformIteratorReferenceAndValueType);

struct foo
{
int x, y;
};

struct access_x
{
_CCCL_HOST_DEVICE int& operator()(foo& f) const noexcept
{
return f.x;
}
};

template <template <typename...> class SrcVec, template <typename...> class DstVec = SrcVec>
void TestTransformIteratorAsDestinationWith()
{
constexpr auto n = 10;
SrcVec<int> src(n, 1234);
DstVec<foo> dst(n, foo{1, 2});

thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{}));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: What happens if this is attempted on cross-system host<->device vectors?


const thrust::host_vector<foo>& dst_h = dst; // no copy when Vec is a host vector
for (const auto& f : dst_h)
{
ASSERT_EQUAL(f.x, 1234);
ASSERT_EQUAL(f.y, 2);
}
}

void TestTransformIteratorAsDestination()
{
TestTransformIteratorAsDestinationWith<thrust::host_vector>();
TestTransformIteratorAsDestinationWith<thrust::device_vector>();

TestTransformIteratorAsDestinationWith<thrust::host_vector, thrust::device_vector>();
TestTransformIteratorAsDestinationWith<thrust::device_vector, thrust::host_vector>();
}
DECLARE_UNITTEST(TestTransformIteratorAsDestination);
2 changes: 1 addition & 1 deletion thrust/thrust/detail/reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class reference
pointer const ptr;

// `thrust::detail::is_wrapped_reference` is a trait that indicates whether
// a type is a fancy reference. It detects such types by loooking for a
// a type is a fancy reference. It detects such types by looking for a
// nested `wrapped_reference_hint` type.
struct wrapped_reference_hint
{};
Expand Down
13 changes: 6 additions & 7 deletions thrust/thrust/iterator/detail/transform_iterator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,14 @@ namespace detail
template <class UnaryFunc, class Iterator, class Reference, class Value>
struct make_transform_iterator_base
{
private:
// FIXME(bgruber): the next line should be correct, but thrust::identity<T> lies and advertises a ::return_type of T,
// while its operator() returns const T& (which __invoke_of correctly detects), which causes transform_iterator to
// crash during dereferencing.
// using wrapped_func_ret_t = ::cuda::std::__invoke_of<UnaryFunc, iterator_value_t<Iterator>>;
using wrapped_func_ret_t = result_of_adaptable_function<UnaryFunc(iterator_value_t<Iterator>)>;
using func_input_t =
::cuda::std::_If<is_wrapped_reference<::cuda::std::__decay_t<iterator_reference_t<Iterator>>>::value,
const iterator_value_t<Iterator>&, // decay the reference to the value type
iterator_reference_t<Iterator>>;
using func_return_t = ::cuda::std::__invoke_of<UnaryFunc, func_input_t>;

// By default, dereferencing the iterator yields the same as the function.
using reference = typename ia_dflt_help<Reference, wrapped_func_ret_t>::type;
using reference = typename ia_dflt_help<Reference, func_return_t>::type;
using value_type = typename ia_dflt_help<Value, remove_cvref<reference>>::type;

public:
Expand Down
23 changes: 19 additions & 4 deletions thrust/thrust/iterator/transform_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,8 @@ class transform_iterator
/*! \endcond
*/

public:
/*! Null constructor does nothing.
*/
using reference = typename super_t::reference;

transform_iterator() = default;

transform_iterator(transform_iterator const&) = default;
Expand Down Expand Up @@ -299,8 +298,24 @@ class transform_iterator
// See goo.gl/LELTNp
THRUST_DISABLE_MSVC_WARNING_BEGIN(4172)

_CCCL_HOST_DEVICE reference dereference() const
{
// TODO(bgruber): use an if constexpr in C++17
return dereference_impl(
::cuda::std::bool_constant<
detail::is_wrapped_reference<::cuda::std::__decay_t<iterator_reference_t<Iterator>>>::value>{});
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE
reference dereference_impl(::cuda::std::false_type /* iterator does not return a wrapped/proxy reference */) const
{
return m_f(*this->base());
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE typename super_t::reference dereference() const
_CCCL_HOST_DEVICE
reference dereference_impl(::cuda::std::true_type /* iterator returns a wrapped/proxy reference */) const
{
// Create a temporary to allow iterators with wrapped references to
// convert to their value type before calling m_f. Note that this
Expand Down
Loading