diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 2078e57bf3555..cad54acc2d185 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -76,10 +76,10 @@ IMPLEMENT_GRADIENT_BUILDER(GetCastGradient) { } IMPLEMENT_GRADIENT_BUILDER(GetSinGradient) { - return std::vector{ - NodeDef("SinGrad", - {GO(0), I(0)}, - {GI(0)})}; + std::vector result; + result.push_back(NodeDef("Cos", {I(0)}, {IA("Cos_O0")})); + result.push_back(NodeDef("Mul", {GO(0), IA("Cos_O0")}, {GI(0)})); + return result; } IMPLEMENT_GRADIENT_BUILDER(GetLogGradient) { diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index c18a1000c8167..7fba133954637 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -2476,24 +2476,6 @@ Example 4: propagateElemTypeFromAttributeToOutput(ctx, "to", 0); }); - ONNX_CONTRIB_OPERATOR_SCHEMA(SinGrad) - .SetDomain(kOnnxDomain) - .SinceVersion(9) - .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) - .SetDoc("Gradient function for Sin") - .AllowUncheckedAttributes() - .Input(0, "dY", "Sin output's grad", "T") - .Input(1, "X", "Input tensor", "T") - .Output(0, "dX", "Sin input's grad", "T") - .TypeConstraint( - "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input and output types to all numeric tensors.") - .FunctionBody(ONNX_NAMESPACE::FunctionBodyHelper::BuildNodes( - {// nodes: {outputs, op, inputs, attributes} - {{"X_1"}, "Cos", {"X"}}, - {{"dX"}, "Mul", {"X_1", "dY"}}})); - ONNX_CONTRIB_OPERATOR_SCHEMA(SummaryScalar) .SetDomain(kMSDomain) .SinceVersion(1) diff --git a/orttraining/orttraining/test/training_ops/cpu/math/element_wise_op_grad_test.cc b/orttraining/orttraining/test/training_ops/cpu/math/element_wise_op_grad_test.cc index 3390df8e407fb..8cccfb4d1f301 100644 --- a/orttraining/orttraining/test/training_ops/cpu/math/element_wise_op_grad_test.cc +++ b/orttraining/orttraining/test/training_ops/cpu/math/element_wise_op_grad_test.cc @@ -7,14 +7,5 @@ namespace onnxruntime { namespace test { -TEST(ElementWiseOpGrad, SinGrad) { - OpTester test("SinGrad", 9); - - test.AddInput("dY", {3}, {0, 1, 2}); - test.AddInput("X", {3}, {-1, 0, 1}); - - test.AddOutput("dX", {3}, {std::cos(-1.0f) * 0, std::cos(0.0f) * 1, std::cos(1.0f) * 2}); - test.Run(); -} } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc index c30851425e6b3..66a28f61c17c1 100644 --- a/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cpu/cpu_training_kernels.cc @@ -37,7 +37,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossInternal); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int32_t, SoftmaxCrossEntropyLossInternalGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float_int64_t, SoftmaxCrossEntropyLossInternalGrad); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, SinGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ConvGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, ReluGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, SoftmaxGrad); @@ -151,7 +150,6 @@ Status RegisterCpuTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.cc b/orttraining/orttraining/training_ops/cpu/op_gradients.cc index e14c12bc01f47..bbc018a83124c 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.cc +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.cc @@ -15,21 +15,6 @@ namespace onnxruntime { namespace contrib { -ONNX_CPU_OPERATOR_KERNEL( - SinGrad, - 9, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - SinGrad); - -template -Status SinGrad::Compute(OpKernelContext* context) const { - auto& dY = *context->Input(0); - auto& X = *context->Input(1); - auto& dX = *context->Output(0, X.Shape()); - MakeEigenArrayMap(dX) = MakeEigenArrayMap(dY) * MakeEigenArrayMap(X).cos(); - return Status::OK(); -} - ONNX_OPERATOR_KERNEL_EX( ReluGrad, kMSDomain, diff --git a/orttraining/orttraining/training_ops/cpu/op_gradients.h b/orttraining/orttraining/training_ops/cpu/op_gradients.h index 4e519a7622851..4a6ec568b4da6 100644 --- a/orttraining/orttraining/training_ops/cpu/op_gradients.h +++ b/orttraining/orttraining/training_ops/cpu/op_gradients.h @@ -9,18 +9,6 @@ namespace onnxruntime { namespace contrib { -template -class SinGrad final : public OpKernel { - public: - explicit SinGrad(const OpKernelInfo& info) : OpKernel(info) { - } - - Status Compute(OpKernelContext* context) const override; - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SinGrad); -}; - template class ReluGrad final : public OpKernel { public: