Skip to content

Commit

Permalink
Add thrust::inclusive_scan cuda par with init value
Browse files Browse the repository at this point in the history
  • Loading branch information
gonidelis committed Jul 17, 2024
1 parent 5a95b2b commit fa8c7e3
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 0 deletions.
103 changes: 103 additions & 0 deletions thrust/thrust/system/cuda/detail/async/inclusive_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,92 @@ async_inclusive_scan_n(execution_policy<DerivedPolicy>& policy, ForwardIt first,
return ev;
}

template <typename DerivedPolicy,
typename ForwardIt,
typename Size,
typename OutputIt,
typename InitialValueType,
typename BinaryOp>
unique_eager_event async_inclusive_scan_n(
execution_policy<DerivedPolicy>& policy, ForwardIt first, Size n, OutputIt out, InitialValueType init, BinaryOp op)
{
using InputValueT = cub::detail::InputValue<InitialValueType>;
using AccumT = typename thrust::iterator_traits<ForwardIt>::value_type;
constexpr bool ForceInclusive = true;

using Dispatch32 =
cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
InputValueT,
thrust::detail::int32_t,
InitialValueType,
cub::DeviceScanPolicy<AccumT, BinaryOp>,
ForceInclusive>;
using Dispatch64 =
cub::DispatchScan<ForwardIt,
OutputIt,
BinaryOp,
InputValueT,
thrust::detail::int64_t,
InitialValueType,
cub::DeviceScanPolicy<AccumT, BinaryOp>,
ForceInclusive>;

InputValueT init_value(init);

auto const device_alloc = get_async_device_allocator(policy);
unique_eager_event ev;

// Determine temporary device storage requirements.
cudaError_t status;
size_t tmp_size = 0;
{
THRUST_INDEX_TYPE_DISPATCH2(
status,
Dispatch32::Dispatch,
Dispatch64::Dispatch,
n,
(nullptr, tmp_size, first, out, op, init_value, n_fixed, nullptr));
thrust::cuda_cub::throw_on_error(
status,
"after determining tmp storage "
"requirements for inclusive_scan");
}

// Allocate temporary storage.
auto content = uninitialized_allocate_unique_n<thrust::detail::uint8_t>(device_alloc, tmp_size);
void* const tmp_ptr = raw_pointer_cast(content.get());

// Set up stream with dependencies.
cudaStream_t const user_raw_stream = thrust::cuda_cub::stream(policy);

if (thrust::cuda_cub::default_stream() != user_raw_stream)
{
ev = make_dependent_event(
std::tuple_cat(std::make_tuple(std::move(content), unique_stream(nonowning, user_raw_stream)),
extract_dependencies(std::move(thrust::detail::derived_cast(policy)))));
}
else
{
ev = make_dependent_event(std::tuple_cat(
std::make_tuple(std::move(content)), extract_dependencies(std::move(thrust::detail::derived_cast(policy)))));
}

// Run scan.
{
THRUST_INDEX_TYPE_DISPATCH2(
status,
Dispatch32::Dispatch,
Dispatch64::Dispatch,
n,
(tmp_ptr, tmp_size, first, out, op, init_value, n_fixed, user_raw_stream));
thrust::cuda_cub::throw_on_error(status, "after dispatching inclusive_scan kernel");
}

return ev;
}

} // namespace detail
} // namespace cuda
} // namespace system
Expand All @@ -140,6 +226,23 @@ auto async_inclusive_scan(
THRUST_RETURNS(thrust::system::cuda::detail::async_inclusive_scan_n(
policy, first, distance(first, THRUST_FWD(last)), THRUST_FWD(out), THRUST_FWD(op)))

// ADL entry point.
template <typename DerivedPolicy,
typename ForwardIt,
typename Sentinel,
typename OutputIt,
typename InitialValueType,
typename BinaryOp>
auto async_inclusive_scan(
execution_policy<DerivedPolicy>& policy,
ForwardIt first,
Sentinel&& last,
OutputIt&& out,
InitialValueType&& init,
BinaryOp&& op)
THRUST_RETURNS(thrust::system::cuda::detail::async_inclusive_scan_n(
policy, first, distance(first, THRUST_FWD(last)), THRUST_FWD(out), THRUST_FWD(init), THRUST_FWD(op)))

} // namespace cuda_cub

THRUST_NAMESPACE_END
Expand Down
99 changes: 99 additions & 0 deletions thrust/thrust/system/cuda/detail/scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,76 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
return result + num_items;
}

_CCCL_EXEC_CHECK_DISABLE
template <typename Derived, typename InputIt, typename Size, typename OutputIt, typename InitValueT, typename ScanOp>
_CCCL_HOST_DEVICE OutputIt inclusive_scan_n_impl(
thrust::cuda_cub::execution_policy<Derived>& policy,
InputIt first,
Size num_items,
OutputIt result,
InitValueT init,
ScanOp scan_op)
{
using InputValueT = cub::detail::InputValue<InitValueT>;
using OffsetT = int;
using AccumT = cub::detail::accumulator_t<ScanOp, InitValueT, cub::detail::value_t<InputIt>>;
constexpr bool ForceInclusive = true;

using Dispatch32 =
cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
InputValueT,
thrust::detail::int32_t,
InitValueT,
cub::DeviceScanPolicy<AccumT, ScanOp>,
ForceInclusive>;
using Dispatch64 =
cub::DispatchScan<InputIt,
OutputIt,
ScanOp,
InputValueT,
thrust::detail::int64_t,
InitValueT,
cub::DeviceScanPolicy<AccumT, ScanOp>,
ForceInclusive>;

cudaStream_t stream = thrust::cuda_cub::stream(policy);
cudaError_t status;

// Determine temporary storage requirements:
size_t tmp_size = 0;
{
THRUST_INDEX_TYPE_DISPATCH2(
status,
Dispatch32::Dispatch,
Dispatch64::Dispatch,
num_items,
(nullptr, tmp_size, first, result, scan_op, InputValueT(init), num_items_fixed, stream));
thrust::cuda_cub::throw_on_error(
status,
"after determining tmp storage "
"requirements for inclusive_scan");
}

// Run scan:
{
// Allocate temporary storage:
thrust::detail::temporary_array<thrust::detail::uint8_t, Derived> tmp{policy, tmp_size};
THRUST_INDEX_TYPE_DISPATCH2(
status,
Dispatch32::Dispatch,
Dispatch64::Dispatch,
num_items,
(tmp.data().get(), tmp_size, first, result, scan_op, InputValueT(init), num_items_fixed, stream));
thrust::cuda_cub::throw_on_error(status, "after dispatching inclusive_scan kernel");
thrust::cuda_cub::throw_on_error(
thrust::cuda_cub::synchronize_optional(policy), "inclusive_scan failed to synchronize");
}

return result + num_items;
}

_CCCL_EXEC_CHECK_DISABLE
template <typename Derived, typename InputIt, typename Size, typename OutputIt, typename InitValueT, typename ScanOp>
_CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
Expand Down Expand Up @@ -159,6 +229,21 @@ _CCCL_HOST_DEVICE OutputIt exclusive_scan_n_impl(
//-------------------------

_CCCL_EXEC_CHECK_DISABLE
template <typename Derived, typename InputIt, typename Size, typename OutputIt, typename T, typename ScanOp>
_CCCL_HOST_DEVICE OutputIt inclusive_scan_n(
thrust::cuda_cub::execution_policy<Derived>& policy,
InputIt first,
Size num_items,
OutputIt result,
T init,
ScanOp scan_op)
{
THRUST_CDP_DISPATCH(
(result = thrust::cuda_cub::detail::inclusive_scan_n_impl(policy, first, num_items, result, init, scan_op);),
(result = thrust::inclusive_scan(cvt_to_seq(derived_cast(policy)), first, first + num_items, result, scan_op);));
return result;
}

template <typename Derived, typename InputIt, typename Size, typename OutputIt, typename ScanOp>
_CCCL_HOST_DEVICE OutputIt inclusive_scan_n(
thrust::cuda_cub::execution_policy<Derived>& policy, InputIt first, Size num_items, OutputIt result, ScanOp scan_op)
Expand All @@ -178,6 +263,20 @@ _CCCL_HOST_DEVICE OutputIt inclusive_scan(
return thrust::cuda_cub::inclusive_scan_n(policy, first, num_items, result, scan_op);
}

template <typename Derived, typename InputIt, typename OutputIt, typename T, typename ScanOp>
_CCCL_HOST_DEVICE OutputIt inclusive_scan(
thrust::cuda_cub::execution_policy<Derived>& policy,
InputIt first,
InputIt last,
OutputIt result,
T init,
ScanOp scan_op)
{
using diff_t = typename thrust::iterator_traits<InputIt>::difference_type;
diff_t const num_items = thrust::distance(first, last);
return thrust::cuda_cub::inclusive_scan_n(policy, first, num_items, result, init, scan_op);
}

template <typename Derived, typename InputIt, typename OutputIt>
_CCCL_HOST_DEVICE OutputIt
inclusive_scan(thrust::cuda_cub::execution_policy<Derived>& policy, InputIt first, InputIt last, OutputIt result)
Expand Down

0 comments on commit fa8c7e3

Please sign in to comment.