From 26056749e11b1c0503224fb5ee4b3cfea7a5abf4 Mon Sep 17 00:00:00 2001 From: xadupre Date: Fri, 7 Jun 2024 09:07:00 +0000 Subject: [PATCH] add MaskedScatterNdOfShape --- operators/cuda/cuda_ops.cc | 2 + operators/cuda/scatter_nd_of_shape.h | 70 +++++++++ operators/cuda/scatter_nd_of_shape_impl.cu | 155 ++++++++++++++++++-- operators/cuda/scatter_nd_of_shape_impl.cuh | 10 ++ test/cuda/test_cudaops.py | 149 +++++++++++++++++-- 5 files changed, 361 insertions(+), 25 deletions(-) diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index 019cf2e05..68097b287 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -15,12 +15,14 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { #ifdef USE_CUDA , CustomCudaStructV2("FastGelu", contrib::FastGelu), + CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), #if ORT_API_VERSION >= 16 CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), + CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape) #endif diff --git a/operators/cuda/scatter_nd_of_shape.h b/operators/cuda/scatter_nd_of_shape.h index 3a1915ed4..239c2b5e6 100644 --- a/operators/cuda/scatter_nd_of_shape.h +++ b/operators/cuda/scatter_nd_of_shape.h @@ -70,4 +70,74 @@ struct ScatterNDOfShape { ScatterReduction reduction_; }; + +template +struct MaskedScatterNDOfShape { + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + std::string value; + OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value); + if (status != nullptr) + return status; + + if (value == "add") + reduction_ = ScatterReduction::Add; + else if (value == "mul") + reduction_ = ScatterReduction::Mul; + else if (value == "min") + reduction_ = ScatterReduction::Min; + else if (value == "max") + reduction_ = ScatterReduction::Max; + else + ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION); + + status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_); + if (status != nullptr) + return status; + + return nullptr; + } + + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& output_shape, + const ortc::Tensor& indices, + const ortc::Tensor& updates, + ortc::Tensor& output) const { + auto& output_shape_shape = output_shape.Shape(); + auto& indices_shape = indices.Shape(); + auto& updates_shape = updates.Shape(); + + if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) { + ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION); + } + if (indices_shape[indices_shape.size() - 1] != 1) { + ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION); + } + + const int64_t* shape_data = output_shape.Data(); // CPU pointer + const int64_t* indices_data = indices.Data(); // GPU pointer + const T* updates_data = updates.Data(); // GPU pointer + std::vector voutput_shape(shape_data, shape_data + output_shape_shape[0]); + T* output_data = output.Allocate(voutput_shape); // GPU pointer + LaunchMaskedScatterNDOfShapeKernel(reinterpret_cast(ctx->GetCudaStream()), + voutput_shape, + indices_shape, + indices_data, + updates_data, + output_data, + reduction_, + masked_value_); + return nullptr; + } + + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 0) // shape + return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + private: + ScatterReduction reduction_; + int64_t masked_value_; +}; + } // namespace contrib diff --git a/operators/cuda/scatter_nd_of_shape_impl.cu b/operators/cuda/scatter_nd_of_shape_impl.cu index 255dc6b81..2cd727548 100644 --- a/operators/cuda/scatter_nd_of_shape_impl.cu +++ b/operators/cuda/scatter_nd_of_shape_impl.cu @@ -52,15 +52,65 @@ addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ } template -cudaError_t _ComputeNoAtomic(cudaStream_t stream, T* output_data, - const int64_t* indices_data, const T* updates_data, - int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) { - dim3 threads(threads_per_block); - dim3 blocks(blocks_per_grid); - using TT = typename CudaT::MappedType; - addition_inplace_kernel<<>>((TT*)output_data, indices_data, - (TT*)updates_data, indice_size, nrows, stride); - return cudaGetLastError(); +__global__ void masked_addition_inplace_kernel(T *__restrict__ output_data, + const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, + const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride, + const int64_t masked_value) { + auto id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= stride) + return; + + for (size_t i = 0; i < nrows; ++i) { + output_data[i * stride + id] = 0; + } + + for (size_t i = 0; i < indice_size; ++i) { + if (indices_data[i] == masked_value) + continue; + _add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]); + } +} + +template +__global__ void masked_addition_inplace_kernelN(T *__restrict__ output_data, + const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, + const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride, + const int64_t masked_value) { + __shared__ int64_t shared_indices[NTHREAD]; + + CUDA_LONG tid = threadIdx.x; + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + + for (size_t i = 0; i < nrows; ++i) { + output_data[i * stride + id] = 0; + } + + int begin = 0; + int end = std::min(begin + NTHREAD, indice_size); + while (begin < end && (end == begin + NTHREAD)) { + shared_indices[tid] = indices_data[tid + begin]; + __syncthreads(); + + for (size_t i = begin; i < end; ++i) { + if (shared_indices[tid] == masked_value) + continue; + _add_inplace(output_data[shared_indices[tid] * stride + id], + updates_data[i * stride + id]); + } + + begin = end; + end = std::min(begin + NTHREAD, indice_size); + } + + for (size_t i = begin; i < indice_size; ++i) { + if (indices_data[i] == masked_value) + continue; + _add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]); + } } template @@ -89,7 +139,54 @@ cudaError_t ScatterNDOfShapeKernel(cudaStream_t stream, int threads_per_block = 256; int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; - return _ComputeNoAtomic(stream, output_data, indices_data, updates_data, threads_per_block, blocks_per_grid, indice_size, nrows, stride); + + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + using TT = typename CudaT::MappedType; + addition_inplace_kernel<<>>(reinterpret_cast(output_data), indices_data, + reinterpret_cast(updates_data), + indice_size, nrows, stride); + return cudaGetLastError(); +} + +template +cudaError_t MaskedScatterNDOfShapeKernel(cudaStream_t stream, const std::vector &input_shape, + const std::vector &indices_shape, + const int64_t *indices_data, const T *updates_data, + T *output_data, + ScatterReduction reduction, int64_t masked_value) { + if (reduction != ScatterReduction::Add) + ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION); + size_t indice_size = static_cast(flattened_dimension(indices_shape)); + size_t input_size = static_cast(flattened_dimension(input_shape)); + size_t stride = input_shape[input_shape.size() - 1]; + size_t nrows = input_size / stride; + + std::vector next_batch(indice_size); + std::vector processed(input_shape[0], 0); + std::vector processed_once(input_shape[0], 0); + + int threads_per_block = 256; + bool split = stride / threads_per_block <= 32; + + int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + + using TT = typename CudaT::MappedType; + + if (split && stride >= 256 && threads_per_block == 256) { + masked_addition_inplace_kernelN<<>>( + reinterpret_cast(output_data), indices_data, + reinterpret_cast(updates_data), + indice_size, nrows, stride, masked_value); + } else { + masked_addition_inplace_kernel<<>>( + reinterpret_cast(output_data), indices_data, + reinterpret_cast(updates_data), + indice_size, nrows, stride, masked_value); + } + return cudaGetLastError(); } template <> @@ -126,4 +223,42 @@ cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream, reduction); } +template <> +cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const float* updates, + float* output, + ScatterReduction reduction, + int64_t masked_value) { + return MaskedScatterNDOfShapeKernel(stream, + output_shape, + indices_shape, + indices, + updates, + output, + reduction, + masked_value); +} + +template <> +cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const ortc::MFloat16* updates, + ortc::MFloat16* output, + ScatterReduction reduction, + int64_t masked_value) { + return MaskedScatterNDOfShapeKernel(stream, + output_shape, + indices_shape, + indices, + updates, + output, + reduction, + masked_value); +} + } // namespace contrib diff --git a/operators/cuda/scatter_nd_of_shape_impl.cuh b/operators/cuda/scatter_nd_of_shape_impl.cuh index a010f210c..75a5479ab 100644 --- a/operators/cuda/scatter_nd_of_shape_impl.cuh +++ b/operators/cuda/scatter_nd_of_shape_impl.cuh @@ -24,4 +24,14 @@ cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream, T* output, ScatterReduction reduction); +template +cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const T* updates, + T* output, + ScatterReduction reduction, + int64_t masked_value); + } // namespace contrib diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index b53be201d..3ff2358bd 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -11,6 +11,10 @@ import onnxruntime as _ort +def has_cuda(): + return "CUDAExecutionProvider" in _ort.get_available_providers() + + class ScatterNDOfShape(OpRun): op_domain = "ai.onnx.contrib" @@ -151,11 +155,10 @@ def _negxplus1_cuda(self, itype): got = sess.run(None, feeds1)[0] assert_almost_equal(expected, got, decimal=5) + @unittest.skipIf(not has_cuda(), reason="cuda not available") def test_cuda_negxplus1(self): - eps = _ort.get_available_providers() - if "CUDAExecutionProvider" in eps: - self._negxplus1_cuda(TensorProto.FLOAT) - self._negxplus1_cuda(TensorProto.FLOAT16) + self._negxplus1_cuda(TensorProto.FLOAT) + self._negxplus1_cuda(TensorProto.FLOAT16) def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype): indices_shape = ["i", "j", 1] if dim3 else ["j", 1] @@ -214,19 +217,135 @@ def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype): got = sess.run(None, feeds, ro)[0] self.assertEqual(expected.tolist(), got.tolist()) + @unittest.skipIf(not has_cuda(), reason="cuda not available") def test_scatternd_of_shape_optimize_cuda(self): - eps = _ort.get_available_providers() - if "CUDAExecutionProvider" in eps: - with self.subTest(optimize=True, dim3=True): - self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT) - self._scatternd_of_shape_optimize_cuda(False, False, TensorProto.FLOAT) - self._scatternd_of_shape_optimize_cuda(False, True, TensorProto.FLOAT) - with self.subTest(optimize=True, dim3=False): - self._scatternd_of_shape_optimize_cuda(True, False, TensorProto.FLOAT) - with self.subTest(optimize=True, dim3=True, itype=TensorProto.FLOAT16): - self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT16) + with self.subTest(optimize=True, dim3=True): + self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT) + self._scatternd_of_shape_optimize_cuda(False, False, TensorProto.FLOAT) + self._scatternd_of_shape_optimize_cuda(False, True, TensorProto.FLOAT) + with self.subTest(optimize=True, dim3=False): + self._scatternd_of_shape_optimize_cuda(True, False, TensorProto.FLOAT) + with self.subTest(optimize=True, dim3=True, itype=TensorProto.FLOAT16): + self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT16) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_scatternd_of_shape_standalone_cuda(self): + self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT) + self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16) + self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT) + self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16) + + def _masked_scatternd_of_shape_cuda(self, reduction, line, itype, big): + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Equal", ["indices", "mone"], ["masked_indices"]), + helper.make_node( + "Where", + ["masked_indices", "zero", "updates"], + ["masked_updates"], + ), + helper.make_node( + "ScatterND", + inputs=["data", "indices", "masked_updates"], + outputs=["y"], + reduction=reduction, + ), + ], + "nd", + [ + helper.make_tensor_value_info("data", itype, [None, None]), + helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]), + helper.make_tensor_value_info("updates", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("y", itype, [None, None])], + [ + numpy_helper.from_array(np.array([-1], dtype=np.int64), name="mone"), + numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"), + ], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "MaskedScatterNDOfShape", + inputs=["shape", "indices", "updates"], + outputs=["y"], + reduction=reduction, + maskedValue=-1, + domain="ai.onnx.contrib", + ) + ], + "nd", + [ + helper.make_tensor_value_info("shape", TensorProto.INT64, [None]), + helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]), + helper.make_tensor_value_info("updates", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("y", itype, [None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + if big: + data = np.zeros((2048, 4096), dtype=dtype) + indices = np.ones((2, 1024), dtype=np.int64) + indices = indices[..., np.newaxis] + shape = tuple(indices.shape[:2]) + (data.shape[-1],) + updates = (np.arange(np.prod(shape)).reshape(shape) / np.prod(shape)).astype(dtype) else: - print("CUDAExecutionProvider not available, test_cuda_scatternd_of_shape skipped.") + data = np.zeros((32, 16), dtype=dtype) + indices = np.array( + [ + [0, 1, 2], + [2, 3, 4], + [-1, 30, 31], + [-1, 7, 8], + [10, 11, -1], + [20, -1, 21], + ], + dtype=np.int64, + ) + indices = indices[..., np.newaxis] + shape = (6, 3, data.shape[-1]) + updates = (np.arange(np.prod(shape)).reshape(shape) + 1).astype(dtype) + + feeds1 = dict(data=data, indices=indices, updates=updates) + feeds2 = dict(shape=np.array(data.shape, dtype=np.int64), indices=indices, updates=updates) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + # opts.log_severity_level = 0 + # opts.log_verbosity_level = 0 + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds2)[0] + assert_almost_equal(expected.tolist(), got.tolist()) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_masked_scatternd_of_shape_standalone_cuda_small(self): + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, False) + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, False) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, False) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, False) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_masked_scatternd_of_shape_standalone_cuda_big(self): + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, True) + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, True) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True) if __name__ == "__main__":