From 70b5522c909bd770e2dc3a1c36e310ee08de2b84 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 7 Jun 2024 07:52:30 +0000 Subject: [PATCH] Add operator MulSigmoid --- operators/cuda/cuda_ops.cc | 3 ++ operators/cuda/mul_sigmoid.h | 34 ++++++++++++++ operators/cuda/mul_sigmoid_impl.cu | 72 +++++++++++++++++++++++++++++ operators/cuda/mul_sigmoid_impl.cuh | 9 ++++ test/cuda/test_cudaops.py | 56 +++++++++++++++++++++- 5 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 operators/cuda/mul_sigmoid.h create mode 100644 operators/cuda/mul_sigmoid_impl.cu create mode 100644 operators/cuda/mul_sigmoid_impl.cuh diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index da7009075..f9017436b 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -6,6 +6,7 @@ #ifdef USE_CUDA #include "cuda/add_mul.h" #include "cuda/fast_gelu.h" +#include "cuda/mul_sigmoid.h" #include "cuda/negxplus1.h" #include "cuda/transpose_cast.h" #endif @@ -30,6 +31,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), + CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), #if ORT_API_VERSION >= 16 @@ -37,6 +39,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), + CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type) diff --git a/operators/cuda/mul_sigmoid.h b/operators/cuda/mul_sigmoid.h new file mode 100644 index 000000000..dc2ab203f --- /dev/null +++ b/operators/cuda/mul_sigmoid.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "mul_sigmoid_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +template +struct MulSigmoid { + template + OrtxStatus OnModelAttach(const TDict& /*dict*/) { + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + ortc::Tensor& output) const { + const T* input_data = input.Data(); + T* output_data = output.Allocate(input.Shape()); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return {}; + } + LaunchMulSigmoidKernel(reinterpret_cast(ctx->GetCudaStream()), + input_length, + input_data, + output_data); + return {}; + } +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/mul_sigmoid_impl.cu b/operators/cuda/mul_sigmoid_impl.cu new file mode 100644 index 000000000..9746cfa32 --- /dev/null +++ b/operators/cuda/mul_sigmoid_impl.cu @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "mul_sigmoid_impl.cuh" +#include "cuda_type.h" + +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +using namespace Ort::Custom; + +template __device__ __inline__ T _exp_typed(const T x); + +template <> __device__ __inline__ float _exp_typed(const float x) { return expf(x); } + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half _exp_typed(const half x) { + return __float2half(expf(__half2float(x))); +} +#else +template <> __device__ __inline__ half _exp_typed(const half x) { return hexp(x); } +#endif + +template __device__ __inline__ T sigmoid(const T a) { + return a > T(0) ? (T)1 / ((T)1. + _exp_typed(-a)) + : (T)1 - (T)1 / ((T)1 + _exp_typed(a)); +} + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half sigmoid(const half a) { + return __float2half(sigmoid(__half2float(a))); +} +#endif + +template __device__ __inline__ T mul_sigmoid(const T a) { return a * sigmoid(a); } + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half mul_sigmoid(const half a) { + float x = __half2float(a); + return __float2half(x * sigmoid(x)); +} +#endif + +template +__global__ void MulSigmoidKernel(T *output_data, const T *input_data, CUDA_LONG N) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= N) + return; + output_data[id] = mul_sigmoid(input_data[id]); +} + +template +cudaError_t _LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + using TT = typename contrib::CudaT::MappedType; + MulSigmoidKernel<<>>(reinterpret_cast(output), reinterpret_cast(input), input_length); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const float* input, float* output) { + return _LaunchMulSigmoidKernel(stream, input_length, input, output); +} + +template <> +cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) { + return _LaunchMulSigmoidKernel(stream, input_length, input, output); +} diff --git a/operators/cuda/mul_sigmoid_impl.cuh b/operators/cuda/mul_sigmoid_impl.cuh new file mode 100644 index 000000000..8b5f23b33 --- /dev/null +++ b/operators/cuda/mul_sigmoid_impl.cuh @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +template +cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output); \ No newline at end of file diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index 000f3796b..e97f83afd 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from numpy.testing import assert_almost_equal +from numpy.testing import assert_almost_equal, assert_allclose from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto from onnx.reference import ReferenceEvaluator from onnx.reference.op_run import OpRun @@ -118,6 +118,60 @@ def test_cuda_fastgelu_f16(self): else: print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.") + def _mul_sigmoid_cuda(self, itype): + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Sigmoid", ["X"], ["sx"]), + helper.make_node("Mul", ["X", "sx"], ["Y"]), + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "MulSigmoid", + ["X"], + ["Y"], + domain="ai.onnx.contrib", + ) + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(18) + 1).reshape((3, 2, 3)).astype(dtype) + + feeds1 = dict(X=x) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_allclose(expected, got, atol=1e-5 if itype == TensorProto.FLOAT else 1e-2) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mul_sigmoid_cuda(self): + self._mul_sigmoid_cuda(TensorProto.FLOAT) + self._mul_sigmoid_cuda(TensorProto.FLOAT16) + def _negxplus1_cuda(self, itype): dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 model1 = helper.make_model(