diff --git a/operators/cuda/add_mul.h b/operators/cuda/add_mul.h index 43a1b97e2..a23e07b8c 100644 --- a/operators/cuda/add_mul.h +++ b/operators/cuda/add_mul.h @@ -141,7 +141,7 @@ struct AddAndMul { input_data_a, input_data_b, input_data_c, output_data, length_a, length_b, length_c, - addition_first, switchMiddelAxis_); + addition_first); } return {}; } @@ -150,4 +150,58 @@ struct AddAndMul { bool switchMiddelAxis_; }; +template +struct SubAndMul { + template + OrtxStatus OnModelAttach(const TDict& dict) { + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& tensor_a, + const ortc::Tensor& tensor_b, + const ortc::Tensor& tensor_c, + ortc::Tensor& output) const { + const T* input_data_a = tensor_a.Data(); + const T* input_data_b = tensor_b.Data(); + const T* input_data_c = tensor_c.Data(); + + auto length_a = tensor_a.NumberOfElement(); + auto length_b = tensor_b.NumberOfElement(); + auto length_c = tensor_c.NumberOfElement(); + if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { + return {}; + } + + std::vector dimsA = tensor_a.Shape(); + std::vector dimsB = tensor_b.Shape(); + std::vector dimsC = tensor_c.Shape(); + + auto max_length = std::max(length_a, std::max(length_b, length_c)); + + auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size())); + while (dimsA.size() < max_rank) + dimsA.insert(dimsA.begin(), 1); + while (dimsB.size() < max_rank) + dimsB.insert(dimsB.begin(), 1); + while (dimsC.size() < max_rank) + dimsC.insert(dimsC.begin(), 1); + + std::vector output_dims(dimsA.size()); + for (size_t i = 0; i < dimsA.size(); ++i) { + output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]); + } + + T* output_data_ab = output_ab.Allocate(output_dims); + LaunchSubAndMulKernel(reinterpret_cast(ctx->GetCudaStream()), + input_data_a, input_data_b, input_data_c, + output_data, + length_a, length_b, length_c, + subtract_first, negative_); + return {}; + } + + private: + bool negative_; +}; + } // namespace contrib \ No newline at end of file diff --git a/operators/cuda/add_mul_impl.cu b/operators/cuda/add_mul_impl.cu index ff52e00cb..14bf73b62 100644 --- a/operators/cuda/add_mul_impl.cu +++ b/operators/cuda/add_mul_impl.cu @@ -321,7 +321,7 @@ cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream, const T* pA, const T* pB, const T* pC, T* output, int64_t countA, int64_t countB, int64_t countC, - bool addition_first, bool switchMiddleAxes) { + bool addition_first) { int64_t max_count = std::max(std::max(countA, countB), countC); if (max_count == 0) // special case where there's a dim value of 0 in the output shape return cudaGetLastError(); @@ -358,6 +358,22 @@ cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream, return cudaGetLastError(); } +template <> +cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, + float* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition) { + return _LaunchAddAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, addition_first); +} + +template <> +cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, + const ortc::MFloat16* input_c, + ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition) { + return _LaunchAddAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, addition_first); +} + template cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const T* pA, const T* pB, const T* pC, @@ -399,3 +415,203 @@ cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, } return cudaGetLastError(); } + +template <> +cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, + float* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition, + int64_t d2, int64_t d3, int64_t d4) { + return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, pA, pB, pC, output, countA, countB, countC, + addition_first, d2, d3, d4); +} + +template <> +cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const ortc::MFloat16* input_a, + const ortc::MFloat16* input_b, const ortc::MFloat16* input_c, + ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition, + int64_t d2, int64_t d3, int64_t d4) { + return _LaunchAddAndMulSwitchMiddleAxesKernel(stream, pA, pB, pC, output, countA, countB, countC, + addition_first, d2, d3, d4); +} + +__device__ __forceinline__ void _submul_op(float* address, const float a, const float b, + const float c) { + *address = (a - b) * c; +} + +__device__ __forceinline__ void _submul_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half((__half2float(a) - __half2float(b)) * __half2float(c)); +#else + *address = (a - b) * c; +#endif +} + +__device__ __forceinline__ void _submul_neg_op(float* address, const float a, const float b, + const float c) { + *address = (b - a) * c; +} + +__device__ __forceinline__ void _submul_neg_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half((__half2float(b) - __half2float(a)) * __half2float(c)); +#else + *address = (b - a) * c; +#endif +} + +__device__ __forceinline__ void _mulsub_op(float* address, const float a, const float b, + const float c) { + *address = a * b - c; +} + +__device__ __forceinline__ void _mulsub_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(a) * __half2float(b) - __half2float(c)); +#else + *address = a * b - c; +#endif +} + +__device__ __forceinline__ void _mulsub_neg_op(float* address, const float a, const float b, + const float c) { + *address = c - a * b; +} + +__device__ __forceinline__ void _mulsub_neg_op(half* address, const half a, const half b, + const half c) { +#if __CUDA_ARCH__ < 700 + *address = __float2half(__half2float(c) - __half2float(a) * __half2float(b)); +#else + *address = c - a * b; +#endif +} + +template +struct SubMul { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _submul_op(address, a, b, c); + } +}; + +template +struct MulSub { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _mulsub_op(address, a, b, c); + } +}; + +template +struct SubMulNeg { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _submul_neg_op(address, a, b, c); + } +}; + +template +struct MulSubNeg { + __device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const { + _mulsub_neg_op(address, a, b, c); + } +}; + +template +__global__ void _MulSubKernel(T* output_data, const T* pA, const T* pB, const T* pC, + CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N, + const TFunc func) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]); + id += NumThreadsPerBlock; + } + } +} + +template +cudaError_t _LaunchSubAndMulKernel(cudaStream_t stream, + const T* pA, const T* pB, const T* pC, + T* output, + int64_t countA, int64_t countB, int64_t countC, + bool addition_first) { + int64_t max_count = std::max(std::max(countA, countB), countC); + if (max_count == 0) // special case where there's a dim value of 0 in the output shape + return cudaGetLastError(); + + const int num_elements_per_thread = 4; + const int num_threads_per_block = 256; + const int num_el_th = num_threads_per_block * num_elements_per_thread; + + int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th; + + using TT = typename contrib::CudaT::MappedType; + + if (addition_first) { + if (negative) { + SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, SubMulNEg()); + } else { + SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, SubMul()); + } + } else { + if (negative) { + SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, MulSubNeg()); + } else { + SubAndMulKernel, num_threads_per_block, num_elements_per_thread> + <<>>( + cuda_stream, + reinterpret_cast(output), + reinterpret_cast(pA), + reinterpret_cast(pB), + reinterpret_cast(pC), + countA, countB, countC, + max_size, MulSub()); + } + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c, + float* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool subtract_first, bool negative) { + return _LaunchSubAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, subtract_first, negative); +} + +template <> +cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, + const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, + const ortc::MFloat16* input_c, + ortc::MFloat16* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool subtract_first, negative) { + return _LaunchSubAndMulKernel(stream, pA, pB, pC, output, countA, countB, countC, subtract_first, negative); +} diff --git a/operators/cuda/add_mul_impl.cuh b/operators/cuda/add_mul_impl.cuh index 98bb06169..2c5b000ec 100644 --- a/operators/cuda/add_mul_impl.cuh +++ b/operators/cuda/add_mul_impl.cuh @@ -17,10 +17,15 @@ cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, const T* input_a, con template cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, T* output, int64_t length_a, int64_t length_b, int64_t length_c, - bool addition, bool switchMiddleAxis); + bool addition); template cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition, int64_t d2, int64_t d3, int64_t d4); + +template +cudaError_t LaunchSubAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c, + T* output, int64_t length_a, int64_t length_b, int64_t length_c, + bool addition, bool negative); diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 9ba3a7329..21feec2e2 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -16,8 +16,11 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { using AddTwiceFloat32Type = typename contrib::AddOrMulTwice; using MulTwiceFloat32Type = typename contrib::AddOrMulTwice; - using AddAndMulFloat32Type = typename contrib::AddOrMulTwice; - using MulAndAddFloat32Type = typename contrib::AddOrMulTwice; + using AddAndMulFloat32Type = typename contrib::AddAndMul; + using MulAndAddFloat32Type = typename contrib::AddAndMul; + + using SubAndMulFloat32Type = typename contrib::SubAndMul; + using MulAndSubFloat32Type = typename contrib::SubAndMul; #if ORT_API_VERSION >= 16 using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; @@ -26,8 +29,11 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { using AddTwiceFloat16Type = typename contrib::AddOrMulTwice; using MulTwiceFloat16Type = typename contrib::AddOrMulTwice; - using AddAndMulFloat32Type = typename contrib::AddOrMulTwice; - using MulAndAddFloat32Type = typename contrib::AddOrMulTwice; + using AddAndMulFloat32Type = typename contrib::AddAndMul; + using MulAndAddFloat32Type = typename contrib::AddAndMul; + + using SubAndMulFloat32Type = typename contrib::SubAndMul; + using MulAndSubFloat32Type = typename contrib::SubAndMul; #endif static OrtOpLoader op_loader( @@ -40,8 +46,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulAdd", MulAndAddFloat32Type), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), + CustomCudaStructV2("MulSub", MulAndSubFloat32Type), CustomCudaStructV2("MulTwice", MulTwiceFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("SubMul", SubAndMulFloat32Type), #if ORT_API_VERSION >= 16 CustomCudaStructV2("AddMul", AddAndMulFloat16Type), @@ -51,8 +59,10 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulAdd", MulAndAddFloat16Type), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), + CustomCudaStructV2("MulSub", MulAndSubFloat16Type), CustomCudaStructV2("MulTwice", MulTwiceFloat16Type), - CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1) + CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("SubMul", SubAndMulFloat16Type) #endif #endif ); diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index db11b212c..ca94646fc 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -423,6 +423,26 @@ def test_submul_cuda(self): self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul") self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul") + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_submul_cuda_negative(self): + self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul", negative=True) + self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul", negative=True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_submul_cuda_broadcast(self): + self._addmul_cuda(TensorProto.FLOAT, "Sub", "Mul", True) + self._addmul_cuda(TensorProto.FLOAT16, "Sub", "Mul", True) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulsub_cuda(self): + self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub") + self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub") + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_mulsub_cuda_negative(self): + self._addmul_cuda(TensorProto.FLOAT, "Mul", "Sub", negative=True) + self._addmul_cuda(TensorProto.FLOAT16, "Mul", "Sub", negative=True) + if __name__ == "__main__": unittest.main()