Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BatchNorm kernel for ROCm #9014

Merged
merged 2 commits into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/rocm/miopen_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "core/framework/tensor.h"
#include <cfloat>

const double MIOPEN_BN_MIN_EPSILON = 1e-5;

namespace onnxruntime {
namespace rocm {

Expand Down Expand Up @@ -56,5 +58,14 @@ struct ReduceConsts {
static const ElemType One;
};

inline double ClampMiopenBatchNormEpsilon(double epsilon) {
if (epsilon < MIOPEN_BN_MIN_EPSILON) {
if (MIOPEN_BN_MIN_EPSILON - epsilon > FLT_EPSILON)
LOGS_DEFAULT(WARNING) << "Provided epsilon is smaller than CUDNN_BN_MIN_EPSILON. Setting it to CUDNN_BN_MIN_EPSILON";
return MIOPEN_BN_MIN_EPSILON;
}
return epsilon;
}

} // namespace rocm
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/providers/rocm/rocm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include "core/providers/shared_library/provider_api.h"
#include "core/common/status.h"
#include "core/providers/rocm/rocm_pch.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
#include "gtest/gtest.h"
#include "gmock/gmock.h"

using namespace std;

namespace onnxruntime {
namespace contrib {
namespace test {

using namespace onnxruntime::test;

#ifdef USE_CUDA
#if USE_CUDA || USE_ROCM
static void TestBatchNormInternal(bool test_double = false, bool T_is_half = false,
bool T1_is_half = false, bool T2_is_half = false,
const std::vector<int64_t>& input_output_dims = {2, 2, 2, 2}) {
Expand Down Expand Up @@ -138,9 +136,11 @@ TEST(CudaKernelTest, BNInternalBasic) { // float case
TestBatchNormInternal();
}

#ifndef USE_ROCM // MIOpen does not support double type
TEST(CudaKernelTest, BNInternalDouble) { // double case
TestBatchNormInternal(true);
}
#endif // ndef USE_ROCM

TEST(CudaKernelTest, BNInternalHalf) { // half case
TestBatchNormInternal(false, true, true, true);
Expand Down Expand Up @@ -195,7 +195,7 @@ TEST(CudaKernelTest, BNInternal1DInput) { // float case, 1d input
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{kCpuExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
#endif // USE_CUDA
#endif // USE_CUDA || USE_ROCM

} // namespace test
} // namespace contrib
Expand Down
137 changes: 137 additions & 0 deletions orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "batch_norm_grad.h"
#include "core/providers/common.h"
#include "core/providers/rocm/miopen_common.h"
#include "core/providers/cpu/nn/batch_norm_helper.h"
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h"

using namespace std;
namespace onnxruntime {
namespace rocm {

#define REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
BatchNormalizationGrad, \
kMSDomain, \
1, \
T##_##T1##_##T2, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()), \
BatchNormalizationGrad<T, T1, T2>);

template <typename T, typename T1, typename T2>
Status BatchNormalizationGrad<T, T1, T2>::ComputeInternal(OpKernelContext* ctx) const {
typedef typename ToHipType<T>::MappedType HipT;
typedef typename ToHipType<T1>::MappedType HipT1;
typedef typename ToHipType<T2>::MappedType HipT2;

const Tensor* dY = ctx->Input<Tensor>(0);
const Tensor* X = ctx->Input<Tensor>(1);
const Tensor* Scale = ctx->Input<Tensor>(2);
const Tensor* saved_mean = ctx->Input<Tensor>(3);
// miopenBatchNormalizationBackward() claims to use `savedInvVariance`, but the value
// is actually equal to the batch inv_std, so we use name `saved_inv_std` here.
const Tensor* saved_inv_std = ctx->Input<Tensor>(4);
const TensorShape input_shape = X->Shape();
const TensorShape channel_shape = saved_mean->Shape();

// no B here, but B has same size as Scale, so can validate inputs for gradient with this substitute
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, Scale, Scale, saved_mean, saved_inv_std));

auto dY_data = reinterpret_cast<const HipT*>(dY->template Data<T>());
auto X_data = reinterpret_cast<const HipT*>(X->template Data<T>());
auto Scale_data = reinterpret_cast<const HipT1*>(Scale->template Data<T1>());
auto saved_mean_data = reinterpret_cast<const HipT2*>(saved_mean->template Data<T2>());
auto saved_inv_std_data = reinterpret_cast<const HipT2*>(saved_inv_std->template Data<T2>());

auto dX_data = reinterpret_cast<HipT*>(ctx->Output(0, input_shape)->template MutableData<T>());
auto dScale_data = reinterpret_cast<HipT1*>(ctx->Output(1, channel_shape)->template MutableData<T1>());
auto dBias_data = reinterpret_cast<HipT1*>(ctx->Output(2, channel_shape)->template MutableData<T1>());

const auto alpha = Consts<HipT>::One;
const auto beta = Consts<HipT>::Zero;

MiopenTensor input_tensor, scale_bias_tensor;
vector<int64_t> new_dims;
BatchNormHelper::NormalizeDims(input_shape, new_dims);
ORT_RETURN_IF_ERROR(input_tensor.Set(new_dims, MiopenTensor::GetDataType<HipT>()));
ORT_RETURN_IF_ERROR(scale_bias_tensor.Set(input_tensor, miopen_batch_norm_mode_));

const int64_t C = new_dims[1];
auto p_scale = reinterpret_cast<const void*>(Scale_data);
auto p_saved_mean = reinterpret_cast<const void*>(saved_mean_data);
auto p_saved_inv_std = reinterpret_cast<const void*>(saved_inv_std_data);
auto p_dScale = reinterpret_cast<void*>(dScale_data);
auto p_dBias = reinterpret_cast<void*>(dBias_data);

IAllocatorUniquePtr<float> p_f_scale, p_f_dScale, p_f_dBias, p_f_saved_mean, p_f_saved_inv_std;

if (std::is_same<T1, MLFloat16>::value) {
p_f_scale = GetScratchBuffer<float>(C);
p_f_dScale = GetScratchBuffer<float>(C);
p_f_dBias = GetScratchBuffer<float>(C);

Impl_Cast<HipT1, float>(Stream(), Scale_data, p_f_scale.get(), C);

p_scale = p_f_scale.get();
p_dScale = p_f_dScale.get();
p_dBias = p_f_dBias.get();
}

if (std::is_same<T2, MLFloat16>::value) {
p_f_saved_mean = GetScratchBuffer<float>(C);
p_f_saved_inv_std = GetScratchBuffer<float>(C);

Impl_Cast<HipT2, float>(Stream(), saved_mean_data, p_f_saved_mean.get(), C);
Impl_Cast<HipT2, float>(Stream(), saved_inv_std_data, p_f_saved_inv_std.get(), C);

p_saved_mean = p_f_saved_mean.get();
p_saved_inv_std = p_f_saved_inv_std.get();
}

MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationBackward(
MiopenHandle(),
miopen_batch_norm_mode_,
&alpha,
&beta,
&alpha,
&beta,
input_tensor,
X_data,
input_tensor,
dY_data,
input_tensor,
dX_data,
scale_bias_tensor,
p_scale,
p_dScale,
p_dBias,
epsilon_,
p_saved_mean,
p_saved_inv_std));

if (std::is_same<T1, MLFloat16>::value) {
Impl_Cast<float, HipT1>(Stream(), reinterpret_cast<float*>(p_dScale), dScale_data, C);
Impl_Cast<float, HipT1>(Stream(), reinterpret_cast<float*>(p_dBias), dBias_data, C);
}

return Status::OK();
}

#define SPECIALIZED_GRADIENT(T, T1, T2) \
REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \
template Status BatchNormalizationGrad<T, T1, T2>::ComputeInternal(OpKernelContext* ctx) const;

SPECIALIZED_GRADIENT(float, float, float)
// MIOpen kernel does not support double, disable for now.
// SPECIALIZED_GRADIENT(double, double, double)
SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, MLFloat16)
SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, float)
SPECIALIZED_GRADIENT(MLFloat16, float, float)

} // namespace rocm
} // namespace onnxruntime
43 changes: 43 additions & 0 deletions orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "gsl/gsl"
#include "core/providers/rocm/rocm_kernel.h"
#include "core/providers/rocm/miopen_common.h"

namespace onnxruntime {
namespace rocm {

template <typename T, typename T1, typename T2>
class BatchNormalizationGrad final : public RocmKernel {
public:
BatchNormalizationGrad(const OpKernelInfo& info)
: RocmKernel{info},
miopen_batch_norm_mode_(miopenBNSpatial) {
float tmp_epsilon;
ORT_ENFORCE(info.GetAttr<float>("epsilon", &tmp_epsilon).IsOK());
epsilon_ = ClampMiopenBatchNormEpsilon(static_cast<double>(tmp_epsilon));

// spatial or not
int64_t tmp_spatial;
if (info.GetAttr<int64_t>("spatial", &tmp_spatial).IsOK()) {
spatial_ = tmp_spatial;
}

if (spatial_ == 0) {
miopen_batch_norm_mode_ = miopenBNPerActivation;
}
}

Status ComputeInternal(OpKernelContext* context) const override;

private:
double epsilon_;
int64_t spatial_ = 1; // default as per spec
miopenBatchNormMode_t miopen_batch_norm_mode_;
};

} // namespace rocm
} // namespace onnxruntime
Loading