Skip to content

Commit

Permalink
convert singrad to function op and remove cpu kernel (#13263)
Browse files Browse the repository at this point in the history
### Description
Implemented gradient of sin as a function op.

### Motivation and Context
Sin gradient currently implemented as cpu op which could hurt
performance.

### Testing
built ORT from source: `./build.sh --config RelWithDebInfo
--enable_training --use_cuda --cuda_home /usr/local/cuda --cudnn_home
/usr/local/cuda --build_wheel --parallel --skip_tests`
tested SinGrad implementation: `cd build/Linux/RelWithDebInfo/ &&
./onnxruntime_test_all --gtest_filter=GradientCheckerTest.SinGrad`

Co-authored-by: Prathik Rao <[email protected]>
Co-authored-by: Baiju Meswani <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2022
1 parent cd2e8b3 commit 05acd20
Show file tree
Hide file tree
Showing 6 changed files with 4 additions and 60 deletions.
8 changes: 4 additions & 4 deletions orttraining/orttraining/core/graph/gradient_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetCastGradient) {
}

IMPLEMENT_GRADIENT_BUILDER(GetSinGradient) {
return std::vector<NodeDef>{
NodeDef("SinGrad",
{GO(0), I(0)},
{GI(0)})};
std::vector<NodeDef> result;
result.push_back(NodeDef("Cos", {I(0)}, {IA("Cos_O0")}));
result.push_back(NodeDef("Mul", {GO(0), IA("Cos_O0")}, {GI(0)}));
return result;
}

IMPLEMENT_GRADIENT_BUILDER(GetLogGradient) {
Expand Down
18 changes: 0 additions & 18 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2476,24 +2476,6 @@ Example 4:
propagateElemTypeFromAttributeToOutput(ctx, "to", 0);
});

ONNX_CONTRIB_OPERATOR_SCHEMA(SinGrad)
.SetDomain(kOnnxDomain)
.SinceVersion(9)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("Gradient function for Sin")
.AllowUncheckedAttributes()
.Input(0, "dY", "Sin output's grad", "T")
.Input(1, "X", "Input tensor", "T")
.Output(0, "dX", "Sin input's grad", "T")
.TypeConstraint(
"T",
{"tensor(float16)", "tensor(float)", "tensor(double)"},
"Constrain input and output types to all numeric tensors.")
.FunctionBody(ONNX_NAMESPACE::FunctionBodyHelper::BuildNodes(
{// nodes: {outputs, op, inputs, attributes}
{{"X_1"}, "Cos", {"X"}},
{{"dX"}, "Mul", {"X_1", "dY"}}}));

ONNX_CONTRIB_OPERATOR_SCHEMA(SummaryScalar)
.SetDomain(kMSDomain)
.SinceVersion(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,5 @@
namespace onnxruntime {
namespace test {

TEST(ElementWiseOpGrad, SinGrad) {
OpTester test("SinGrad", 9);

test.AddInput<float>("dY", {3}, {0, 1, 2});
test.AddInput<float>("X", {3}, {-1, 0, 1});

test.AddOutput<float>("dX", {3}, {std::cos(-1.0f) * 0, std::cos(0.0f) * 1, std::cos(1.0f) * 2});
test.Run();
}
} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossInternal);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossInternalGrad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossInternalGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad);
Expand Down Expand Up @@ -151,7 +150,6 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossInternal)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossInternalGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossInternalGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad)>,
Expand Down
15 changes: 0 additions & 15 deletions orttraining/orttraining/training_ops/cpu/op_gradients.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,6 @@
namespace onnxruntime {
namespace contrib {

ONNX_CPU_OPERATOR_KERNEL(
SinGrad,
9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
SinGrad<float>);

template <typename T>
Status SinGrad<T>::Compute(OpKernelContext* context) const {
auto& dY = *context->Input<Tensor>(0);
auto& X = *context->Input<Tensor>(1);
auto& dX = *context->Output(0, X.Shape());
MakeEigenArrayMap<float>(dX) = MakeEigenArrayMap<float>(dY) * MakeEigenArrayMap<float>(X).cos();
return Status::OK();
}

ONNX_OPERATOR_KERNEL_EX(
ReluGrad,
kMSDomain,
Expand Down
12 changes: 0 additions & 12 deletions orttraining/orttraining/training_ops/cpu/op_gradients.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,6 @@
namespace onnxruntime {
namespace contrib {

template <typename T>
class SinGrad final : public OpKernel {
public:
explicit SinGrad(const OpKernelInfo& info) : OpKernel(info) {
}

Status Compute(OpKernelContext* context) const override;

private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SinGrad);
};

template <typename T>
class ReluGrad final : public OpKernel {
public:
Expand Down

0 comments on commit 05acd20

Please sign in to comment.