From d58c35c6afcae5a2c6beed214ba528285992d6c3 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 6 Jun 2024 15:12:06 +0000 Subject: [PATCH] Add custom ops ReplaceZero --- operators/cuda/cuda_ops.cc | 5 +- operators/cuda/replace_zero.h | 42 +++++++++++++++++ operators/cuda/replace_zero_impl.cu | 64 ++++++++++++++++++++++++++ operators/cuda/replace_zero_impl.cuh | 9 ++++ test/cuda/test_cudaops.py | 68 ++++++++++++++++++++++++++-- 5 files changed, 183 insertions(+), 5 deletions(-) create mode 100644 operators/cuda/replace_zero.h create mode 100644 operators/cuda/replace_zero_impl.cu create mode 100644 operators/cuda/replace_zero_impl.cuh diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index f8269302a..f1d3f041a 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -7,6 +7,7 @@ #include "cuda/add_mul.h" #include "cuda/fast_gelu.h" #include "cuda/negxplus1.h" +#include "cuda/replace_zero.h" #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { @@ -28,13 +29,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero), #if ORT_API_VERSION >= 16 CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), - CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1) + CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero) #endif #endif ); diff --git a/operators/cuda/replace_zero.h b/operators/cuda/replace_zero.h new file mode 100644 index 000000000..d75f7dd20 --- /dev/null +++ b/operators/cuda/replace_zero.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "replace_zero_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +template +struct ReplaceZero { + template + OrtxStatus OnModelAttach(const TDict& dict) { + float default_value=0; + by_ = dict.TryToGetAttributeWithDefault("by", default_value); + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + ortc::Tensor& output) const { + const T* input_data = input.Data(); + auto input_shape = input.Shape(); + T* output_data = output.Allocate(input_shape); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return {}; + } + + LaunchReplaceZeroKernel(reinterpret_cast(ctx->GetCudaStream()), + input_length, + input_data, + output_data, + by_); + return {}; + } + + private: + float by_; +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/replace_zero_impl.cu b/operators/cuda/replace_zero_impl.cu new file mode 100644 index 000000000..f0c1414f2 --- /dev/null +++ b/operators/cuda/replace_zero_impl.cu @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "replace_zero_impl.cuh" +#include "cuda_type.h" + +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +using namespace Ort::Custom; + +template __device__ __inline__ T _replace_zero(const T x, const T by) { + return x == (T)0 ? by : x; +} + +template <> __device__ __inline__ half _replace_zero(const half x, const half by) { +#if __CUDA_ARCH__ < 700 + return __half2float(x) == 0 ? by : x; +#else + return x == (half)0 ? by : x; +#endif +} + +template +__global__ void ReplaceZeroKernel(T *output_data, const T *input_data, CUDA_LONG N, const T by) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= N) + return; + output_data[id] = _replace_zero(input_data[id], by); +} + +template T _cvt(float value) { return (T)value; } + +template <> half _cvt(float value) { return __float2half(value); } + +template +cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) { + if (input_length == 0) + return cudaGetLastError(); + using TT = typename contrib::CudaT::MappedType; + + CUDA_LONG N = static_cast(input_length); + + const int num_threads_per_block = 256; + const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block; + + TT cby = _cvt(by); + ReplaceZeroKernel<<>>( + reinterpret_cast(output_data), reinterpret_cast(input_data), N, cby); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const float* input_data, float* output_data, float by) { + return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by); +} + +template <> +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const ortc::MFloat16* input_data, ortc::MFloat16* output_data, float by) { + return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by); +} diff --git a/operators/cuda/replace_zero_impl.cuh b/operators/cuda/replace_zero_impl.cuh new file mode 100644 index 000000000..048a39363 --- /dev/null +++ b/operators/cuda/replace_zero_impl.cuh @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +template +cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by); diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index df1ab2c47..57fa26f0a 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from numpy.testing import assert_almost_equal +from numpy.testing import assert_almost_equal, assert_allclose from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto from onnx.reference import ReferenceEvaluator from onnx.reference.op_run import OpRun @@ -151,7 +151,7 @@ def test_cuda_negxplus1(self): self._negxplus1_cuda(TensorProto.FLOAT16) def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)): - from onnx_extended.ortops.optim.cuda import get_ort_ext_libs + from ai.onnx.contrib import get_ort_ext_libs model1 = helper.make_model( helper.make_graph( @@ -181,7 +181,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, f"{op_type}SharedInput", ["X", "Y", "Z"], ["XY", "XZ"], - domain="onnx_extended.ortops.optim.cuda", + domain="ai.onnx.contrib", ) ], "nd", @@ -197,7 +197,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, ), opset_imports=[ helper.make_opsetid("", 18), - helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + helper.make_opsetid("ai.onnx.contrib", 1), ], ir_version=9, ) @@ -262,6 +262,66 @@ def test_add_shared_input_cuda_broadcast2(self): shapec=(3, 2, 3), ) + def _replace_zero_cuda(self, itype): + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Equal", ["X", "zero"], ["cond"]), + helper.make_node("Where", ["cond", "cst", "X"], ["Y"]), + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + [ + numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"), + numpy_helper.from_array(np.array([1.67], dtype=dtype), name="cst"), + ], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "ReplaceZero", + ["X"], + ["Y"], + by=1.67, + domain="ai.onnx.contrib", + ) + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None, None])], + [helper.make_tensor_value_info("Y", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype) + + feeds1 = dict(X=x) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_allclose(expected, got, atol=1e-5) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_replace_zero_cuda(self): + self._replace_zero_cuda(TensorProto.FLOAT) + self._replace_zero_cuda(TensorProto.FLOAT16) + if __name__ == "__main__": unittest.main()