From f5055466d5376059c2ea74e3cea46e16a537bc0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 11 Jun 2024 09:59:46 +0200 Subject: [PATCH] Add custom kernel ScatterNDOfShape (#705) * first draft * clang * Draft for ScatterNFOfShape * fix build * disable test when cuda is missing * fix implementation * update test * add MaskedScatterNdOfShape * fix merge conflicts --- operators/cuda/cuda_ops.cc | 5 + operators/cuda/scatter_nd_of_shape.h | 143 +++++++++++ operators/cuda/scatter_nd_of_shape_impl.cu | 264 ++++++++++++++++++++ operators/cuda/scatter_nd_of_shape_impl.cuh | 37 +++ test/cuda/test_cudaops.py | 197 +++++++++++++++ 5 files changed, 646 insertions(+) create mode 100644 operators/cuda/scatter_nd_of_shape.h create mode 100644 operators/cuda/scatter_nd_of_shape_impl.cu create mode 100644 operators/cuda/scatter_nd_of_shape_impl.cuh diff --git a/operators/cuda/cuda_ops.cc b/operators/cuda/cuda_ops.cc index da7009075..25913b8bf 100644 --- a/operators/cuda/cuda_ops.cc +++ b/operators/cuda/cuda_ops.cc @@ -7,6 +7,7 @@ #include "cuda/add_mul.h" #include "cuda/fast_gelu.h" #include "cuda/negxplus1.h" +#include "cuda/scatter_nd_of_shape.h" #include "cuda/transpose_cast.h" #endif @@ -29,15 +30,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& { , CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), + CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), #if ORT_API_VERSION >= 16 CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type), CustomCudaStructV2("FastGelu", contrib::FastGelu), CustomCudaStructV2("FastGelu", contrib::FastGelu), + CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape), CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type), CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1), + CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape), CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type), CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type) #endif diff --git a/operators/cuda/scatter_nd_of_shape.h b/operators/cuda/scatter_nd_of_shape.h new file mode 100644 index 000000000..239c2b5e6 --- /dev/null +++ b/operators/cuda/scatter_nd_of_shape.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "string_utils.h" +#include "scatter_nd_of_shape_impl.cuh" + +namespace contrib { + +template +struct ScatterNDOfShape { + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + std::string value; + OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value); + if (status != nullptr) + return status; + + if (value == "add") + reduction_ = ScatterReduction::Add; + else if (value == "mul") + reduction_ = ScatterReduction::Mul; + else if (value == "min") + reduction_ = ScatterReduction::Min; + else if (value == "max") + reduction_ = ScatterReduction::Max; + else + ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION); + + return nullptr; + } + + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& output_shape, + const ortc::Tensor& indices, + const ortc::Tensor& updates, + ortc::Tensor& output) const { + auto& output_shape_shape = output_shape.Shape(); + auto& indices_shape = indices.Shape(); + auto& updates_shape = updates.Shape(); + + if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) { + ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION); + } + if (indices_shape[indices_shape.size() - 1] != 1) { + ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION); + } + + const int64_t* shape_data = output_shape.Data(); // CPU pointer + const int64_t* indices_data = indices.Data(); // GPU pointer + const T* updates_data = updates.Data(); // GPU pointer + std::vector voutput_shape(shape_data, shape_data + output_shape_shape[0]); + T* output_data = output.Allocate(voutput_shape); // GPU pointer + LaunchScatterNDOfShapeKernel(reinterpret_cast(ctx->GetCudaStream()), + voutput_shape, + indices_shape, + indices_data, + updates_data, + output_data, + reduction_); + return nullptr; + } + + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 0) // shape + return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + ScatterReduction reduction_; +}; + + +template +struct MaskedScatterNDOfShape { + OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + std::string value; + OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value); + if (status != nullptr) + return status; + + if (value == "add") + reduction_ = ScatterReduction::Add; + else if (value == "mul") + reduction_ = ScatterReduction::Mul; + else if (value == "min") + reduction_ = ScatterReduction::Min; + else if (value == "max") + reduction_ = ScatterReduction::Max; + else + ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION); + + status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_); + if (status != nullptr) + return status; + + return nullptr; + } + + OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& output_shape, + const ortc::Tensor& indices, + const ortc::Tensor& updates, + ortc::Tensor& output) const { + auto& output_shape_shape = output_shape.Shape(); + auto& indices_shape = indices.Shape(); + auto& updates_shape = updates.Shape(); + + if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) { + ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION); + } + if (indices_shape[indices_shape.size() - 1] != 1) { + ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION); + } + + const int64_t* shape_data = output_shape.Data(); // CPU pointer + const int64_t* indices_data = indices.Data(); // GPU pointer + const T* updates_data = updates.Data(); // GPU pointer + std::vector voutput_shape(shape_data, shape_data + output_shape_shape[0]); + T* output_data = output.Allocate(voutput_shape); // GPU pointer + LaunchMaskedScatterNDOfShapeKernel(reinterpret_cast(ctx->GetCudaStream()), + voutput_shape, + indices_shape, + indices_data, + updates_data, + output_data, + reduction_, + masked_value_); + return nullptr; + } + + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 0) // shape + return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + private: + ScatterReduction reduction_; + int64_t masked_value_; +}; + +} // namespace contrib diff --git a/operators/cuda/scatter_nd_of_shape_impl.cu b/operators/cuda/scatter_nd_of_shape_impl.cu new file mode 100644 index 000000000..2cd727548 --- /dev/null +++ b/operators/cuda/scatter_nd_of_shape_impl.cu @@ -0,0 +1,264 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "scatter_nd_of_shape_impl.cuh" +#include "cuda_type.h" + +namespace contrib { + +#define _ENFORCE(cond, msg) \ + if (!(cond)) ORTX_CXX_API_THROW(msg, ORT_RUNTIME_EXCEPTION); + +#ifndef HIP_LONG +#define HIP_LONG int32_t +#endif + +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +template +__device__ __forceinline__ void _add_inplace(T& x, const T a) { x += a; } + +template <> +__device__ __forceinline__ void _add_inplace(half& x, const half a) { +#if __CUDA_ARCH__ < 700 + x = __float2half(__half2float(x) + __half2float(a)); +#else + x += a; +#endif +} + +template +__global__ void +addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__ indices_data, + const T* __restrict__ updates_data, const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride) { + HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= stride) + return; + + for (size_t i = 0; i < nrows; ++i) { + output_data[i * stride + id] = 0; + } + + int64_t index; + for (size_t i = 0; i < indice_size; ++i) { + index = (indices_data[i] + nrows) % nrows; + _add_inplace(output_data[index * stride + id], updates_data[i * stride + id]); + } +} + +template +__global__ void masked_addition_inplace_kernel(T *__restrict__ output_data, + const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, + const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride, + const int64_t masked_value) { + auto id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= stride) + return; + + for (size_t i = 0; i < nrows; ++i) { + output_data[i * stride + id] = 0; + } + + for (size_t i = 0; i < indice_size; ++i) { + if (indices_data[i] == masked_value) + continue; + _add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]); + } +} + +template +__global__ void masked_addition_inplace_kernelN(T *__restrict__ output_data, + const int64_t *__restrict__ indices_data, + const T *__restrict__ updates_data, + const CUDA_LONG indice_size, + const CUDA_LONG nrows, const CUDA_LONG stride, + const int64_t masked_value) { + __shared__ int64_t shared_indices[NTHREAD]; + + CUDA_LONG tid = threadIdx.x; + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + + for (size_t i = 0; i < nrows; ++i) { + output_data[i * stride + id] = 0; + } + + int begin = 0; + int end = std::min(begin + NTHREAD, indice_size); + while (begin < end && (end == begin + NTHREAD)) { + shared_indices[tid] = indices_data[tid + begin]; + __syncthreads(); + + for (size_t i = begin; i < end; ++i) { + if (shared_indices[tid] == masked_value) + continue; + _add_inplace(output_data[shared_indices[tid] * stride + id], + updates_data[i * stride + id]); + } + + begin = end; + end = std::min(begin + NTHREAD, indice_size); + } + + for (size_t i = begin; i < indice_size; ++i) { + if (indices_data[i] == masked_value) + continue; + _add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]); + } +} + +template +NTYPE flattened_dimension(const std::vector& values, size_t first = 0) { + NTYPE r = 1; + for (auto it = values.begin() + first; it != values.end(); ++it) + r *= *it; + return r; +} + +template +cudaError_t ScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices_data, + const T* updates_data, + T* output_data, + ScatterReduction reduction) { + if (reduction != ScatterReduction::Add) + ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION); + size_t indice_size = static_cast(flattened_dimension(indices_shape)); + size_t output_size = static_cast(flattened_dimension(output_shape)); + size_t rank = output_shape.size() - indices_shape.size(); + size_t stride = static_cast(flattened_dimension(output_shape, output_shape.size() - 1 - rank)); + size_t nrows = output_size / stride; + + int threads_per_block = 256; + int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; + + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + using TT = typename CudaT::MappedType; + addition_inplace_kernel<<>>(reinterpret_cast(output_data), indices_data, + reinterpret_cast(updates_data), + indice_size, nrows, stride); + return cudaGetLastError(); +} + +template +cudaError_t MaskedScatterNDOfShapeKernel(cudaStream_t stream, const std::vector &input_shape, + const std::vector &indices_shape, + const int64_t *indices_data, const T *updates_data, + T *output_data, + ScatterReduction reduction, int64_t masked_value) { + if (reduction != ScatterReduction::Add) + ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION); + size_t indice_size = static_cast(flattened_dimension(indices_shape)); + size_t input_size = static_cast(flattened_dimension(input_shape)); + size_t stride = input_shape[input_shape.size() - 1]; + size_t nrows = input_size / stride; + + std::vector next_batch(indice_size); + std::vector processed(input_shape[0], 0); + std::vector processed_once(input_shape[0], 0); + + int threads_per_block = 256; + bool split = stride / threads_per_block <= 32; + + int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block; + dim3 threads(threads_per_block); + dim3 blocks(blocks_per_grid); + + using TT = typename CudaT::MappedType; + + if (split && stride >= 256 && threads_per_block == 256) { + masked_addition_inplace_kernelN<<>>( + reinterpret_cast(output_data), indices_data, + reinterpret_cast(updates_data), + indice_size, nrows, stride, masked_value); + } else { + masked_addition_inplace_kernel<<>>( + reinterpret_cast(output_data), indices_data, + reinterpret_cast(updates_data), + indice_size, nrows, stride, masked_value); + } + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const float* updates, + float* output, + ScatterReduction reduction) { + return ScatterNDOfShapeKernel(stream, + output_shape, + indices_shape, + indices, + updates, + output, + reduction); +} + +template <> +cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const ortc::MFloat16* updates, + ortc::MFloat16* output, + ScatterReduction reduction) { + return ScatterNDOfShapeKernel(stream, + output_shape, + indices_shape, + indices, + updates, + output, + reduction); +} + +template <> +cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const float* updates, + float* output, + ScatterReduction reduction, + int64_t masked_value) { + return MaskedScatterNDOfShapeKernel(stream, + output_shape, + indices_shape, + indices, + updates, + output, + reduction, + masked_value); +} + +template <> +cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const ortc::MFloat16* updates, + ortc::MFloat16* output, + ScatterReduction reduction, + int64_t masked_value) { + return MaskedScatterNDOfShapeKernel(stream, + output_shape, + indices_shape, + indices, + updates, + output, + reduction, + masked_value); +} + +} // namespace contrib diff --git a/operators/cuda/scatter_nd_of_shape_impl.cuh b/operators/cuda/scatter_nd_of_shape_impl.cuh new file mode 100644 index 000000000..75a5479ab --- /dev/null +++ b/operators/cuda/scatter_nd_of_shape_impl.cuh @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +namespace contrib { + +enum class ScatterReduction : int { + None = 0, + Add = 1, + Mul = 2, + Min = 3, + Max = 4, +}; + +template +cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const T* updates, + T* output, + ScatterReduction reduction); + +template +cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream, + const std::vector& output_shape, + const std::vector& indices_shape, + const int64_t* indices, + const T* updates, + T* output, + ScatterReduction reduction, + int64_t masked_value); + +} // namespace contrib diff --git a/test/cuda/test_cudaops.py b/test/cuda/test_cudaops.py index 000f3796b..c5c8dc6e0 100644 --- a/test/cuda/test_cudaops.py +++ b/test/cuda/test_cudaops.py @@ -4,6 +4,7 @@ from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto from onnx.reference import ReferenceEvaluator from onnx.reference.op_run import OpRun +from onnx.reference.ops.op_scatternd import _scatter_nd_impl from onnxruntime_extensions import make_onnx_model from onnxruntime_extensions import get_library_path as _get_library_path @@ -14,6 +15,15 @@ def has_cuda(): return "CUDAExecutionProvider" in _ort.get_available_providers() +class ScatterNDOfShape(OpRun): + op_domain = "ai.onnx.contrib" + + def _run(self, shape, indices, updates, reduction=None, strategy=None): + data = np.zeros(shape, dtype=updates.dtype) + y = _scatter_nd_impl(data, indices, updates, reduction=reduction) + return (y,) + + class NegXPlus1(OpRun): op_domain = "ai.onnx.contrib" @@ -274,6 +284,193 @@ def test_add_shared_input_cuda_broadcast2(self): shapec=(3, 2, 3), ) + def _scatternd_of_shape_optimize_cuda(self, optimize, dim3, itype): + indices_shape = ["i", "j", 1] if dim3 else ["j", 1] + updates_shape = ["i", "j", "b"] if dim3 else ["j", "b"] + + model = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "ScatterNDOfShape", + inputs=["shape", "indices", "updates"], + outputs=["y"], + reduction="add", + strategy="optimize" if optimize else "none", + domain="ai.onnx.contrib", + ) + ], + "nd", + [ + helper.make_tensor_value_info("shape", TensorProto.INT64, [2]), + helper.make_tensor_value_info("indices", TensorProto.INT64, indices_shape), + helper.make_tensor_value_info("updates", itype, updates_shape), + ], + [helper.make_tensor_value_info("y", itype, [None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + if dim3: + shape = (128, 1024) + indices = np.zeros((2, 64, 1)).astype(np.int64) + indices[:, ::2, 0] = 87 + indices[:, ::3, 0] = 85 + updates = np.ones((2, 64, 1024)).astype(np.float32) + else: + shape = (128, 1024) + indices = np.zeros((128, 1)).astype(np.int64) + indices[::2, 0] = 87 + indices[::3, 0] = 85 + updates = np.ones((128, 1024)).astype(np.float32) + if itype != 1: + updates = updates.astype(np.float16) + feeds = dict(shape=np.array(shape, dtype=np.int64), indices=indices, updates=updates) + + ref = ReferenceEvaluator(model, new_ops=[ScatterNDOfShape]) + expected = ref.run(None, feeds)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + sess = _ort.InferenceSession(model.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + ro = None + got = sess.run(None, feeds, ro)[0] + self.assertEqual(expected.tolist(), got.tolist()) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_scatternd_of_shape_optimize_cuda(self): + with self.subTest(optimize=True, dim3=True): + self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT) + self._scatternd_of_shape_optimize_cuda(False, False, TensorProto.FLOAT) + self._scatternd_of_shape_optimize_cuda(False, True, TensorProto.FLOAT) + with self.subTest(optimize=True, dim3=False): + self._scatternd_of_shape_optimize_cuda(True, False, TensorProto.FLOAT) + with self.subTest(optimize=True, dim3=True, itype=TensorProto.FLOAT16): + self._scatternd_of_shape_optimize_cuda(True, True, TensorProto.FLOAT16) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_scatternd_of_shape_standalone_cuda(self): + self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT) + self._scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16) + self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT) + self._scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16) + + def _masked_scatternd_of_shape_cuda(self, reduction, line, itype, big): + dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 + + model1 = helper.make_model( + helper.make_graph( + [ + helper.make_node("Equal", ["indices", "mone"], ["masked_indices"]), + helper.make_node( + "Where", + ["masked_indices", "zero", "updates"], + ["masked_updates"], + ), + helper.make_node( + "ScatterND", + inputs=["data", "indices", "masked_updates"], + outputs=["y"], + reduction=reduction, + ), + ], + "nd", + [ + helper.make_tensor_value_info("data", itype, [None, None]), + helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]), + helper.make_tensor_value_info("updates", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("y", itype, [None, None])], + [ + numpy_helper.from_array(np.array([-1], dtype=np.int64), name="mone"), + numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"), + ], + ), + opset_imports=[helper.make_opsetid("", 18)], + ir_version=9, + ) + + model2 = helper.make_model( + helper.make_graph( + [ + helper.make_node( + "MaskedScatterNDOfShape", + inputs=["shape", "indices", "updates"], + outputs=["y"], + reduction=reduction, + maskedValue=-1, + domain="ai.onnx.contrib", + ) + ], + "nd", + [ + helper.make_tensor_value_info("shape", TensorProto.INT64, [None]), + helper.make_tensor_value_info("indices", TensorProto.INT64, [None, None, 1]), + helper.make_tensor_value_info("updates", itype, [None, None, None]), + ], + [helper.make_tensor_value_info("y", itype, [None, None])], + ), + opset_imports=[ + helper.make_opsetid("", 18), + helper.make_opsetid("ai.onnx.contrib", 1), + ], + ir_version=9, + ) + + if big: + data = np.zeros((2048, 4096), dtype=dtype) + indices = np.ones((2, 1024), dtype=np.int64) + indices = indices[..., np.newaxis] + shape = tuple(indices.shape[:2]) + (data.shape[-1],) + updates = (np.arange(np.prod(shape)).reshape(shape) / np.prod(shape)).astype(dtype) + else: + data = np.zeros((32, 16), dtype=dtype) + indices = np.array( + [ + [0, 1, 2], + [2, 3, 4], + [-1, 30, 31], + [-1, 7, 8], + [10, 11, -1], + [20, -1, 21], + ], + dtype=np.int64, + ) + indices = indices[..., np.newaxis] + shape = (6, 3, data.shape[-1]) + updates = (np.arange(np.prod(shape)).reshape(shape) + 1).astype(dtype) + + feeds1 = dict(data=data, indices=indices, updates=updates) + feeds2 = dict(shape=np.array(data.shape, dtype=np.int64), indices=indices, updates=updates) + ref = ReferenceEvaluator(model1) + expected = ref.run(None, feeds1)[0] + + opts = _ort.SessionOptions() + opts.register_custom_ops_library(_get_library_path()) + # opts.log_severity_level = 0 + # opts.log_verbosity_level = 0 + sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"]) + got = sess.run(None, feeds2)[0] + assert_almost_equal(expected.tolist(), got.tolist()) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_masked_scatternd_of_shape_standalone_cuda_small(self): + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, False) + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, False) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, False) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, False) + + @unittest.skipIf(not has_cuda(), reason="cuda not available") + def test_masked_scatternd_of_shape_standalone_cuda_big(self): + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT, True) + self._masked_scatternd_of_shape_cuda("add", 0, TensorProto.FLOAT16, True) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT, True) + self._masked_scatternd_of_shape_cuda("add", 1, TensorProto.FLOAT16, True) + def _transpose_cast_cuda(self, itype): dtype = np.float32 if itype == TensorProto.FLOAT else np.float16 itype2 = TensorProto.FLOAT if itype == TensorProto.FLOAT16 else TensorProto.FLOAT16