From 79f3b048d4d195b6684f2d1b6ca5bfe1ab9ea8d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 6 Jun 2024 17:44:21 +0200 Subject: [PATCH] Add custom op Transpose2DCast (#737) * Add custom op Transpose2DCast * fix compilation issues * fix compilation issues --- operators/cuda/add_mul.h | 2 +- operators/cuda/cuda_ops.cc | 7 ++- operators/cuda/transpose_cast.h | 39 +++++++++++++ operators/cuda/transpose_cast_impl.cu | 56 ++++++++++++++++++ operators/cuda/transpose_cast_impl.cuh | 9 +++ test/cuda/test_cudaops.py | 80 ++++++++++++++++++++++++-- 6 files changed, 185 insertions(+), 8 deletions(-) create mode 100644 operators/cuda/transpose_cast.h create mode 100644 operators/cuda/transpose_cast_impl.cu create mode 100644 operators/cuda/transpose_cast_impl.cuh diff --git a/operators/cuda/add_mul.h b/operators/cuda/add_mul.h index 5b8f36f21..a4f1fabaf 100644 --- a/operators/cuda/add_mul.h +++ b/operators/cuda/add_mul.h @@ -29,7 +29,7 @@ struct AddOrMulSharedInput { auto length_c = tensor_c.NumberOfElement(); T* output_data_ab = output_ab.Allocate(length_a <= length_b ? tensor_b.Shape() : tensor_a.Shape()); - T* output_data_ac = output_ab.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape()); + T* output_data_ac = output_ac.Allocate(length_a <= length_c ? tensor_c.Shape() : tensor_a.Shape()); if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) { return {}; diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index f8269302a..da7009075 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -7,6 +7,7 @@ #include "cuda/add_mul.h" #include "cuda/fast_gelu.h" #include "cuda/negxplus1.h" +#include "cuda/transpose_cast.h" #endif FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { @@ -17,6 +18,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { #if ORT_API_VERSION >= 16 using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput; + using Transpose2DCastFloat32ToFloat16Type = typename contrib::Transpose2DCast; + using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast; #endif @@ -34,7 +37,9 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), - CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1) + CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), + CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type) #endif #endif ); diff --git a/operators/cuda/transpose_cast.h b/operators/cuda/transpose_cast.h new file mode 100644 index 000000000..6ffae51c2 --- /dev/null +++ b/operators/cuda/transpose_cast.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "transpose_cast_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +template +struct Transpose2DCast { + template + OrtxStatus OnModelAttach(const TDict& /*dict*/) { + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + ortc::Tensor& output) const { + const TIN* input_data = input.Data(); + auto shape = input.Shape(); + if (shape.size() != 2) { + ORTX_CXX_API_THROW("Input must be a 2D tensor", ORT_RUNTIME_EXCEPTION); + } + int n_rows = static_cast(shape[0]); + int n_cols = static_cast(shape[1]); + + std::vector new_shape{static_cast(n_cols), static_cast(n_rows)}; + TOUT* output_data = output.Allocate(new_shape); + if (0 == n_rows || 0 == n_cols) { + return {}; + } + LaunchTranspose2DCastKernel(reinterpret_cast(ctx->GetCudaStream()), + n_rows, n_cols, input_data, output_data); + return {}; + } +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/transpose_cast_impl.cu b/operators/cuda/transpose_cast_impl.cu new file mode 100644 index 000000000..cdd6a177b --- /dev/null +++ b/operators/cuda/transpose_cast_impl.cu @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "transpose_cast_impl.cuh" +#include "cuda_type.h" + +using namespace Ort::Custom; + +#define TILE_DIM 32 +#define BLOCK_ROWS 8 + +template +__global__ void Transpose2DCastKernel(TOUT *output_data, const TIN *input_data, int n_rows, int n_cols) { + __shared__ TIN tile[TILE_DIM][TILE_DIM + 1]; + + int x = blockIdx.x * TILE_DIM + threadIdx.x; + int y = blockIdx.y * TILE_DIM + threadIdx.y; + // int width = gridDim.x * TILE_DIM; + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) + tile[threadIdx.y + j][threadIdx.x] = input_data[(y + j) * n_cols + x]; + + __syncthreads(); + + x = blockIdx.y * TILE_DIM + threadIdx.x; // transpose block offset + y = blockIdx.x * TILE_DIM + threadIdx.y; + + for (int j = 0; j < TILE_DIM; j += BLOCK_ROWS) + output_data[(y + j) * n_rows + x] = (TOUT)(tile[threadIdx.x][threadIdx.y + j]); +} + +template +cudaError_t _LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, + const TIN* input, TOUT* output) { + dim3 dimGrid((n_cols + TILE_DIM - 1) / TILE_DIM, (n_rows + TILE_DIM - 1) / TILE_DIM, 1); + dim3 dimBlock(TILE_DIM, BLOCK_ROWS, 1); + using TTIN = typename contrib::CudaT::MappedType; + using TTOUT = typename contrib::CudaT::MappedType; + Transpose2DCastKernel<<>>( + reinterpret_cast(output), reinterpret_cast(input), n_rows, n_cols); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, + const float* input, ortc::MFloat16* output) { + return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output); +} + +template <> +cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, + const ortc::MFloat16* input, float* output) { + return _LaunchTranspose2DCastKernel(stream, n_rows, n_cols, input, output); +} diff --git a/operators/cuda/transpose_cast_impl.cuh b/operators/cuda/transpose_cast_impl.cuh new file mode 100644 index 000000000..b3fb2c44f --- /dev/null +++ b/operators/cuda/transpose_cast_impl.cuh @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +template +cudaError_t LaunchTranspose2DCastKernel(cudaStream_t stream, int n_rows, int n_cols, const TIN* input, TOUT* output); \ No newline at end of file diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index df1ab2c47..000f3796b 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -21,6 +21,20 @@ def _run(self, X): return (1 - X,) +class Transpose2DCastFP16(OpRun): + op_domain = "ai.onnx.contrib" + + def _run(self, X): + return (X.T.to(np.float16),) + + +class Transpose2DCastFP32(OpRun): + op_domain = "ai.onnx.contrib" + + def _run(self, X): + return (X.T.to(np.float32),) + + class TestCudaOps(unittest.TestCase): @staticmethod def _create_negpos_test_model(domain="ai.onnx.contrib"): @@ -151,8 +165,6 @@ def test_cuda_negxplus1(self): self._negxplus1_cuda(TensorProto.FLOAT16) def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)): - from onnx_extended.ortops.optim.cuda import get_ort_ext_libs - model1 = helper.make_model( helper.make_graph( [ @@ -181,7 +193,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, f"{op_type}SharedInput", ["X", "Y", "Z"], ["XY", "XZ"], - domain="onnx_extended.ortops.optim.cuda", + domain="ai.onnx.contrib", ) ], "nd", @@ -197,7 +209,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, ), opset_imports=[ helper.make_opsetid("", 18), - helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1), + helper.make_opsetid("ai.onnx.contrib", 1), ], ir_version=9, ) @@ -212,7 +224,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, expected = ref.run(None, feeds1) opts = _ort.SessionOptions() - opts.register_custom_ops_library(get_ort_ext_libs()[0]) + opts.register_custom_ops_library(_get_library_path()) sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) got = sess.run(None, feeds1) for i in range(2): @@ -262,6 +274,62 @@ def test_add_shared_input_cuda_broadcast2(self): shapec=(3, 2, 3), ) + def _transpose_cast_cuda(self, itype): + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16 + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Transpose", ["X"], ["t"], perm=[1, 0]), + helper.make_node("Cast", ["t"], ["Y"], to=itype2), + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None])], + [helper.make_tensor_value_info("Y", itype2, [None, None])], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + ("Transpose2DCastFP16" if itype2 == TensorProto.FLOAT16 else "Transpose2DCastFP32"), + ["X"], + ["Y"], + domain="ai.onnx.contrib", + ) + ], + "nd", + [helper.make_tensor_value_info("X", itype, [None, None])], + [helper.make_tensor_value_info("Y", itype2, [None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + x = (np.arange(32 * 32 * 3) + 1).reshape((32, 32 * 3)).astype(dtype) + + feeds1 = dict(X=x) + ref = ReferenceEvaluator(model1, new_ops=[Transpose2DCastFP16, Transpose2DCastFP32]) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds1)[0] + assert_almost_equal(expected, got, decimal=5) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_transpose_cast_cuda(self): + self._transpose_cast_cuda(TensorProto.FLOAT) + self._transpose_cast_cuda(TensorProto.FLOAT16) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2)