-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes to rel-1.9.0 to compile and pass for AMD ROCm (#9144)
* Revert "Fix nightly CI pipeline to generate ROCm 4.2 wheels and add ROCm 4.3.1 wheels (#9101)" This reverts commit 4788839. * Add BatchNorm kernel for ROCm (#9014) * Add BatchNorm kernel for ROCm, update BN test * correct epsilon_ setting; limit min epsilon * Upgrade ROCm CI pipeline for ROCm 4.3.1 and permit run inside container (#9070) * try to run inside 4.3.1 container * no \ in container run command * remove networking options * try with adding video render groups * add job to build docker image * try without 1st stage * change alpha, beta to float * try adding service connection * retain huggingface directory * static video and render gid * use runtime expression for variables * install torch-ort * pin sacrebleu==1.5.1 * update curves for rocm 4.3.1 * try again * disable determinism and only check tail of loss curve and with a much larger threshold of 0.05 * disable RoBERTa due to high run variablity on ROCm 4.3.1 * put reduction unit tests back in * Fix nightly CI pipeline to generate ROCm 4.2 wheels and add ROCm 4.3.1 wheels (#9101) * make work for both rocm 4.2 and rocm 4.3.1 * fix rocm 4.3.1 docker image reference * fix CUDA_VERSION to ROCM_VERSION * fix ReduceConsts conflict def * add ifdef to miopen_common.h as well * trailing ws Co-authored-by: wangye <[email protected]> Co-authored-by: mindest <[email protected]>
- Loading branch information
1 parent
66b3c31
commit 4daa14b
Showing
21 changed files
with
1,102 additions
and
64 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
137 changes: 137 additions & 0 deletions
137
orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "batch_norm_grad.h" | ||
#include "core/providers/common.h" | ||
#include "core/providers/rocm/miopen_common.h" | ||
#include "core/providers/cpu/nn/batch_norm_helper.h" | ||
#include "core/providers/rocm/math/unary_elementwise_ops_impl.h" | ||
|
||
using namespace std; | ||
namespace onnxruntime { | ||
namespace rocm { | ||
|
||
#define REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \ | ||
ONNX_OPERATOR_TYPED_KERNEL_EX( \ | ||
BatchNormalizationGrad, \ | ||
kMSDomain, \ | ||
1, \ | ||
T##_##T1##_##T2, \ | ||
kRocmExecutionProvider, \ | ||
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()) \ | ||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T1>()) \ | ||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T2>()), \ | ||
BatchNormalizationGrad<T, T1, T2>); | ||
|
||
template <typename T, typename T1, typename T2> | ||
Status BatchNormalizationGrad<T, T1, T2>::ComputeInternal(OpKernelContext* ctx) const { | ||
typedef typename ToHipType<T>::MappedType HipT; | ||
typedef typename ToHipType<T1>::MappedType HipT1; | ||
typedef typename ToHipType<T2>::MappedType HipT2; | ||
|
||
const Tensor* dY = ctx->Input<Tensor>(0); | ||
const Tensor* X = ctx->Input<Tensor>(1); | ||
const Tensor* Scale = ctx->Input<Tensor>(2); | ||
const Tensor* saved_mean = ctx->Input<Tensor>(3); | ||
// miopenBatchNormalizationBackward() claims to use `savedInvVariance`, but the value | ||
// is actually equal to the batch inv_std, so we use name `saved_inv_std` here. | ||
const Tensor* saved_inv_std = ctx->Input<Tensor>(4); | ||
const TensorShape input_shape = X->Shape(); | ||
const TensorShape channel_shape = saved_mean->Shape(); | ||
|
||
// no B here, but B has same size as Scale, so can validate inputs for gradient with this substitute | ||
ORT_RETURN_IF_ERROR(BatchNormHelper::ValidateInputs(X, Scale, Scale, saved_mean, saved_inv_std)); | ||
|
||
auto dY_data = reinterpret_cast<const HipT*>(dY->template Data<T>()); | ||
auto X_data = reinterpret_cast<const HipT*>(X->template Data<T>()); | ||
auto Scale_data = reinterpret_cast<const HipT1*>(Scale->template Data<T1>()); | ||
auto saved_mean_data = reinterpret_cast<const HipT2*>(saved_mean->template Data<T2>()); | ||
auto saved_inv_std_data = reinterpret_cast<const HipT2*>(saved_inv_std->template Data<T2>()); | ||
|
||
auto dX_data = reinterpret_cast<HipT*>(ctx->Output(0, input_shape)->template MutableData<T>()); | ||
auto dScale_data = reinterpret_cast<HipT1*>(ctx->Output(1, channel_shape)->template MutableData<T1>()); | ||
auto dBias_data = reinterpret_cast<HipT1*>(ctx->Output(2, channel_shape)->template MutableData<T1>()); | ||
|
||
const auto alpha = Consts<HipT>::One; | ||
const auto beta = Consts<HipT>::Zero; | ||
|
||
MiopenTensor input_tensor, scale_bias_tensor; | ||
vector<int64_t> new_dims; | ||
BatchNormHelper::NormalizeDims(input_shape, new_dims); | ||
ORT_RETURN_IF_ERROR(input_tensor.Set(new_dims, MiopenTensor::GetDataType<HipT>())); | ||
ORT_RETURN_IF_ERROR(scale_bias_tensor.Set(input_tensor, miopen_batch_norm_mode_)); | ||
|
||
const int64_t C = new_dims[1]; | ||
auto p_scale = reinterpret_cast<const void*>(Scale_data); | ||
auto p_saved_mean = reinterpret_cast<const void*>(saved_mean_data); | ||
auto p_saved_inv_std = reinterpret_cast<const void*>(saved_inv_std_data); | ||
auto p_dScale = reinterpret_cast<void*>(dScale_data); | ||
auto p_dBias = reinterpret_cast<void*>(dBias_data); | ||
|
||
IAllocatorUniquePtr<float> p_f_scale, p_f_dScale, p_f_dBias, p_f_saved_mean, p_f_saved_inv_std; | ||
|
||
if (std::is_same<T1, MLFloat16>::value) { | ||
p_f_scale = GetScratchBuffer<float>(C); | ||
p_f_dScale = GetScratchBuffer<float>(C); | ||
p_f_dBias = GetScratchBuffer<float>(C); | ||
|
||
Impl_Cast<HipT1, float>(Stream(), Scale_data, p_f_scale.get(), C); | ||
|
||
p_scale = p_f_scale.get(); | ||
p_dScale = p_f_dScale.get(); | ||
p_dBias = p_f_dBias.get(); | ||
} | ||
|
||
if (std::is_same<T2, MLFloat16>::value) { | ||
p_f_saved_mean = GetScratchBuffer<float>(C); | ||
p_f_saved_inv_std = GetScratchBuffer<float>(C); | ||
|
||
Impl_Cast<HipT2, float>(Stream(), saved_mean_data, p_f_saved_mean.get(), C); | ||
Impl_Cast<HipT2, float>(Stream(), saved_inv_std_data, p_f_saved_inv_std.get(), C); | ||
|
||
p_saved_mean = p_f_saved_mean.get(); | ||
p_saved_inv_std = p_f_saved_inv_std.get(); | ||
} | ||
|
||
MIOPEN_RETURN_IF_ERROR(miopenBatchNormalizationBackward( | ||
MiopenHandle(), | ||
miopen_batch_norm_mode_, | ||
&alpha, | ||
&beta, | ||
&alpha, | ||
&beta, | ||
input_tensor, | ||
X_data, | ||
input_tensor, | ||
dY_data, | ||
input_tensor, | ||
dX_data, | ||
scale_bias_tensor, | ||
p_scale, | ||
p_dScale, | ||
p_dBias, | ||
epsilon_, | ||
p_saved_mean, | ||
p_saved_inv_std)); | ||
|
||
if (std::is_same<T1, MLFloat16>::value) { | ||
Impl_Cast<float, HipT1>(Stream(), reinterpret_cast<float*>(p_dScale), dScale_data, C); | ||
Impl_Cast<float, HipT1>(Stream(), reinterpret_cast<float*>(p_dBias), dBias_data, C); | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
#define SPECIALIZED_GRADIENT(T, T1, T2) \ | ||
REGISTER_GRADIENT_KERNEL_TYPED(T, T1, T2) \ | ||
template Status BatchNormalizationGrad<T, T1, T2>::ComputeInternal(OpKernelContext* ctx) const; | ||
|
||
SPECIALIZED_GRADIENT(float, float, float) | ||
// MIOpen kernel does not support double, disable for now. | ||
// SPECIALIZED_GRADIENT(double, double, double) | ||
SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, MLFloat16) | ||
SPECIALIZED_GRADIENT(MLFloat16, MLFloat16, float) | ||
SPECIALIZED_GRADIENT(MLFloat16, float, float) | ||
|
||
} // namespace rocm | ||
} // namespace onnxruntime |
43 changes: 43 additions & 0 deletions
43
orttraining/orttraining/training_ops/rocm/nn/batch_norm_grad.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "gsl/gsl" | ||
#include "core/providers/rocm/rocm_kernel.h" | ||
#include "core/providers/rocm/miopen_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace rocm { | ||
|
||
template <typename T, typename T1, typename T2> | ||
class BatchNormalizationGrad final : public RocmKernel { | ||
public: | ||
BatchNormalizationGrad(const OpKernelInfo& info) | ||
: RocmKernel{info}, | ||
miopen_batch_norm_mode_(miopenBNSpatial) { | ||
float tmp_epsilon; | ||
ORT_ENFORCE(info.GetAttr<float>("epsilon", &tmp_epsilon).IsOK()); | ||
epsilon_ = ClampMiopenBatchNormEpsilon(static_cast<double>(tmp_epsilon)); | ||
|
||
// spatial or not | ||
int64_t tmp_spatial; | ||
if (info.GetAttr<int64_t>("spatial", &tmp_spatial).IsOK()) { | ||
spatial_ = tmp_spatial; | ||
} | ||
|
||
if (spatial_ == 0) { | ||
miopen_batch_norm_mode_ = miopenBNPerActivation; | ||
} | ||
} | ||
|
||
Status ComputeInternal(OpKernelContext* context) const override; | ||
|
||
private: | ||
double epsilon_; | ||
int64_t spatial_ = 1; // default as per spec | ||
miopenBatchNormMode_t miopen_batch_norm_mode_; | ||
}; | ||
|
||
} // namespace rocm | ||
} // namespace onnxruntime |
Oops, something went wrong.