Skip to content

Commit

Permalink
Add operator MulSigmoid
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 7, 2024
1 parent 79f3b04 commit 70b5522
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 1 deletion.
3 changes: 3 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifdef USE_CUDA
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/mul_sigmoid.h"
#include "cuda/negxplus1.h"
#include "cuda/transpose_cast.h"
#endif
Expand All @@ -30,13 +31,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
Expand Down
34 changes: 34 additions & 0 deletions operators/cuda/mul_sigmoid.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "mul_sigmoid_impl.cuh"
#include "ortx_common.h"

namespace contrib {

template <typename T>
struct MulSigmoid {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
T* output_data = output.Allocate(input.Shape());
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}
LaunchMulSigmoidKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data);
return {};
}
};

} // namespace contrib
72 changes: 72 additions & 0 deletions operators/cuda/mul_sigmoid_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "mul_sigmoid_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _exp_typed(const T x);

template <> __device__ __inline__ float _exp_typed(const float x) { return expf(x); }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _exp_typed(const half x) {
return __float2half(expf(__half2float(x)));
}
#else
template <> __device__ __inline__ half _exp_typed(const half x) { return hexp(x); }
#endif

template <typename T> __device__ __inline__ T sigmoid(const T a) {
return a > T(0) ? (T)1 / ((T)1. + _exp_typed<T>(-a))
: (T)1 - (T)1 / ((T)1 + _exp_typed<T>(a));
}

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half sigmoid(const half a) {
return __float2half(sigmoid(__half2float(a)));
}
#endif

template <typename T> __device__ __inline__ T mul_sigmoid(const T a) { return a * sigmoid(a); }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half mul_sigmoid(const half a) {
float x = __half2float(a);
return __float2half(x * sigmoid(x));
}
#endif

template <typename T>
__global__ void MulSigmoidKernel(T *output_data, const T *input_data, CUDA_LONG N) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = mul_sigmoid(input_data[id]);
}

template <typename T>
cudaError_t _LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
using TT = typename contrib::CudaT<T>::MappedType;
MulSigmoidKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
return cudaGetLastError();
}

template <>
cudaError_t LaunchMulSigmoidKernel<float>(cudaStream_t stream, int input_length, const float* input, float* output) {
return _LaunchMulSigmoidKernel(stream, input_length, input, output);
}

template <>
cudaError_t LaunchMulSigmoidKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) {
return _LaunchMulSigmoidKernel(stream, input_length, input, output);
}
9 changes: 9 additions & 0 deletions operators/cuda/mul_sigmoid_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

template <typename T>
cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output);
56 changes: 55 additions & 1 deletion test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import numpy as np
from numpy.testing import assert_almost_equal
from numpy.testing import assert_almost_equal, assert_allclose
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
Expand Down Expand Up @@ -118,6 +118,60 @@ def test_cuda_fastgelu_f16(self):
else:
print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.")

def _mul_sigmoid_cuda(self, itype):
model1 = helper.make_model(
helper.make_graph(
[
helper.make_node("Sigmoid", ["X"], ["sx"]),
helper.make_node("Mul", ["X", "sx"], ["Y"]),
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)

model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"MulSigmoid",
["X"],
["Y"],
domain="ai.onnx.contrib",
)
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(18) + 1).reshape((3, 2, 3)).astype(dtype)

feeds1 = dict(X=x)
ref = ReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]

opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
assert_allclose(expected, got, atol=1e-5 if itype == TensorProto.FLOAT else 1e-2)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_mul_sigmoid_cuda(self):
self._mul_sigmoid_cuda(TensorProto.FLOAT)
self._mul_sigmoid_cuda(TensorProto.FLOAT16)

def _negxplus1_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
model1 = helper.make_model(
Expand Down

0 comments on commit 70b5522

Please sign in to comment.