Skip to content

Commit

Permalink
Several improvements to zip_iterator/zip_function (#1710)
Browse files Browse the repository at this point in the history
* Improve zip_iterator documentation
* Re-expose zip_iterator's IteratorTuple
* Test zip_iterator construction from iterator tuple
* Expose zip_function's underlying function
* Add default ctor for zip_function
* Simplify zip_function documentation example
  • Loading branch information
bernhardmgruber committed May 7, 2024
1 parent 3104dd0 commit fb83b4a
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 54 deletions.
14 changes: 14 additions & 0 deletions thrust/testing/zip_function.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ struct SumThreeTuple
THRUST_DECLTYPE_RETURNS(thrust::get<0>(x) + thrust::get<1>(x) + thrust::get<2>(x))
}; // end SumThreeTuple

template <typename T>
struct TestZipFunctionCtor
{
void operator()()
{
ASSERT_EQUAL(thrust::zip_function<SumThree>()(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
ASSERT_EQUAL(thrust::zip_function<SumThree>(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
# ifdef __cpp_deduction_guides
ASSERT_EQUAL(thrust::zip_function(SumThree{})(thrust::make_tuple(1, 2, 3)), SumThree{}(1, 2, 3));
# endif // __cpp_deduction_guides
}
};
SimpleUnitTest<TestZipFunctionCtor, type_list<int>> TestZipFunctionCtorInstance;

template <typename T>
struct TestZipFunctionTransform
{
Expand Down
1 change: 1 addition & 0 deletions thrust/testing/zip_iterator.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct TestZipIteratorManipulation

// test construction
ZipIterator iter0 = make_zip_iterator(t);
ASSERT_EQUAL(true, iter0 == ZipIterator{t});

ASSERT_EQUAL_QUIET(v0.begin(), get<0>(iter0.get_iterator_tuple()));
ASSERT_EQUAL_QUIET(v1.begin(), get<1>(iter0.get_iterator_tuple()));
Expand Down
40 changes: 16 additions & 24 deletions thrust/thrust/iterator/zip_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,25 +69,20 @@ THRUST_NAMESPACE_BEGIN
* #include <thrust/tuple.h>
* #include <thrust/device_vector.h>
* ...
* thrust::device_vector<int> int_v(3);
* int_v[0] = 0; int_v[1] = 1; int_v[2] = 2;
* thrust::device_vector<int> int_v{0, 1, 2};
* thrust::device_vector<float> float_v{0.0f, 1.0f, 2.0f};
* thrust::device_vector<char> char_v{'a', 'b', 'c'};
*
* thrust::device_vector<float> float_v(3);
* float_v[0] = 0.0f; float_v[1] = 1.0f; float_v[2] = 2.0f;
* // aliases for iterators
* using IntIterator = thrust::device_vector<int>::iterator;
* using FloatIterator = thrust::device_vector<float>::iterator;
* using CharIterator = thrust::device_vector<char>::iterator;
*
* thrust::device_vector<char> char_v(3);
* char_v[0] = 'a'; char_v[1] = 'b'; char_v[2] = 'c';
*
* // typedef these iterators for shorthand
* typedef thrust::device_vector<int>::iterator IntIterator;
* typedef thrust::device_vector<float>::iterator FloatIterator;
* typedef thrust::device_vector<char>::iterator CharIterator;
*
* // typedef a tuple of these iterators
* typedef thrust::tuple<IntIterator, FloatIterator, CharIterator> IteratorTuple;
* // alias for a tuple of these iterators
* using IteratorTuple = thrust::tuple<IntIterator, FloatIterator, CharIterator>;
*
* // typedef the zip_iterator of this tuple
* typedef thrust::zip_iterator<IteratorTuple> ZipIterator;
* using ZipIterator = thrust::zip_iterator<IteratorTuple>;
*
* // finally, create the zip_iterator
* ZipIterator iter(thrust::make_tuple(int_v.begin(), float_v.begin(), char_v.begin()));
Expand Down Expand Up @@ -116,15 +111,8 @@ THRUST_NAMESPACE_BEGIN
*
* int main()
* {
* thrust::device_vector<int> int_in(3), int_out(3);
* int_in[0] = 0;
* int_in[1] = 1;
* int_in[2] = 2;
*
* thrust::device_vector<float> float_in(3), float_out(3);
* float_in[0] = 0.0f;
* float_in[1] = 10.0f;
* float_in[2] = 20.0f;
* thrust::device_vector<int> int_in{0, 1, 2}, int_out(3);
* thrust::device_vector<float> float_in{0.0f, 10.0f, 20.0f}, float_out(3);
*
* thrust::copy(thrust::make_zip_iterator(thrust::make_tuple(int_in.begin(), float_in.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(int_in.end(), float_in.end())),
Expand All @@ -146,6 +134,10 @@ template <typename IteratorTuple>
class zip_iterator : public detail::zip_iterator_base<IteratorTuple>::type
{
public:
/*! The underlying iterator tuple type. Alias to zip_iterator's first template argument.
*/
using iterator_tuple = IteratorTuple;

/*! Default constructor does nothing.
*/
#if defined(_CCCL_COMPILER_MSVC_2017)
Expand Down
55 changes: 25 additions & 30 deletions thrust/thrust/zip_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,54 +95,40 @@ _CCCL_HOST_DEVICE auto apply_impl(Function&& func, Tuple&& args, index_sequence<
* #include <thrust/zip_function.h>
*
* struct SumTuple {
* float operator()(Tuple tup) {
* return std::get<0>(tup) + std::get<1>(tup) + std::get<2>(tup);
* float operator()(auto tup) const {
* return thrust::get<0>(tup) + thrust::get<1>(tup) + thrust::get<2>(tup);
* }
* };
* struct SumArgs {
* float operator()(float a, float b, float c) {
* float operator()(float a, float b, float c) const {
* return a + b + c;
* }
* };
*
* int main() {
* thrust::device_vector<float> A(3);
* thrust::device_vector<float> B(3);
* thrust::device_vector<float> C(3);
* thrust::device_vector<float> A{0.f, 1.f, 2.f};
* thrust::device_vector<float> B{1.f, 2.f, 3.f};
* thrust::device_vector<float> C{2.f, 3.f, 4.f};
* thrust::device_vector<float> D(3);
* A[0] = 0.f; A[1] = 1.f; A[2] = 2.f;
* B[0] = 1.f; B[1] = 2.f; B[2] = 3.f;
* C[0] = 2.f; C[1] = 3.f; C[2] = 4.f;
*
* // The following four invocations of transform are equivalent
* auto begin = thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin()));
* auto end = thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end()));
*
* // The following four invocations of transform are equivalent:
* // Transform with 3-tuple
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* SumTuple{});
* thrust::transform(begin, end, D.begin(), SumTuple{});
*
* // Transform with 3 parameters
* thrust::zip_function<SumArgs> adapted{};
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* adapted);
* thrust::transform(begin, end, D.begin(), adapted);
*
* // Transform with 3 parameters with convenience function
* thrust::zip_function<SumArgs> adapted{};
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* thrust::make_zip_function(SumArgs{}));
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function(SumArgs{}));
*
* // Transform with 3 parameters with convenience function and lambda
* thrust::zip_function<SumArgs> adapted{};
* thrust::transform(thrust::make_zip_iterator(thrust::make_tuple(A.begin(), B.begin(), C.begin())),
* thrust::make_zip_iterator(thrust::make_tuple(A.end(), B.end(), C.end())),
* D.begin(),
* thrust::make_zip_function([] (float a, float b, float c) {
* return a + b + c;
* }));
* thrust::transform(begin, end, D.begin(), thrust::make_zip_function([] (float a, float b, float c) {
* return a + b + c;
* }));
* return 0;
* }
* \endcode
Expand All @@ -154,6 +140,9 @@ template <typename Function>
class zip_function
{
public:
//! Default constructs the contained function object.
zip_function() = default;

_CCCL_HOST_DEVICE zip_function(Function func)
: func(std::move(func))
{}
Expand Down Expand Up @@ -181,6 +170,12 @@ class zip_function

# endif // _CCCL_STD_VER

//! Returns a reference to the underlying function.
_CCCL_HOST_DEVICE Function& underlying_function() const
{
return func;
}

private:
mutable Function func;
};
Expand Down

0 comments on commit fb83b4a

Please sign in to comment.