Skip to content

Commit

Permalink
Gather
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 20, 2023
1 parent e3d5d9a commit 0bb2e59
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 28 deletions.
23 changes: 8 additions & 15 deletions test/catch2_test_block_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <cub/block/block_radix_sort.cuh>

#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/sequence.h>

Expand Down Expand Up @@ -424,12 +425,8 @@ radix_sort_reference(const thrust::device_vector<KeyT> &d_keys,
thrust::host_vector<KeyT> h_keys(d_keys);
thrust::host_vector<std::size_t> h_permutation =
get_permutation(h_keys, is_descending, begin_bit, end_bit);

thrust::host_vector<KeyT> result(d_keys.size());
std::transform(h_permutation.begin(),
h_permutation.end(),
result.begin(),
[&](std::size_t i) { return h_keys[i]; });
thrust::gather(h_permutation.cbegin(), h_permutation.cend(), h_keys.cbegin(), result.begin());

return result;
}
Expand All @@ -447,18 +444,14 @@ radix_sort_reference(const thrust::device_vector<KeyT> &d_keys,
result.second.resize(d_keys.size());

thrust::host_vector<KeyT> h_keys(d_keys);
thrust::host_vector<std::size_t> h_permutation = get_permutation(h_keys, is_descending, begin_bit, end_bit);

std::transform(h_permutation.begin(),
h_permutation.end(),
result.first.begin(),
[&](std::size_t i) { return h_keys[i]; });
thrust::host_vector<std::size_t> h_permutation =
get_permutation(h_keys, is_descending, begin_bit, end_bit);

thrust::host_vector<ValueT> h_values(d_values);
std::transform(h_permutation.begin(),
h_permutation.end(),
result.second.begin(),
[&](std::size_t i) { return h_values[i]; });
thrust::gather(h_permutation.cbegin(),
h_permutation.cend(),
thrust::make_zip_iterator(h_keys.cbegin(), h_values.cbegin()),
thrust::make_zip_iterator(result.first.begin(), result.second.begin()));

return result;
}
19 changes: 6 additions & 13 deletions test/catch2_test_device_radix_sort_custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <thrust/detail/raw_pointer_cast.h>
#include <thrust/device_vector.h>
#include <thrust/gather.h>
#include <thrust/host_vector.h>
#include <thrust/reverse.h>
#include <thrust/sequence.h>
Expand Down Expand Up @@ -153,11 +154,8 @@ static thrust::device_vector<key> reference_sort_keys(const thrust::device_vecto
thrust::host_vector<key> h_keys(d_keys);
thrust::host_vector<std::size_t> h_permutation =
get_permutation(h_keys, is_descending, begin_bit, end_bit);

thrust::host_vector<key> result(d_keys.size());
std::transform(h_permutation.begin(), h_permutation.end(), result.begin(), [&](std::size_t i) {
return h_keys[i];
});
thrust::gather(h_permutation.cbegin(), h_permutation.cend(), h_keys.cbegin(), result.begin());
return result;
}

Expand All @@ -174,16 +172,11 @@ reference_sort_pairs(const thrust::device_vector<key> &d_keys,
get_permutation(h_keys, is_descending, begin_bit, end_bit);

thrust::host_vector<key> result_keys(d_keys.size());
std::transform(h_permutation.begin(),
h_permutation.end(),
result_keys.begin(),
[&](std::size_t i) { return h_keys[i]; });

thrust::host_vector<value> result_values(d_values.size());
std::transform(h_permutation.begin(),
h_permutation.end(),
result_values.begin(),
[&](std::size_t i) { return h_values[i]; });
thrust::gather(h_permutation.cbegin(),
h_permutation.cend(),
thrust::make_zip_iterator(h_keys.cbegin(), h_values.cbegin()),
thrust::make_zip_iterator(result_keys.begin(), result_values.begin()));

return std::make_pair(result_keys, result_values);
}
Expand Down

0 comments on commit 0bb2e59

Please sign in to comment.