Skip to content

Commit

Permalink
create cuda stream on each device
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Dec 7, 2022
1 parent 1cce398 commit eac79ef
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
21 changes: 16 additions & 5 deletions nvbench/cuda_stream.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@

#pragma once

#include <nvbench/cuda_call.cuh>

#include <cuda_runtime_api.h>

#include <nvbench/cuda_call.cuh>
#include <nvbench/detail/device_scope.cuh>
#include <nvbench/device_info.cuh>

#include <memory>
#include <optional>

namespace nvbench
{
Expand All @@ -42,10 +45,18 @@ struct cuda_stream
* Constructs a cuda_stream that owns a new stream, created with
* `cudaStreamCreate`.
*/
cuda_stream()
: m_stream{[]() {
cuda_stream(std::optional<nvbench::device_info> device)
: m_stream{[device]() {
cudaStream_t s;
NVBENCH_CUDA_CALL(cudaStreamCreate(&s));
if (device.has_value())
{
nvbench::detail::device_scope scope_guard{device.value().get_id()};
NVBENCH_CUDA_CALL(cudaStreamCreate(&s));
}
else
{
NVBENCH_CUDA_CALL(cudaStreamCreate(&s));
}
return s;
}(),
stream_deleter{true}}
Expand Down
3 changes: 2 additions & 1 deletion nvbench/state.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ private:
std::optional<nvbench::device_info> device,
std::size_t type_config_index);

nvbench::cuda_stream m_cuda_stream;
std::reference_wrapper<const nvbench::benchmark_base> m_benchmark;
nvbench::named_values m_axis_values;
std::optional<nvbench::device_info> m_device;
Expand All @@ -277,6 +276,8 @@ private:
nvbench::float64_t m_skip_time;
nvbench::float64_t m_timeout;

nvbench::cuda_stream m_cuda_stream;

// Deadlock protection. See blocking_kernel's class doc for details.
nvbench::float64_t m_blocking_kernel_timeout{30.0};

Expand Down
2 changes: 2 additions & 0 deletions nvbench/state.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ state::state(const benchmark_base &bench)
, m_max_noise{bench.get_max_noise()}
, m_skip_time{bench.get_skip_time()}
, m_timeout{bench.get_timeout()}
, m_cuda_stream{std::nullopt}
{}

state::state(const benchmark_base &bench,
Expand All @@ -58,6 +59,7 @@ state::state(const benchmark_base &bench,
, m_max_noise{bench.get_max_noise()}
, m_skip_time{bench.get_skip_time()}
, m_timeout{bench.get_timeout()}
, m_cuda_stream{m_device}
{}

nvbench::int64_t state::get_int64(const std::string &axis_name) const
Expand Down

0 comments on commit eac79ef

Please sign in to comment.