From cf1e47f7878a8d7043688ac66c73cb9d12355073 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 10 Jun 2024 11:00:20 +0000 Subject: [PATCH] fix compilation --- operators/cuda/add_mul.h | 124 ++++++++++++-------------- operators/cuda/add_mul_impl.cu | 158 +++++++++++++++++---------------- operators/cuda/cuda_ops.cc | 16 ++-- test/cuda/test_cudaops.py | 10 +-- 4 files changed, 149 insertions(+), 159 deletions(-) diff --git a/operators/cuda/add_mul.h b/operators/cuda/add_mul.h index a23e07b8c..4995642e7 100644 --- a/operators/cuda/add_mul.h +++ b/operators/cuda/add_mul.h @@ -8,6 +8,27 @@ namespace contrib { +inline void _FillOutputShape3Op(std::vector& dimsA, + std::vector& dimsB, + std::vector& dimsC, + std::vector& output_dims) { + auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size())); + while (dimsA.size() < max_rank) + dimsA.insert(dimsA.begin(), 1); + while (dimsB.size() < max_rank) + dimsB.insert(dimsB.begin(), 1); + while (dimsC.size() < max_rank) + dimsC.insert(dimsC.begin(), 1); + + output_dims.resize(dimsA.size()); + for (size_t i = 0; i < dimsA.size(); ++i) { + output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]); + if (output_dims[i] == 0) { + ORTX_CXX_API_THROW("One of the input dimensions is null.", ORT_RUNTIME_EXCEPTION); + } + } +} + template struct AddOrMulSharedInput { template @@ -20,22 +41,19 @@ struct AddOrMulSharedInput { const ortc::Tensor& tensor_c, ortc::Tensor& output_ab, ortc::Tensor& output_ac) const { - const T* input_data_a = tensor_a.Data(); - const T* input_data_b = tensor_b.Data(); - const T* input_data_c = tensor_c.Data(); - auto length_a = tensor_a.NumberOfElement(); auto length_b = tensor_b.NumberOfElement(); auto length_c = tensor_c.NumberOfElement(); - T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape()); - T* output_data_ac = output_ab.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape()); - - if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { + if (0 == length_a || 0 == length_b || 0 == length_c) { return {}; } + + T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape()); + T* output_data_ac = output_ac.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape()); + LaunchAddOrMulSharedInputKernel(reinterpret_cast(ctx->GetCudaStream()), - input_data_a, input_data_b, input_data_c, + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), output_data_ab, output_data_ac, length_a, length_b, length_c, addition); @@ -54,25 +72,24 @@ struct AddOrMulTwice { const ortc::Tensor& tensor_b, const ortc::Tensor& tensor_c, ortc::Tensor& output) const { - const T* input_data_a = tensor_a.Data(); - const T* input_data_b = tensor_b.Data(); - const T* input_data_c = tensor_c.Data(); - auto length_a = tensor_a.NumberOfElement(); auto length_b = tensor_b.NumberOfElement(); auto length_c = tensor_c.NumberOfElement(); - T* output_data_ab = output_ab.Allocate( - length_a <= length_b - ? lenght_c <= length_b ? tensor_b.Shape() : tensor_c.Shape() - : lenght_a <= length_b ? tensor_b.Shape() - : tensor_a.Shape()); - - if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { + if (0 == length_a || 0 == length_b || 0 == length_c) { return {}; } + + std::vector dimsA = tensor_a.Shape(); + std::vector dimsB = tensor_b.Shape(); + std::vector dimsC = tensor_c.Shape(); + std::vector output_dims; + _FillOutputShape3Op(dimsA, dimsB, dimsC, output_dims); + + T* output_data = output.Allocate(output_dims); + LaunchAddOrMulTwiceKernel(reinterpret_cast(ctx->GetCudaStream()), - input_data_a, input_data_b, input_data_c, + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), output_data, length_a, length_b, length_c, addition); @@ -84,42 +101,27 @@ template struct AddAndMul { template OrtxStatus OnModelAttach(const TDict& dict) { - return {}; + int64_t default_value = 0; + switchMiddelAxis_ = dict.TryToGetAttributeWithDefault("switchMiddleAxis", default_value) == 1; } OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& tensor_a, const ortc::Tensor& tensor_b, const ortc::Tensor& tensor_c, ortc::Tensor& output) const { - const T* input_data_a = tensor_a.Data(); - const T* input_data_b = tensor_b.Data(); - const T* input_data_c = tensor_c.Data(); - auto length_a = tensor_a.NumberOfElement(); auto length_b = tensor_b.NumberOfElement(); auto length_c = tensor_c.NumberOfElement(); - if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { + + if (0 == length_a || 0 == length_b || 0 == length_c) { return {}; } std::vector dimsA = tensor_a.Shape(); std::vector dimsB = tensor_b.Shape(); std::vector dimsC = tensor_c.Shape(); - - auto max_length = std::max(length_a, std::max(length_b, length_c)); - - auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size())); - while (dimsA.size() < max_rank) - dimsA.insert(dimsA.begin(), 1); - while (dimsB.size() < max_rank) - dimsB.insert(dimsB.begin(), 1); - while (dimsC.size() < max_rank) - dimsC.insert(dimsC.begin(), 1); - - std::vector output_dims(dimsA.size()); - for (size_t i = 0; i < dimsA.size(); ++i) { - output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]); - } + std::vector output_dims; + _FillOutputShape3Op(dimsA, dimsB, dimsC, output_dims); if (switchMiddelAxis_) { if (output_dims.size() != 4) { @@ -130,15 +132,16 @@ struct AddAndMul { int64_t d2 = output_dims[output_dims.size() - 3]; output_dims[1] = d3; output_dims[2] = d2; + T* output_data = output.Allocate(output_dims); LaunchAddAndMulSwitchMiddleAxesKernel(reinterpret_cast(ctx->GetCudaStream()), - input_data_a, input_data_b, input_data_c, + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), output_data, length_a, length_b, length_c, addition_first, d2, d3, d4); } else { - T* output_data_ab = output_ab.Allocate(output_dims); + T* output_data = output.Allocate(output_dims); LaunchAddAndMulKernel(reinterpret_cast(ctx->GetCudaStream()), - input_data_a, input_data_b, input_data_c, + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), output_data, length_a, length_b, length_c, addition_first); @@ -154,46 +157,31 @@ template struct SubAndMul { template OrtxStatus OnModelAttach(const TDict& dict) { - return {}; + //int64_t default_value = 0; + //negative_ = dict.TryToGetAttributeWithDefault("negative", default_value) == 1; + negative_ = false; } OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, const ortc::Tensor& tensor_a, const ortc::Tensor& tensor_b, const ortc::Tensor& tensor_c, ortc::Tensor& output) const { - const T* input_data_a = tensor_a.Data(); - const T* input_data_b = tensor_b.Data(); - const T* input_data_c = tensor_c.Data(); - auto length_a = tensor_a.NumberOfElement(); auto length_b = tensor_b.NumberOfElement(); auto length_c = tensor_c.NumberOfElement(); - if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { + if (0 == length_a || 0 == length_b || 0 == length_c) { return {}; } std::vector dimsA = tensor_a.Shape(); std::vector dimsB = tensor_b.Shape(); std::vector dimsC = tensor_c.Shape(); + std::vector output_dims; + _FillOutputShape3Op(dimsA, dimsB, dimsC, output_dims); + T* output_data = output.Allocate(output_dims); - auto max_length = std::max(length_a, std::max(length_b, length_c)); - - auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size())); - while (dimsA.size() < max_rank) - dimsA.insert(dimsA.begin(), 1); - while (dimsB.size() < max_rank) - dimsB.insert(dimsB.begin(), 1); - while (dimsC.size() < max_rank) - dimsC.insert(dimsC.begin(), 1); - - std::vector output_dims(dimsA.size()); - for (size_t i = 0; i < dimsA.size(); ++i) { - output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]); - } - - T* output_data_ab = output_ab.Allocate(output_dims); LaunchSubAndMulKernel(reinterpret_cast(ctx->GetCudaStream()), - input_data_a, input_data_b, input_data_c, + tensor_a.Data(), tensor_b.Data(), tensor_c.Data(), output_data, length_a, length_b, length_c, subtract_first, negative_); diff --git a/operators/cuda/add_mul_impl.cu b/operators/cuda/add_mul_impl.cu index 14bf73b62..20919c467 100644 --- a/operators/cuda/add_mul_impl.cu +++ b/operators/cuda/add_mul_impl.cu @@ -169,15 +169,15 @@ struct Add3Op { }; template -__global__ void AddMulTwiceKernel(T* output, const T* pA, const T* pB, - const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, - CUDA_LONG N, const TFunc func) { +__global__ void AddMulTwiceKernel(T* output_data, const T* pA, const T* pB, const T* pC, + CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, + const TFunc func) { CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; CUDA_LONG id = start; #pragma unroll for (int i = 0; i < NumElementsPerThread; i++) { if (id < N) { - func(output_ab, pA[id % nA], pB[id % nB], pC[id % nC]); + func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]); id += NumThreadsPerBlock; } } @@ -206,36 +206,39 @@ cudaError_t _LaunchAddOrMulTwiceKernel(cudaStream_t stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), static_cast(countB), static_cast(countC), - static_cast(max_count), Add3SharedOp()); + static_cast(max_count), Add3Op()); } else { AddMulTwiceKernel, num_threads_per_block, num_elements_per_thread> <<>>( reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), static_cast(countA), static_cast(countB), static_cast(countC), - static_cast(max_count), Mul3SharedOp()); + static_cast(max_count), Mul3Op()); } return cudaGetLastError(); } template <> -cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, - const float* input_a, const float* input_b, const float* input_c, - float* output, - int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { - return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, - output, - length_a, length_b, length_c, addition); +cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, + const float* input_a, const float* input_b, const float* input_c, + float* output, + int64_t length_a, int64_t length_b, int64_t length_c, + bool addition) { + return _LaunchAddOrMulTwiceKernel(stream, input_a, input_b, input_c, + output, + length_a, length_b, length_c, addition); } template <> -cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, - const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, - ortc::MFloat16* output, - int64_t length_a, int64_t length_b, int64_t length_c, bool addition) { - return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, - output, - length_a, length_b, length_c, addition); +cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, + const ortc::MFloat16* input_c, + ortc::MFloat16* output, + int64_t length_a, int64_t length_b, int64_t length_c, + bool addition) { + return _LaunchAddOrMulTwiceKernel(stream, input_a, input_b, input_c, + output, + length_a, length_b, length_c, addition); } __device__ __forceinline__ void _addmul_op(float* address, const float a, const float b, @@ -281,7 +284,7 @@ struct MulAdd { }; template -__global__ void _AddAndMulKernel(T* output_data, const T* pA, const T* pB, const T* pC, +__global__ void AddAndMulKernel(T* output_data, const T* pA, const T* pB, const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, const TFunc func) { CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; @@ -296,7 +299,7 @@ __global__ void _AddAndMulKernel(T* output_data, const T* pA, const T* pB, const } template -__global__ void _AddAndMulSwitchMiddleAxesKernel(T* output_data, const T* pA, const T* pB, +__global__ void AddAndMulSwitchMiddleAxesKernel(T* output_data, const T* pA, const T* pB, const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, const TFunc func, CUDA_LONG d2, @@ -337,23 +340,27 @@ cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream, if (addition_first) { AddAndMulKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, AddMul()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + AddMul()); } else { AddAndMulKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, MulAdd()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulAdd()); } return cudaGetLastError(); } @@ -361,8 +368,8 @@ cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream, template <> cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, float* output, int64_t length_a, int64_t length_b, int64_t length_c, - bool addition) { - return _LaunchAddAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, addition_first); + bool addition_first) { + return _LaunchAddAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, addition_first); } template <> @@ -370,8 +377,8 @@ cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, - bool addition) { - return _LaunchAddAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, addition_first); + bool addition_first) { + return _LaunchAddAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, addition_first); } template @@ -395,23 +402,27 @@ cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, if (addition_first) { AddAndMulSwitchMiddleAxesKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, AddMul()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + AddMul(), d2, d3, d4); } else { AddAndMulSwitchMiddleAxesKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, MulAdd()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulAdd(), d2, d3, d4); } return cudaGetLastError(); } @@ -419,9 +430,9 @@ cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, template <> cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, float* output, int64_t length_a, int64_t length_b, int64_t length_c, - bool addition, + bool addition_first, int64_t d2, int64_t d3, int64_t d4) { - return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, pA, pB, pC, output, countA, countB, countC, + return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, addition_first, d2, d3, d4); } @@ -429,9 +440,9 @@ template <> cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, - bool addition, + bool addition_first, int64_t d2, int64_t d3, int64_t d4) { - return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, pA, pB, pC, output, countA, countB, countC, + return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, addition_first, d2, d3, d4); } @@ -519,27 +530,12 @@ struct MulSubNeg { } }; -template -__global__ void _MulSubKernel(T* output_data, const T* pA, const T* pB, const T* pC, - CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, - const TFunc func) { - CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; - CUDA_LONG id = start; -#pragma unroll - for (int i = 0; i < NumElementsPerThread; i++) { - if (id < N) { - func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]); - id += NumThreadsPerBlock; - } - } -} - template cudaError_t _LaunchSubAndMulKernel(cudaStream_t stream, const T* pA, const T* pB, const T* pC, T* output, int64_t countA, int64_t countB, int64_t countC, - bool addition_first) { + bool addition_first, bool negative) { int64_t max_count = std::max(std::max(countA, countB), countC); if (max_count == 0) // special case where there's a dim value of 0 in the output shape return cudaGetLastError(); @@ -554,47 +550,55 @@ cudaError_t _LaunchSubAndMulKernel(cudaStream_t stream, if (addition_first) { if (negative) { - SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, SubMulNEg()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + SubMulNeg()); } else { - SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, SubMul()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + SubMul()); } } else { if (negative) { - SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, MulSubNeg()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulSubNeg()); } else { - SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + AddAndMulKernel, num_threads_per_block, num_elements_per_thread> <<>>( - cuda_stream, reinterpret_cast(output), reinterpret_cast(pA), reinterpret_cast(pB), reinterpret_cast(pC), - countA, countB, countC, - max_size, MulSub()); + static_cast(countA), + static_cast(countB), + static_cast(countC), + static_cast(max_count), + MulSub()); } } return cudaGetLastError(); @@ -604,7 +608,7 @@ template <> cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, float* output, int64_t length_a, int64_t length_b, int64_t length_c, bool subtract_first, bool negative) { - return _LaunchSubAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, subtract_first, negative); + return _LaunchSubAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, subtract_first, negative); } template <> @@ -612,6 +616,6 @@ cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, - bool subtract_first, negative) { - return _LaunchSubAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, subtract_first, negative); + bool subtract_first, bool negative) { + return _LaunchSubAndMulKernel(stream, input_a, input_b, input_c, output, length_a, length_b, length_c, subtract_first, negative); } diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 21feec2e2..0587fc6f9 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -29,38 +29,38 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { using AddTwiceFloat16Type = typename contrib::AddOrMulTwice; using MulTwiceFloat16Type = typename contrib::AddOrMulTwice; - using AddAndMulFloat32Type = typename contrib::AddAndMul; - using MulAndAddFloat32Type = typename contrib::AddAndMul; + using AddAndMulFloat16Type = typename contrib::AddAndMul; + using MulAndAddFloat16Type = typename contrib::AddAndMul; - using SubAndMulFloat32Type = typename contrib::SubAndMul; - using MulAndSubFloat32Type = typename contrib::SubAndMul; + using SubAndMulFloat16Type = typename contrib::SubAndMul; + using MulAndSubFloat16Type = typename contrib::SubAndMul; #endif static OrtOpLoader op_loader( []() { return nullptr; } #ifdef USE_CUDA , + CustomCudaStructV2("AddAdd", AddTwiceFloat32Type), CustomCudaStructV2("AddMul", AddAndMulFloat32Type), CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type), - CustomCudaStructV2("AddTwice", AddTwiceFloat32Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulAdd", MulAndAddFloat32Type), + CustomCudaStructV2("MulMul", MulTwiceFloat32Type), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), CustomCudaStructV2("MulSub", MulAndSubFloat32Type), - CustomCudaStructV2("MulTwice", MulTwiceFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("SubMul", SubAndMulFloat32Type), #if ORT_API_VERSION >= 16 + CustomCudaStructV2("AddAdd", AddTwiceFloat16Type), CustomCudaStructV2("AddMul", AddAndMulFloat16Type), CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type), - CustomCudaStructV2("AddTwice", AddTwiceFloat16Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulAdd", MulAndAddFloat16Type), + CustomCudaStructV2("MulMul", MulTwiceFloat16Type), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), CustomCudaStructV2("MulSub", MulAndSubFloat16Type), - CustomCudaStructV2("MulTwice", MulTwiceFloat16Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("SubMul", SubAndMulFloat16Type) #endif diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index ca94646fc..ab7b1f4a3 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -71,8 +71,8 @@ def _addaddmulmul_cuda(self, itype, op_type, broad=False): shapey = (3, 2, 3) shapez = (1, 2, 3) if broad else (3, 2, 3) x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype) - y = (np.arange(np.prod(shapey)) + 1).reshape(shapey).astype(dtype) - z = (np.arange(np.prod(shapez)) + 1).reshape(shapez).astype(dtype) + y = (np.arange(np.prod(shapey)) + 10).reshape(shapey).astype(dtype) + z = (np.arange(np.prod(shapez)) + 100).reshape(shapez).astype(dtype) feeds1 = dict(X=x, Y=y, Z=z) ref = ReferenceEvaluator(model1) @@ -228,8 +228,6 @@ def test_cuda_negxplus1(self): self._negxplus1_cuda(TensorProto.FLOAT16) def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)): - from ai.onnx.contrib import get_ort_ext_libs - model1 = helper.make_model( helper.make_graph( [ @@ -289,7 +287,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, expected = ref.run(None, feeds1) opts = _ort.SessionOptions() - opts.register_custom_ops_library(get_ort_ext_libs()[0]) + opts.register_custom_ops_library(_get_library_path()) sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) got = sess.run(None, feeds1) for i in range(2): @@ -445,4 +443,4 @@ def test_mulsub_cuda_negative(self): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2)