Skip to content

Commit

Permalink
Simplify temporary storage allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed May 22, 2024
1 parent 3e293a8 commit 5c80ab2
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions cub/test/catch2_test_launch_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,15 @@ void device_side_api_launch(ActionT action, Args... args)
template <class ActionT, class... Args>
void host_side_api_launch(ActionT action, Args... args)
{
std::uint8_t* d_temp_storage = nullptr;
std::size_t temp_storage_bytes{};
cudaError_t error = action(d_temp_storage, temp_storage_bytes, args...);
cudaError_t error = action(nullptr, temp_storage_bytes, args...);
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
REQUIRE(cudaSuccess == error);

c2h::device_vector<std::uint8_t> temp_storage(temp_storage_bytes);
d_temp_storage = thrust::raw_pointer_cast(temp_storage.data());

error = action(d_temp_storage, temp_storage_bytes, args...);
error = action(thrust::raw_pointer_cast(temp_storage.data()), temp_storage_bytes, args...);
REQUIRE(cudaSuccess == cudaPeekAtLastError());
REQUIRE(cudaSuccess == cudaDeviceSynchronize());
REQUIRE(cudaSuccess == error);
Expand Down

0 comments on commit 5c80ab2

Please sign in to comment.