diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index f9017436b..e6f012edd 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -31,6 +31,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), + CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid), CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), #if ORT_API_VERSION >= 16 @@ -39,6 +40,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), + CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid), CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), diff --git a/operators/cuda/mul_sigmoid.h b/operators/cuda/mul_sigmoid.h index dc2ab203f..433e4e8ba 100644 --- a/operators/cuda/mul_sigmoid.h +++ b/operators/cuda/mul_sigmoid.h @@ -31,4 +31,32 @@ struct MulSigmoid { } }; +template +struct MulMulSigmoid { + template + OrtxStatus OnModelAttach(const TDict& /*dict*/) { + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input_x, + const ortc::Tensor& input_y, + ortc::Tensor& output) const { + const T* input_data_x = input_x.Data(); + const T* input_data_y = input_y.Data(); + auto input_length_x = input_x.NumberOfElement(); + auto input_length_y = input_y.NumberOfElement(); + if (0 == input_length_x || 0 == input_data_y) { + return {}; + } + T* output_data = output.Allocate(input_length_x > input_length_y ? input_x.Shape() : input_y.Shape()); + LaunchMulMulSigmoidKernel(reinterpret_cast(ctx->GetCudaStream()), + input_length_x, + input_length_y, + input_data_x, + input_data_y, + output_data); + return {}; + } +}; + } // namespace contrib \ No newline at end of file diff --git a/operators/cuda/mul_sigmoid_impl.cu b/operators/cuda/mul_sigmoid_impl.cu index 9746cfa32..c2c14fe0f 100644 --- a/operators/cuda/mul_sigmoid_impl.cu +++ b/operators/cuda/mul_sigmoid_impl.cu @@ -44,6 +44,17 @@ template <> __device__ __inline__ half mul_sigmoid(const half a) { } #endif +template __device__ __inline__ T mul_mul_sigmoid(const T x, const T y) { + return x * y * sigmoid(y); +} + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half mul_mul_sigmoid(const half x, const half y) { + float hy = __half2float(y); + return __float2half(__half2float(x) * hy * sigmoid(hy)); +} +#endif + template __global__ void MulSigmoidKernel(T *output_data, const T *input_data, CUDA_LONG N) { CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; @@ -52,6 +63,15 @@ __global__ void MulSigmoidKernel(T *output_data, const T *input_data, CUDA_LONG output_data[id] = mul_sigmoid(input_data[id]); } +template +__global__ void MulMulSigmoidKernel(T *output_data, const T *px, const T *py, CUDA_LONG N, + CUDA_LONG Nx, CUDA_LONG Ny) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= N) + return; + output_data[id] = mul_mul_sigmoid(px[id % Nx], py[id % Ny]); +} + template cudaError_t _LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output) { constexpr int blockSize = 256; @@ -70,3 +90,30 @@ template <> cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) { return _LaunchMulSigmoidKernel(stream, input_length, input, output); } + +template +cudaError_t _LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y, + const T* input_data_x, const T* input_data_y, T* output) { + int input_length = std::max(input_length_x, input_length_y); + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + using TT = typename contrib::CudaT::MappedType; + MulMulSigmoidKernel<<>>(reinterpret_cast(output), + reinterpret_cast(input_data_x), + reinterpret_cast(input_data_y), + input_length, input_length_x, input_length_y); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y, + const float* input_data_x, const float* input_data_y, float* output) { + return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output); +} + +template <> +cudaError_t LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y, + const ortc::MFloat16* input_data_x, const ortc::MFloat16* input_data_y, + ortc::MFloat16* output) { + return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output); +} diff --git a/operators/cuda/mul_sigmoid_impl.cuh b/operators/cuda/mul_sigmoid_impl.cuh index 8b5f23b33..661c5a04c 100644 --- a/operators/cuda/mul_sigmoid_impl.cuh +++ b/operators/cuda/mul_sigmoid_impl.cuh @@ -6,4 +6,8 @@ #include template -cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output); \ No newline at end of file +cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output); + +template +cudaError_t LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y, + const T* input_data_x, const T* input_data_y, T* output); diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index e97f83afd..6fffd0cc6 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -118,6 +118,77 @@ def test_cuda_fastgelu_f16(self): else: print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.") + def _mulmulsigmoid_cuda(self, itype, broad=False, atol=1e-5, rtol=1e-3): + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Mul", ["X", "Y"], ["xy"]), + helper.make_node("Sigmoid", ["Y"], ["sy"]), + helper.make_node("Mul", ["xy", "sy"], ["final"]), + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "MulMulSigmoid", + ["X", "Y"], + ["final"], + domain="ai.onnx.contrib", + ) + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + shapex = (1, 2, 3) if broad else (3, 2, 3) + shapey = (3, 2, 3) + x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype) + y = (np.arange(np.prod(shapey)) + 2).reshape(shapey).astype(dtype) + x /= x.size + y /= y.size + + feeds1 = dict(X=x, Y=y) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_allclose(expected, got, atol=atol, rtol=rtol) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulmulsigmoid_cuda(self): + self._mulmulsigmoid_cuda(TensorProto.FLOAT) + self._mulmulsigmoid_cuda(TensorProto.FLOAT16) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulmulsigmoid_cuda_broadcast(self): + self._mulmulsigmoid_cuda(TensorProto.FLOAT, True) + self._mulmulsigmoid_cuda(TensorProto.FLOAT16, True) + def _mul_sigmoid_cuda(self, itype): model1 = helper.make_model( helper.make_graph(