Skip to content

Commit

Permalink
Merge pull request #174 from sony/feature/20190624-2stream-conv-bwd
Browse files Browse the repository at this point in the history
Concurrent streams in backward conv & disable nan propagation in MaxPoolingCudnn
  • Loading branch information
TakuyaNarihira authored Jul 11, 2019
2 parents a16372b + f25d89f commit f5bc805
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 16 deletions.
12 changes: 12 additions & 0 deletions include/nbla/cuda/cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ namespace nbla {

using std::unordered_map;

/**
* Enum for nbla global streams.
*/
enum CudaStreamId { CONVOLUTION_BWD, MAX_COUNT };

/**
Singleton class for storing some handles or configs for CUDA Computation.
*/
Expand Down Expand Up @@ -74,6 +79,11 @@ class NBLA_CUDA_API Cuda {
*/
shared_ptr<Allocator> naive_allocator();

/** Get auxilliary stream
*/
shared_ptr<cudaStream_t> get_stream(unsigned int flag, CudaStreamId streamId,
int device = -1);

protected:
std::mutex mtx_cublas_;
std::mutex mtx_curand_;
Expand All @@ -92,6 +102,8 @@ class NBLA_CUDA_API Cuda {
*/
shared_ptr<Allocator> naive_allocator_;
shared_ptr<Allocator> caching_allocator_;
// stream pool -> <device, <id, stream>>
unordered_map<int, unordered_map<int, shared_ptr<cudaStream_t>>> streams_;

private:
friend SingletonManager;
Expand Down
5 changes: 3 additions & 2 deletions include/nbla/cuda/cudnn/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ class NBLA_CUDA_API CudnnHandleManager {
/**
Get cuDNN handle for device.
*/
cudnnHandle_t handle(int device = -1);
cudnnHandle_t handle(int device = -1, cudaStream_t stream = 0);

/** Hash map for CudnnConvResource.
*/
Expand Down Expand Up @@ -398,7 +398,8 @@ class NBLA_CUDA_API CudnnHandleManager {
void set_deterministic_option(bool value);

protected:
map<int, cudnnHandle_t> handles_;
unordered_map<int, unordered_map<cudaStream_t, shared_ptr<cudnnHandle_t>>>
handles_;
int workspace_limit_{0}; ///< Workspace limit in bytes.
bool deterministic_option_{false}; ///< Choose deterministic algorithms

Expand Down
7 changes: 7 additions & 0 deletions include/nbla/cuda/cudnn/function/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ template <typename T> class ConvolutionCudaCudnn : public Convolution<T> {
protected:
int device_;
cudnnHandle_t cudnn_handle_;
cudnnHandle_t dgrad_handle_;
shared_ptr<cudaEvent_t> default_event_;
shared_ptr<cudaEvent_t> dgrad_event_;
shared_ptr<cudaStream_t> dgrad_stream_;
#if CUDNN_VERSION < 7000
int x_offset_;
int w_offset_;
Expand All @@ -84,6 +88,9 @@ template <typename T> class ConvolutionCudaCudnn : public Convolution<T> {
virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);

void wait_default_on_dgrad();
void wait_dgrad_on_default();
};
}
#endif
39 changes: 39 additions & 0 deletions src/nbla/cuda/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ Cuda::~Cuda() {
}
}
}

for (auto &all_streams : this->streams_) {
for (auto &stream : all_streams.second) {
NBLA_CUDA_CHECK(cudaStreamDestroy(*(stream.second)));
}
}
}

cublasHandle_t Cuda::cublas_handle(int device) {
Expand Down Expand Up @@ -121,6 +127,39 @@ std::shared_ptr<cudaEvent_t> Cuda::cuda_event(unsigned int flags, int device) {
});
}

shared_ptr<cudaStream_t> Cuda::get_stream(unsigned int flags,
CudaStreamId streamId, int device) {
if (device < 0) {
device = cuda_get_device();
}

int streamIdInt = static_cast<int>(streamId);

auto device_streams = this->streams_[device];
auto it = device_streams.find(streamIdInt);

// Stream has already been created.
if (it != device_streams.end()) {
// check flags
auto stream = it->second;
unsigned int register_flags;
NBLA_CUDA_CHECK(cudaStreamGetFlags(*stream, &register_flags));
NBLA_CHECK(flags == register_flags, error_code::value,
"flag mismatch. StreamId: %u, flags created before: %u, flags "
"requested: %u",
streamId, register_flags, flags);
return it->second;
}

// Create stream.
auto stream = shared_ptr<cudaStream_t>(new cudaStream_t());
NBLA_CUDA_CHECK(cudaStreamCreateWithFlags(stream.get(), flags));

this->streams_[device].insert({streamIdInt, stream});

return stream;
}

curandGenerator_t Cuda::curand_generator() {
// Get current device
int device = cuda_get_device();
Expand Down
26 changes: 17 additions & 9 deletions src/nbla/cuda/cudnn/cudnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,7 @@ CudnnTensorDescriptor::CudnnTensorDescriptor() {
CudnnTensorDescriptor::~CudnnTensorDescriptor() {
NBLA_CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc));
}

////////////////////////////////////////
// Cudnn Pooling Wrapper
////////////////////////////////////////
Expand Down Expand Up @@ -579,7 +580,7 @@ CudnnPooling::CudnnPooling(const vector<int> &inshape,
// Create pooling descriptor.
#if CUDNN_VERSION >= 5000
NBLA_CUDNN_CHECK(cudnnSetPoolingNdDescriptor(
pooling_desc_.desc, mode, CUDNN_PROPAGATE_NAN, cfg.kernel.size(),
pooling_desc_.desc, mode, CUDNN_NOT_PROPAGATE_NAN, cfg.kernel.size(),
cfg.kernel.data(), cfg.pad.data(), cfg.stride.data()));
#else
NBLA_CUDNN_CHECK(cudnnSetPoolingNdDescriptor(
Expand Down Expand Up @@ -665,21 +666,28 @@ void CudnnSoftmax::backward(const void *alpha, const void *y, const void *dy,
CudnnHandleManager::CudnnHandleManager() {}

CudnnHandleManager::~CudnnHandleManager() {
for (auto handle : this->handles_) {
NBLA_CUDNN_CHECK(cudnnDestroy(handle.second));
for (auto dev_handles : this->handles_) {
for (auto handle : dev_handles.second) {
NBLA_CUDNN_CHECK(cudnnDestroy(*handle.second));
}
}
}

cudnnHandle_t CudnnHandleManager::handle(int device) {
cudnnHandle_t CudnnHandleManager::handle(int device, cudaStream_t stream) {
if (device < 0) {
NBLA_CUDA_CHECK(cudaGetDevice(&device));
}
if (this->handles_.count(device) == 0) {
cudnnHandle_t handle;
NBLA_CUDNN_CHECK(cudnnCreate(&handle));
this->handles_[device] = handle;
auto &dev_handles = this->handles_[device];
auto handle = dev_handles[stream];
if (handle) {
return *handle;
}
return this->handles_[device];

handle = make_shared<cudnnHandle_t>();
NBLA_CUDNN_CHECK(cudnnCreate(handle.get()));
NBLA_CUDNN_CHECK(cudnnSetStream(*handle, stream));
dev_handles[stream] = handle;
return *handle;
}

int CudnnHandleManager::get_workspace_limit_in_bytes() {
Expand Down
42 changes: 37 additions & 5 deletions src/nbla/cuda/cudnn/function/generic/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@

namespace nbla {

template <typename T> void ConvolutionCudaCudnn<T>::wait_default_on_dgrad() {
NBLA_CUDA_CHECK(cudaEventRecord(*(this->default_event_), 0));
NBLA_CUDA_CHECK(
cudaStreamWaitEvent(*(this->dgrad_stream_), *(this->default_event_), 0));
}

template <typename T> void ConvolutionCudaCudnn<T>::wait_dgrad_on_default() {
NBLA_CUDA_CHECK(
cudaEventRecord(*(this->dgrad_event_), *(this->dgrad_stream_)));
NBLA_CUDA_CHECK(cudaStreamWaitEvent(0, *(this->dgrad_event_), 0));
}

template <typename T>
void ConvolutionCudaCudnn<T>::setup_impl(const Variables &inputs,
const Variables &outputs) {
Expand All @@ -36,6 +48,16 @@ void ConvolutionCudaCudnn<T>::setup_impl(const Variables &inputs,
Convolution<T>::setup_impl(inputs, outputs);
cudnn_handle_ = SingletonManager::get<CudnnHandleManager>()->handle(device_);

dgrad_event_ =
SingletonManager::get<Cuda>()->cuda_event(cudaEventDisableTiming);
default_event_ =
SingletonManager::get<Cuda>()->cuda_event(cudaEventDisableTiming);

dgrad_stream_ = SingletonManager::get<Cuda>()->get_stream(
cudaStreamNonBlocking, nbla::CudaStreamId::CONVOLUTION_BWD, device_);
dgrad_handle_ = SingletonManager::get<CudnnHandleManager>()->handle(
device_, *dgrad_stream_);

#if CUDNN_VERSION < 7000
x_offset_ = this->inner_size_i_ / this->group_;
y_offset_ = this->inner_size_o_ / this->group_;
Expand Down Expand Up @@ -127,7 +149,6 @@ void ConvolutionCudaCudnn<T>::backward_impl(const Variables &inputs,
const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum) {

if (!(propagate_down[0] || propagate_down[1] ||
(inputs.size() == 3 && propagate_down[2]))) {
return;
Expand All @@ -150,22 +171,32 @@ void ConvolutionCudaCudnn<T>::backward_impl(const Variables &inputs,
}
auto alpha = get_cudnn_scalar_arg<T>(1);
auto workspace_size = rsc_->workspace_size();
unique_ptr<CudaCachedArray> workspace_arr;
void *workspace{nullptr};
unique_ptr<CudaCachedArray> workspace_arr, workspace_arr_dgrad;
void *workspace{nullptr}, *workspace_dgrad{nullptr};
if (workspace_size) {
workspace_arr.reset(
new CudaCachedArray(workspace_size, dtypes::BYTE, this->ctx_));
workspace = workspace_arr->pointer<void>();
workspace_arr_dgrad.reset(
new CudaCachedArray(workspace_size, dtypes::BYTE, this->ctx_));
workspace_dgrad = workspace_arr_dgrad->pointer<void>();
}
#if CUDNN_VERSION >= 7000
if (propagate_down[0]) {
this->wait_default_on_dgrad();
auto beta = get_cudnn_scalar_arg<T>(accum[0] ? 1 : 0);
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardData(
cudnn_handle_, &alpha, rsc_->w_desc, w, rsc_->y_desc, dy,
rsc_->conv_dgrad_desc.desc, rsc_->bwd_data_algo, workspace,
dgrad_handle_, &alpha, rsc_->w_desc, w, rsc_->y_desc, dy,
rsc_->conv_dgrad_desc.desc, rsc_->bwd_data_algo, workspace_dgrad,
rsc_->bwd_data_workspace_size, &beta, rsc_->x_desc, dx));
}
if (propagate_down[1]) {
/** Note:
* When the bwd of first layer convolution is slower, check the value of
* beta.
* In the case of beta = 1, Not first_layer_wgrad_kernel which is faster than
* any others but a slower kernel would be called in cudnn API.
*/
auto beta = get_cudnn_scalar_arg<T>(accum[1] ? 1 : 0);
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardFilter(
cudnn_handle_, &alpha, rsc_->x_desc, x, rsc_->y_desc, dy,
Expand All @@ -177,6 +208,7 @@ void ConvolutionCudaCudnn<T>::backward_impl(const Variables &inputs,
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardBias(
cudnn_handle_, &alpha, rsc_->y_desc, dy, &beta, rsc_->b_desc, db));
}
this->wait_dgrad_on_default();
#else
for (int g = 0; g < this->group_; ++g) {
if (propagate_down[0]) {
Expand Down
2 changes: 2 additions & 0 deletions src/nbla/cuda/cudnn/function/generic/deconvolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,15 @@ void DeconvolutionCudaCudnn<T>::backward_impl(
for (int g = 0; g < this->group_; ++g) {
if (propagate_down[0]) {
auto beta = get_cudnn_scalar_arg<T>(accum[0] ? 1 : 0);
// todo: will enable async streaming execution as well as convolution.
NBLA_CUDNN_CHECK(cudnnConvolutionForward(
cudnn_handle_, &alpha, rsc_->x_desc, dx + x_offset_ * g, rsc_->w_desc,
w + w_offset_ * g, rsc_->conv_desc.desc, rsc_->fwd_algo, workspace,
rsc_->fwd_workspace_size, &beta, rsc_->y_desc, dy + y_offset_ * g));
}
if (propagate_down[1]) {
auto beta = get_cudnn_scalar_arg<T>(accum[1] ? 1 : 0);
// todo: will enable async streaming execution as well as convolution.
NBLA_CUDNN_CHECK(cudnnConvolutionBackwardFilter(
cudnn_handle_, &alpha, rsc_->x_desc, dx + x_offset_ * g, rsc_->y_desc,
y + y_offset_ * g, rsc_->conv_wgrad_desc.desc, rsc_->bwd_filter_algo,
Expand Down

0 comments on commit f5bc805

Please sign in to comment.