Skip to content

Commit

Permalink
Merge pull request #175 from sony/feature/20190206-double-backward-in…
Browse files Browse the repository at this point in the history
…-python

Double backward
  • Loading branch information
TakuyaNarihira authored Jul 22, 2019
2 parents f5bc805 + 202987a commit 92d5918
Show file tree
Hide file tree
Showing 51 changed files with 412 additions and 67 deletions.
3 changes: 3 additions & 0 deletions build-tools/code_generator/function_types.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,6 @@ GatherNd:
ScatterNd:
float: [float]
half: [Half]
MaxPoolingBackward:
float: [float]
half: [Half]
1 change: 1 addition & 0 deletions include/nbla/cuda/function/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ template <typename T> class BroadcastCuda : public Broadcast<T> {
// Variables for backward.
shared_ptr<Function> f_transpose_, f_sum_;
VariablePtr trp_input_, trp_output_, sum_input_, sum_output_;
vector<int> broadcast_dims_;

public:
typedef typename CudaType<T>::type Tc;
Expand Down
49 changes: 49 additions & 0 deletions include/nbla/cuda/function/max_pooling_backward.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2017 Sony Corporation. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef NBLA_CUDA_FUNCTION_MAX_POOLING_BACKWARD_HPP
#define NBLA_CUDA_FUNCTION_MAX_POOLING_BACKWARD_HPP

#include <nbla/cuda/cuda.hpp>
#include <nbla/function/max_pooling_backward.hpp>

namespace nbla {

template <typename T>
class MaxPoolingBackwardCuda : public MaxPoolingBackward<T> {
public:
typedef typename CudaType<T>::type Tcu;

explicit MaxPoolingBackwardCuda(const Context &ctx, const vector<int> &kernel,
const vector<int> &stride, bool ignore_border,
const vector<int> &pad, bool channel_last)
: MaxPoolingBackward<T>(ctx, kernel, stride, ignore_border, pad,
channel_last),
device_(std::stoi(ctx.device_id)) {}
virtual ~MaxPoolingBackwardCuda() {}
virtual string name() { return "MaxPoolingBackwardCuda"; }
virtual vector<string> allowed_array_classes() {
return SingletonManager::get<Cuda>()->array_classes();
}

protected:
int device_;
virtual void setup_impl(const Variables &inputs, const Variables &outputs);
virtual void forward_impl(const Variables &inputs, const Variables &outputs);
virtual void backward_impl(const Variables &inputs, const Variables &outputs,
const vector<bool> &propagate_down,
const vector<bool> &accum);
};
}
#endif
22 changes: 16 additions & 6 deletions include/nbla/cuda/function/utils/base_transform_unary.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ void backward_impl_transform_unary(const Variables &inputs,
NAME##UnaryOpCuda(this->args_)); \
}

#define NBLA_DEFINE_GRAD_DEPENDS_OUTPUT_DATA(NAME, DEP_Y) \
template <typename T> \
bool NAME##Cuda<T>::grad_depends_output_data(int i, int o) const { \
return DEP_Y; \
};

// ----------------------------------------------------------------------------
// Zero argument
// ----------------------------------------------------------------------------
Expand All @@ -138,11 +144,13 @@ void backward_impl_transform_unary(const Variables &inputs,

#define NBLA_DEFINE_TRANSFORM_UNARY_CUDA_NO_GRAD(NAME, OP) \
NBLA_DEFINE_UNARY_OP_CUDA_NO_GRAD(NAME, OP); \
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME)
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME) \
NBLA_DEFINE_GRAD_DEPENDS_OUTPUT_DATA(NAME, false)

#define NBLA_DEFINE_TRANSFORM_UNARY_CUDA(NAME, OP, GOP) \
#define NBLA_DEFINE_TRANSFORM_UNARY_CUDA(NAME, OP, GOP, DEP_Y) \
NBLA_DEFINE_UNARY_OP_CUDA(NAME, OP, GOP); \
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME)
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME) \
NBLA_DEFINE_GRAD_DEPENDS_OUTPUT_DATA(NAME, DEP_Y)

// ----------------------------------------------------------------------------
// One argument
Expand All @@ -157,9 +165,10 @@ void backward_impl_transform_unary(const Variables &inputs,
NBLA_DEFINE_UNARY_OP_CUDA_BACKWARD(GOP) \
}

#define NBLA_DEFINE_TRANSFORM_UNARY_CUDA_1(NAME, OP, GOP, A0) \
#define NBLA_DEFINE_TRANSFORM_UNARY_CUDA_1(NAME, OP, GOP, A0, DEP_Y) \
NBLA_DEFINE_UNARY_OP_CUDA_1(NAME, OP, GOP, A0); \
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME)
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME) \
NBLA_DEFINE_GRAD_DEPENDS_OUTPUT_DATA(NAME, DEP_Y)

#define NBLA_DEFINE_UNARY_OP_CUDA_1_NO_GRAD(NAME, OP, A0) \
class NAME##UnaryOpCuda : public BaseUnaryOpCuda { \
Expand All @@ -172,6 +181,7 @@ void backward_impl_transform_unary(const Variables &inputs,

#define NBLA_DEFINE_TRANSFORM_UNARY_CUDA_1_NO_GRAD(NAME, OP, A0) \
NBLA_DEFINE_UNARY_OP_CUDA_1_NO_GRAD(NAME, OP, A0); \
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME)
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(NAME) \
NBLA_DEFINE_GRAD_DEPENDS_OUTPUT_DATA(NAME, false)
}
#endif
6 changes: 6 additions & 0 deletions include/nbla/cuda/function/utils/base_transform_unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ protected: \
return create_##NAME(this->ctx_); \
} \
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(); \
\
public: \
virtual bool grad_depends_output_data(int i, int o) const; \
}

// ----------------------------------------------------------------------------
Expand All @@ -90,6 +93,9 @@ protected: \
return create_##NAME(this->ctx_, std::get<0>(this->args_)); \
} \
NBLA_DECLARE_TRANSFORM_UNARY_CUDA_FORWARD_BACKWARD(); \
\
public: \
virtual bool grad_depends_output_data(int i, int o) const; \
}
}
#endif
11 changes: 10 additions & 1 deletion src/nbla/cuda/cudnn/function/generic/batch_normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,17 @@ void BatchNormalizationCudaCudnn<T>::setup_impl(const Variables &inputs,
int H = this->size2_;
int W = 1;
mode_ = CUDNN_BATCHNORM_SPATIAL;
// Channel last is restricted for spatial input
bool channel_last = this->axes_[0] == inputs[0]->ndim() - 1;
if (channel_last) {
if (inputs[0]->ndim() == 2) { // typical 1-d affine output with shape (N, C)
mode_ = CUDNN_BATCHNORM_PER_ACTIVATION;
NBLA_CUDNN_CHECK(
cudnnSetTensor4dDescriptor(input_desc_.desc, CUDNN_TENSOR_NHWC,
cudnn_data_type<T>::type(), N, C, H, W));
NBLA_CUDNN_CHECK(
cudnnSetTensor4dDescriptor(output_desc_.desc, CUDNN_TENSOR_NHWC,
cudnn_data_type<T>::type(), N, C, H, W));
} else if (channel_last) {
// To prevent NOT SUPPORTED error in CUDNNN, N and H are recalculated.
// (Large N is not allowed.)
N = inputs[0]->shape()[0];
Expand Down
3 changes: 2 additions & 1 deletion src/nbla/cuda/function/generic/abs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Abs, std::abs(x), (x < (T)0) ? -dy : dy);
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Abs, std::abs(x), (x < (T)0) ? -dy : dy,
false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/acos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ACos, acos(x), -dy *rsqrt(1 - x * x));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ACos, acos(x), -dy *rsqrt(1 - x * x), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/acosh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ACosh, acosh(x), dy / sqrt(x * x - 1));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ACosh, acosh(x), dy / sqrt(x * x - 1), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/add_scalar.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA_1(AddScalar, x + (T)a0, dy, double);
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_1(AddScalar, x + (T)a0, dy, double, false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/asin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ASin, asin(x), dy *rsqrt(1 - x * x));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ASin, asin(x), dy *rsqrt(1 - x * x), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/asinh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ASinh, asinh(x), dy *rsqrt(x *x + 1));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ASinh, asinh(x), dy *rsqrt(x *x + 1), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/atan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ATan, atan(x), dy / (1 + x * x));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ATan, atan(x), dy / (1 + x * x), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/atanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ATanh, atanh(x), dy / (1 - x * x));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(ATanh, atanh(x), dy / (1 - x * x), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/binary_sigmoid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(BinarySigmoid, (x > (T)0) ? (T)1 : (T)0,
(abs(x) >= (T)1) ? (T)0 : dy *(T)0.5);
(abs(x) >= (T)1) ? (T)0 : dy *(T)0.5, false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/binary_tanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@
namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(BinaryTanh, (x > (T)0) ? (T)1 : (T)-1,
(abs(x) >= (T)1) ? (T)0 : dy);
(abs(x) >= (T)1) ? (T)0 : dy, false);
}
30 changes: 14 additions & 16 deletions src/nbla/cuda/function/generic/broadcast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,10 @@ void BroadcastCuda<T>::setup_impl(const Variables &inputs,
broadcast_dims.push_back(d);
}
}
broadcast_dims_ = broadcast_dims;
if (broadcast_dims.size() == 0)
return;
sum_input_ = make_shared<Variable>(outputs[0]->grad());
sum_output_ = make_shared<Variable>();
f_sum_ = create_Sum(this->ctx_, /*axis*/ broadcast_dims, /*keepdims*/ false);
f_sum_->setup(Variables{sum_input_.get()}, Variables{sum_output_.get()});
f_sum_ = create_Sum(this->ctx_, /*axis*/ broadcast_dims, /*keepdims*/ true);
}

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -156,21 +154,21 @@ void BroadcastCuda<T>::backward_impl(const Variables &inputs,
const vector<bool> &accum) {
if (!propagate_down[0])
return;
shared_ptr<Variable> sum_input = make_shared<Variable>(outputs[0]->grad());
shared_ptr<Variable> sum_output;
if (f_sum_) {
sum_input_->set_grad(outputs[0]->grad()); // What is this??? Seems like no
// effect. set_data()? set_data is
// done in setup_impl anyway.
if (!accum[0]) {
auto data_backup = sum_output_->data()->array();
sum_output_->data()->set_array(inputs[0]->grad()->array());
f_sum_->forward(Variables{sum_input_.get()},
Variables{sum_output_.get()});
sum_output_->data()->set_array(data_backup);
sum_output = make_shared<Variable>(inputs[0]->grad());
f_sum_->setup(Variables{sum_input.get()}, Variables{sum_output.get()});
f_sum_->forward(Variables{sum_input.get()}, Variables{sum_output.get()});
return;
}
f_sum_->forward(Variables{sum_input_.get()}, Variables{sum_output_.get()});
} else if (!accum[0]) {
inputs[0]->grad()->zero();
sum_output = make_shared<Variable>(inputs[0]->shape());
f_sum_->setup(Variables{sum_input.get()}, Variables{sum_output.get()});
f_sum_->forward(Variables{sum_input.get()}, Variables{sum_output.get()});
} else {
if (!accum[0])
inputs[0]->grad()->zero();
}
auto _get = [this](Variable *v) {
return v->get_data_pointer<Tc>(this->ctx_);
Expand All @@ -179,7 +177,7 @@ void BroadcastCuda<T>::backward_impl(const Variables &inputs,
return v->get_grad_pointer<Tc>(this->ctx_);
};
cuda_set_device(device_);
const Tc *g = f_sum_ ? _get(sum_output_.get()) : _gget(outputs[0]);
const Tc *g = f_sum_ ? _get(sum_output.get()) : _gget(outputs[0]);
Tc *dx = inputs[0]->cast_grad_and_get_pointer<Tc>(this->ctx_, false);
NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_add_grad, inputs[0]->size(), g, dx);
}
Expand Down
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/ceil.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Ceil, ceil(x), dy);
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Ceil, ceil(x), dy, false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/cos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Cos, cos(x), -dy *sin(x));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Cos, cos(x), -dy *sin(x), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/cosh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Cosh, cosh(x), dy *sinh(x));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Cosh, cosh(x), dy *sinh(x), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/elu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ namespace nbla {
NBLA_DEFINE_TRANSFORM_UNARY_CUDA_1(ELU,
x >= (T)0 ? x : (T)a0 * (std::exp(x) - (T)1),
x >= (T)0 ? dy : dy * (T)a0 * std::exp(x),
double);
double, false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/exp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Exp, std::exp(x), dy *exp(x));
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Exp, std::exp(x), dy *exp(x), false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/floor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Floor, floor(x), dy);
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Floor, floor(x), dy, false);
}
3 changes: 2 additions & 1 deletion src/nbla/cuda/function/generic/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ NBLA_DEFINE_TRANSFORM_UNARY_CUDA(
std::pow(1 / cosh((T)0.797885 * x +
(T)0.0356774 * std::pow(x, (T)3)),
(T)2) +
(T)0.5 * std::tanh((T)0.797885 * x + (T)0.0356774 * std::pow(x, (T)3)));
(T)0.5 * std::tanh((T)0.797885 * x + (T)0.0356774 * std::pow(x, (T)3)),
false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/hard_sigmoid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
namespace nbla {
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(
HardSigmoid, x > (T)2.5 ? (T)1 : x < (T)-2.5 ? (T)0 : (T)0.2 * x + (T)0.5,
x <= (T)2.5 && (T)-2.5 <= x ? dy * (T)0.2 : (T)0);
x <= (T)2.5 && (T)-2.5 <= x ? dy * (T)0.2 : (T)0, false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/hard_tanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
namespace nbla {
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(HardTanh,
x > (T)1 ? (T)1 : x < (T)-1 ? (T)-1 : x,
(T)-1 <= x && x <= (T)1 ? dy : (T)0);
(T)-1 <= x && x <= (T)1 ? dy : (T)0, false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/log.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@

namespace nbla {

NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Log, std::log(x), dy / x);
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(Log, std::log(x), dy / x, false);
}
2 changes: 1 addition & 1 deletion src/nbla/cuda/function/generic/log_sigmoid.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@

namespace nbla {
NBLA_DEFINE_TRANSFORM_UNARY_CUDA(LogSigmoid, -std::log(std::exp(-x) + (T)1),
dy / (std::exp(x) + (T)1));
dy / (std::exp(x) + (T)1), false);
}
Loading

0 comments on commit 92d5918

Please sign in to comment.