Skip to content

Commit

Permalink
Revert unrelated changes
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Oct 11, 2024
1 parent 9e23b41 commit f8957f1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
12 changes: 6 additions & 6 deletions cpp/include/cudf/detail/gather.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

#include <thrust/functional.h>
#include <thrust/gather.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/logical.h>

Expand Down Expand Up @@ -539,13 +538,14 @@ void gather_bitmask(table_device_view input,

constexpr size_type block_size = 256;
using Selector = gather_bitmask_functor<Op, decltype(gather_map_begin)>;
auto const selector = Selector{input, masks, gather_map_begin};
auto const counting_it = thrust::make_counting_iterator(0);
auto const mask_size_it = thrust::make_constant_iterator(mask_size);
auto selector = Selector{input, masks, gather_map_begin};
auto counting_it = thrust::make_counting_iterator(0);
auto kernel =
valid_if_n_kernel<decltype(counting_it), decltype(counting_it), Selector, block_size>;

cudf::detail::grid_1d grid{mask_size, block_size, 1};
valid_if_n_kernel<block_size><<<grid.num_blocks, block_size, 0, stream.value()>>>(
counting_it, counting_it, selector, masks, mask_count, mask_size_it, valid_counts);
kernel<<<grid.num_blocks, block_size, 0, stream.value()>>>(
counting_it, counting_it, selector, masks, mask_count, mask_size, valid_counts);
}

template <typename MapIterator>
Expand Down
16 changes: 7 additions & 9 deletions cpp/include/cudf/detail/valid_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ std::pair<rmm::device_buffer, size_type> valid_if(InputIterator begin,
* p: [](size_type col, size_type row){ return col == row; }
* masks: [[b00...], [b00...], [b00...]]
* mask_count: 3
* mask_num_bits: [2]
* mask_num_bits: 2
* valid_counts: [0, 0, 0]
*
* Example Results:
Expand All @@ -148,30 +148,28 @@ std::pair<rmm::device_buffer, size_type> valid_if(InputIterator begin,
* @param valid_counts Used to obtain the total number of valid bits for each
* mask.
*/
template <int32_t block_size,
typename InputIterator1,
template <typename InputIterator1,
typename InputIterator2,
typename MaskSizeIterator,
typename BinaryPredicate>
typename BinaryPredicate,
int32_t block_size>
CUDF_KERNEL void valid_if_n_kernel(InputIterator1 begin1,
InputIterator2 begin2,
BinaryPredicate p,
bitmask_type* masks[],
size_type mask_count,
MaskSizeIterator mask_num_bits,
size_type mask_num_bits,
size_type* valid_counts)
{
for (size_type mask_idx = 0; mask_idx < mask_count; mask_idx++) {
auto const mask = masks[mask_idx];
if (mask == nullptr) { continue; }

auto const mask_size = mask_num_bits[mask_idx];
auto block_offset = blockIdx.x * blockDim.x;
auto warp_valid_count = static_cast<size_type>(0);

while (block_offset < mask_size) {
while (block_offset < mask_num_bits) {
auto const thread_idx = block_offset + threadIdx.x;
auto const thread_active = thread_idx < mask_size;
auto const thread_active = thread_idx < mask_num_bits;
auto const arg_1 = *(begin1 + mask_idx);
auto const arg_2 = *(begin2 + thread_idx);
auto const bit_is_valid = thread_active && p(arg_1, arg_2);
Expand Down

0 comments on commit f8957f1

Please sign in to comment.