Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow move of cuda stream #203

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions include/dca/linalg/util/cublas_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,33 @@ namespace util {

class CublasHandle {
public:
CublasHandle() {
CublasHandle() noexcept {
cublasStatus_t ret = cublasCreate(&handle_);
checkRC(ret);
}

CublasHandle& operator=(const CublasHandle& other) = delete;

CublasHandle(CublasHandle&& other) {
CublasHandle(CublasHandle&& other) noexcept {
std::swap(handle_, other.handle_);
}

CublasHandle& operator=(CublasHandle&& other) noexcept {
std::swap(handle_, other.handle_);
return *this;
}

~CublasHandle() {
if (handle_)
cublasDestroy(handle_);
}

void setStream(cudaStream_t stream) {
void setStream(cudaStream_t stream) noexcept {
cublasStatus_t ret = cublasSetStream(handle_, stream);
checkRC(ret);
}

operator cublasHandle_t() const {
operator cublasHandle_t() const noexcept {
return handle_;
}

Expand Down
16 changes: 11 additions & 5 deletions include/dca/linalg/util/cuda_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,23 @@ namespace util {

class CudaStream {
public:
CudaStream() {
CudaStream() noexcept {
cudaStreamCreate(&stream_);
}

CudaStream(const CudaStream& other) = delete;
CudaStream& operator=(const CudaStream& other) = delete;

CudaStream(CudaStream&& other) {
CudaStream(CudaStream&& other) noexcept {
std::swap(stream_, other.stream_);
}

void sync() const {
CudaStream& operator=(CudaStream&& other) noexcept {
std::swap(stream_, other.stream_);
return *this;
}

void sync() const noexcept {
checkRC(cudaStreamSynchronize(stream_));
}

Expand All @@ -45,7 +51,7 @@ class CudaStream {
cudaStreamDestroy(stream_);
}

operator cudaStream_t() const {
operator cudaStream_t() const noexcept {
return stream_;
}

Expand All @@ -60,7 +66,7 @@ class CudaStream {
public:
CudaStream() = default;

void sync() const {}
void sync() const noexcept {}
};

#endif // DCA_HAVE_CUDA
Expand Down
24 changes: 18 additions & 6 deletions include/dca/linalg/util/magma_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,41 @@ namespace util {

class MagmaQueue {
public:
MagmaQueue() {
MagmaQueue() noexcept {
magma_queue_create(&queue_);
}

MagmaQueue(const MagmaQueue&) = delete;
MagmaQueue& operator=(const MagmaQueue&) = delete;

MagmaQueue(MagmaQueue&& rhs) noexcept {
std::swap(queue_, rhs.queue_);
}

MagmaQueue& operator=(MagmaQueue&& rhs) noexcept {
std::swap(queue_, rhs.queue_);
return *this;
}

~MagmaQueue() {
magma_queue_destroy(queue_);
}

inline operator magma_queue_t() {
operator magma_queue_t() const noexcept {
return queue_;
}

cudaStream_t getStream() const {
cudaStream_t getStream() const noexcept {
return magma_queue_get_cuda_stream(queue_);
}

private:
magma_queue_t queue_ = nullptr;
};

} // util
} // linalg
} // dca
} // namespace util
} // namespace linalg
} // namespace dca

#endif // DCA_HAVE_CUDA
#endif // DCA_LINALG_UTIL_MAGMA_QUEUE_HPP
8 changes: 4 additions & 4 deletions include/dca/linalg/util/stream_container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class StreamContainer {
public:
StreamContainer(int max_threads = 0) : streams_(max_threads * streams_per_thread_) {}

int get_max_threads() const {
int get_max_threads() const noexcept {
return streams_.size() / streams_per_thread_;
}

int get_streams_per_thread() const {
int get_streams_per_thread() const noexcept {
return streams_per_thread_;
}

Expand All @@ -46,7 +46,7 @@ class StreamContainer {
// Returns the 'stream_id'-th stream associated with thread 'thread_id'.
// Preconditions: 0 <= thread_id < get_max_threads(),
// 0 <= stream_id < streams_per_thread_.
CudaStream& operator()(int thread_id, int stream_id) {
CudaStream& operator()(int thread_id, int stream_id) noexcept {
assert(thread_id >= 0 && thread_id < get_max_threads());
assert(stream_id >= 0 && stream_id < streams_per_thread_);
return streams_[stream_id + streams_per_thread_ * thread_id];
Expand All @@ -55,7 +55,7 @@ class StreamContainer {
// Synchronizes the 'stream_id'-th stream associated with thread 'thread_id'.
// Preconditions: 0 <= thread_id < get_max_threads(),
// 0 <= stream_id < streams_per_thread_.
void sync(int thread_id, int stream_id) {
void sync(int thread_id, int stream_id) noexcept {
operator()(thread_id, stream_id).sync();
}

Expand Down