Skip to content

Commit

Permalink
Allow mutation through a transform_iterator
Browse files Browse the repository at this point in the history
But only if the transform iterator's base iterator returns a true l-value reference (and not a proxy reference).
  • Loading branch information
bernhardmgruber committed Jul 18, 2024
1 parent 56d99db commit ec3c5ae
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
30 changes: 30 additions & 0 deletions thrust/testing/transform_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,33 @@ void TestTransformIteratorNonCopyable()
}

DECLARE_UNITTEST(TestTransformIteratorNonCopyable);

struct foo
{
int x, y;
};

struct access_x
{
_CCCL_HOST_DEVICE int& operator()(foo& f) const noexcept
{
return f.x;
}
};

void TestTransformIteratorAsDestination()
{
constexpr auto n = 10;
thrust::host_vector<int> src(n, 1234);
thrust::host_vector<foo> dst(n, foo{1, 2});

thrust::copy(src.begin(), src.end(), thrust::make_transform_iterator(dst.begin(), access_x{}));

for (const auto& f : dst)
{
ASSERT_EQUAL(f.x, 1234);
ASSERT_EQUAL(f.y, 2);
}
}

DECLARE_UNITTEST(TestTransformIteratorAsDestination);
10 changes: 7 additions & 3 deletions thrust/thrust/iterator/detail/transform_iterator.inl
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ template <class UnaryFunc, class Iterator, class Reference, class Value>
struct transform_iterator_base
{
private:
using unary_func_input_t = ::cuda::std::_If<
::cuda::std::is_same<::cuda::std::__decay_t<iterator_reference_t<Iterator>>, iterator_value_t<Iterator>>::value,
iterator_reference_t<Iterator>,
const iterator_value_t<Iterator>&>;

// By default, dereferencing the iterator yields the same as the function.
using reference = typename thrust::detail::ia_dflt_help<
Reference,
thrust::detail::result_of_adaptable_function<UnaryFunc(typename thrust::iterator_value<Iterator>::type)>>::type;
using reference = typename thrust::detail::
ia_dflt_help<Reference, thrust::detail::result_of_adaptable_function<UnaryFunc(unary_func_input_t)>>::type;

// To get the default for Value: remove cvref on the result type.
using value_type = typename thrust::detail::ia_dflt_help<Value, thrust::remove_cvref<reference>>::type;
Expand Down
17 changes: 16 additions & 1 deletion thrust/thrust/iterator/transform_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,23 @@ class transform_iterator
// See goo.gl/LELTNp
THRUST_DISABLE_MSVC_WARNING_BEGIN(4172)

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE typename super_t::reference dereference() const
{
// TODO(bgruber): use an if constexpr in C++17
return dereference_impl(
::cuda::std::is_same<::cuda::std::__decay_t<iterator_reference_t<Iterator>>, iterator_value_t<Iterator>>{});
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE
typename super_t::reference dereference_impl(::cuda::std::true_type /* iterator returns a T& */) const
{
return m_f(*this->base());
}

_CCCL_EXEC_CHECK_DISABLE
_CCCL_HOST_DEVICE
typename super_t::reference dereference_impl(::cuda::std::false_type /* iterator returns a proxy ref */) const
{
// Create a temporary to allow iterators with wrapped references to
// convert to their value type before calling m_f. Note that this
Expand Down

0 comments on commit ec3c5ae

Please sign in to comment.