diff --git a/operators/cuda/add_mul.h b/operators/cuda/add_mul.h index 5b8f36f21..3f6cc1e64 100644 --- a/operators/cuda/add_mul.h +++ b/operators/cuda/add_mul.h @@ -43,4 +43,41 @@ struct AddOrMulSharedInput { } }; +template +struct AddOrMulTwice { + template + OrtxStatus OnModelAttach(const TDict& /*dict*/) { + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& tensor_a, + const ortc::Tensor& tensor_b, + const ortc::Tensor& tensor_c, + ortc::Tensor& output) const { + const T* input_data_a = tensor_a.Data(); + const T* input_data_b = tensor_b.Data(); + const T* input_data_c = tensor_c.Data(); + + auto length_a = tensor_a.NumberOfElement(); + auto length_b = tensor_b.NumberOfElement(); + auto length_c = tensor_c.NumberOfElement(); + + T* output_data_ab = output_ab.Allocate( + length_a <= length_b + ? lenght_c <= length_b ? tensor_b.Shape() : tensor_c.Shape() + : lenght_a <= length_b ? tensor_b.Shape() : tensor_a.Shape()); + + if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { + return {}; + } + LaunchAddOrMulTwiceKernel(reinterpret_cast(ctx->GetCudaStream()), + input_data_a, input_data_b, input_data_c, + output_data, + length_a, length_b, length_c, + addition); + return {}; + } +}; + + } // namespace contrib \ No newline at end of file diff --git a/operators/cuda/add_mul_impl.cu b/operators/cuda/add_mul_impl.cu index 85f55bc77..0373ddf28 100644 --- a/operators/cuda/add_mul_impl.cu +++ b/operators/cuda/add_mul_impl.cu @@ -12,12 +12,12 @@ using namespace Ort::Custom; -__device__ __forceinline__ void _add3_op(float* ab, float* ac, const float a, const float b, const float c) { +__device__ __forceinline__ void _add3_2_op(float* ab, float* ac, const float a, const float b, const float c) { *ab = a + b; *ac = a + c; } -__device__ __forceinline__ void _add3_op(half* ab, half* ac, const half a, const half b, const half c) { +__device__ __forceinline__ void _add3_2_op(half* ab, half* ac, const half a, const half b, const half c) { #if __CUDA_ARCH__ < 700 *ab = __float2half(__half2float(a) + __half2float(b)); *ac = __float2half(__half2float(a) + __half2float(c)); @@ -27,12 +27,12 @@ __device__ __forceinline__ void _add3_op(half* ab, half* ac, const half a, const #endif } -__device__ __forceinline__ void _mul3_op(float* ab, float* ac, const float a, const float b, const float c) { +__device__ __forceinline__ void _mul3_2_op(float* ab, float* ac, const float a, const float b, const float c) { *ab = a * b; *ac = a * c; } -__device__ __forceinline__ void _mul3_op(half* ab, half* ac, const half a, const half b, const half c) { +__device__ __forceinline__ void _mul3_2_op(half* ab, half* ac, const half a, const half b, const half c) { #if __CUDA_ARCH__ < 700 *ab = __float2half(__half2float(a) * __half2float(b)); *ac = __float2half(__half2float(a) * __half2float(c)); @@ -45,21 +45,21 @@ __device__ __forceinline__ void _mul3_op(half* ab, half* ac, const half a, const template struct Mul3SharedOp { __device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const { - _mul3_op(ab, ac, a, b, c); + _mul3_2_op(ab, ac, a, b, c); } }; template struct Add3SharedOp { __device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const { - _add3_op(ab, ac, a, b, c); + _add3_2_op(ab, ac, a, b, c); } }; template -__global__ void AddMulKernel(T* output_ab, T* output_ac, const T* pA, const T* pB, - const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, - CUDA_LONG N, const TFunc func) { +__global__ void AddMulSharedInputKernel(T* output_ab, T* output_ac, const T* pA, const T* pB, + const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, + CUDA_LONG N, const TFunc func) { CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; CUDA_LONG id = start; #pragma unroll @@ -89,14 +89,14 @@ cudaError_t _LaunchAddOrMulSharedInputKernel(cudaStream_t stream, using TT = typename contrib::CudaT::MappedType; if (addition) { - AddMulKernel, num_threads_per_block, num_elements_per_thread> + AddMulSharedInputKernel, num_threads_per_block, num_elements_per_thread> <<>>( reinterpret_cast(output_ab), reinterpret_cast(output_ac), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), static_cast(countB), static_cast(countC), static_cast(max_count), Add3SharedOp()); } else { - AddMulKernel, num_threads_per_block, num_elements_per_thread> + AddMulSharedInputKernel, num_threads_per_block, num_elements_per_thread> <<>>( reinterpret_cast(output_ab), reinterpret_cast(output_ac), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), @@ -107,15 +107,132 @@ cudaError_t _LaunchAddOrMulSharedInputKernel(cudaStream_t stream, } template <> -cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, +cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, + const float* input_a, const float* input_b, const float* input_c, float* output_ab, float* output_ac, int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { - return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition); + return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, + output_ab, output_ac, + length_a, length_b, length_c, addition); } template <> -cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, +cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, ortc::MFloat16* output_ab, ortc::MFloat16* output_ac, int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { - return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition); + return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, + output_ab, output_ac, + length_a, length_b, length_c, addition); } + +__device__ __forceinline__ void _add3_op(float *address, const float a, const float b, + const float c) { + *address = a + b + c; +} + +__device__ __forceinline__ void _add3_op(half *address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) + __half2float(b) + __half2float(c)); +#else + *address = a + b + c; +#endif +} + +__device__ __forceinline__ void _mul3_op(float *address, const float a, const float b, + const float c) { + *address = a * b * c; +} + +__device__ __forceinline__ void _mul3_op(half *address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) * __half2float(b) * __half2float(c)); +#else + *address = a * b * c; +#endif +} + +template struct Mul3Op { + __device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const { + _mul3_op(address, a, b, c); + } +}; + +template struct Add3Op { + __device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const { + _add3_op(address, a, b, c); + } +}; + +template +__global__ void AddMulTwiceKernel(T* output, const T* pA, const T* pB, + const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, + CUDA_LONG N, const TFunc func) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + func(output_ab, pA[id % nA], pB[id % nB], pC[id % nC]); + id += NumThreadsPerBlock; + } + } +} + +template +cudaError_t _LaunchAddOrMulTwiceKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, bool addition) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition) { + AddMulTwiceKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), + static_cast(countA), static_cast(countB), static_cast(countC), + static_cast(max_count), Add3SharedOp()); + } else { + AddMulTwiceKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + reinterpret_cast(output), + reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), + static_cast(countB), static_cast(countC), + static_cast(max_count), Mul3SharedOp()); + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, + const float* input_a, const float* input_b, const float* input_c, + float* output, + int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { + return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, + output, + length_a, length_b, length_c, addition); +} + +template <> +cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, + ortc::MFloat16* output, + int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { + return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, + output, + length_a, length_b, length_c, addition); +} + diff --git a/operators/cuda/add_mul_impl.cuh b/operators/cuda/add_mul_impl.cuh index 9bf3ad853..6fdf9369e 100644 --- a/operators/cuda/add_mul_impl.cuh +++ b/operators/cuda/add_mul_impl.cuh @@ -8,4 +8,8 @@ template cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, T* output_ab, T* output_ac, - int64_t length_a, int64_t length_b, int64_t length_c, bool addition); \ No newline at end of file + int64_t length_a, int64_t length_b, int64_t length_c, bool addition); + +template +cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition); \ No newline at end of file diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index f8269302a..8c4808a5a 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -14,9 +14,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput; + using AddTwiceFloat32Type = typename contrib::AddOrMulTwice; + using MulTwiceFloat32Type = typename contrib::AddOrMulTwice; + #if ORT_API_VERSION >= 16 using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; + + using AddTwiceFloat16Type = typename contrib::AddOrMulTwice; + using MulTwiceFloat16Type = typename contrib::AddOrMulTwice; #endif @@ -25,15 +31,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { #ifdef USE_CUDA , CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type), + CustomCudaStructV2("AddTwice", AddTwiceFloat32Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), + CustomCudaStructV2("MulTwice", MulTwiceFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), #if ORT_API_VERSION >= 16 CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type), + CustomCudaStructV2("AddTwice", AddTwiceFloat16Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), + CustomCudaStructV2("MulTwice", MulTwiceFloat16Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1) #endif #endif diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index df1ab2c47..7955795db 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -22,6 +22,83 @@ def _run(self, X): class TestCudaOps(unittest.TestCase): + def _addaddmulmul_cuda(self, itype, op_type, broad=False): + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node(op_type, ["X", "Y"], ["xy"]), + helper.make_node(op_type, ["xy", "Z"], ["final"]), + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + f"{op_type}{op_type}", + ["X", "Y", "Z"], + ["final"], + domain="ai.onnx.contrib", + ) + ], + "nd", + [ + helper.make_tensor_value_info("X", itype, [None, None, None]), + helper.make_tensor_value_info("Y", itype, [None, None, None]), + helper.make_tensor_value_info("Z", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("final", itype, [None, None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + shapex = (1, 2, 3) if broad else (3, 2, 3) + shapey = (3, 2, 3) + shapez = (1, 2, 3) if broad else (3, 2, 3) + x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype) + y = (np.arange(np.prod(shapey)) + 1).reshape(shapey).astype(dtype) + z = (np.arange(np.prod(shapez)) + 1).reshape(shapez).astype(dtype) + + feeds1 = dict(X=x, Y=y, Z=z) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_almost_equal(expected, got) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulmul_cuda(self): + self._addaddmulmul_cuda(TensorProto.FLOAT, "Mul") + self._addaddmulmul_cuda(TensorProto.FLOAT16, "Mul") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulmul_cuda_broadcast(self): + self._addaddmulmul_cuda(TensorProto.FLOAT, "Mul", True) + self._addaddmulmul_cuda(TensorProto.FLOAT16, "Mul", True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_addadd_cuda(self): + self._addaddmulmul_cuda(TensorProto.FLOAT, "Add") + self._addaddmulmul_cuda(TensorProto.FLOAT16, "Add") + @staticmethod def _create_negpos_test_model(domain="ai.onnx.contrib"): nodes = [ @@ -151,7 +228,7 @@ def test_cuda_negxplus1(self): self._negxplus1_cuda(TensorProto.FLOAT16) def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)): - from onnx_extended.ortops.optim.cuda import get_ort_ext_libs + from ai.onnx.contrib import get_ort_ext_libs model1 = helper.make_model( helper.make_graph( @@ -181,7 +258,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, f"{op_type}SharedInput", ["X", "Y", "Z"], ["XY", "XZ"], - domain="onnx_extended.ortops.optim.cuda", + domain="ai.onnx.contrib", ) ], "nd", @@ -197,7 +274,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, ), opset_imports=[ helper.make_opsetid("", 18), - helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + helper.make_opsetid("ai.onnx.contrib", 1), ], ir_version=9, )