From 49d3d54c366d838a9095785b7097d681b789c322 Mon Sep 17 00:00:00 2001 From: Kee Date: Thu, 21 Nov 2024 21:59:44 +0800 Subject: [PATCH 1/2] [VSINPU]Split\Pad and some element-wise OPs support Description -Add split/pad/neg/not/ceil/round/min/max op support -Fix conv2d op default pads value issue Signed-off-by: Kee --- .../vsinpu/builders/impl/conv_op_builder.h | 2 +- .../builders/impl/elementwise_op_builder.h | 6 + .../vsinpu/builders/impl/pad_op_builder.h | 191 ++++++++++++++++++ .../vsinpu/builders/impl/split_op_builder.h | 190 +++++++++++++++++ .../vsinpu/builders/op_builder_factory.h | 12 +- 5 files changed, 399 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h create mode 100644 onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h index 3ed432c2efa1c..5278efdb4a400 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/conv_op_builder.h @@ -112,7 +112,7 @@ class ConvOpBuilder : public BaseOpBuilder { } } } else { - auto pads = helper.Get("pads", std::vector{0U, 0U}); + auto pads = helper.Get("pads", std::vector{0U, 0U, 0U, 0U}); if (group != 1 && group != weight_tensor->GetShape()[OChannel_idx]) { if (is_1d_conv) { op = graph_ep->GetGraph() diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h index 4c10ba01b1c2e..7da1e6e674601 100644 --- a/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h +++ b/onnxruntime/core/providers/vsinpu/builders/impl/elementwise_op_builder.h @@ -65,6 +65,12 @@ ELEMENTWISE_OP_BUILDER(Floor, Floor); ELEMENTWISE_OP_BUILDER(Log, Log); ELEMENTWISE_OP_BUILDER(Sin, Sin); ELEMENTWISE_OP_BUILDER(HardSwish, HardSwish); +ELEMENTWISE_OP_BUILDER(Neg, Neg); +ELEMENTWISE_OP_BUILDER(Not, LogicalNot); +ELEMENTWISE_OP_BUILDER(Ceil, Ceil); +ELEMENTWISE_OP_BUILDER(Round, Round); +ELEMENTWISE_OP_BUILDER(Min, Minimum); +ELEMENTWISE_OP_BUILDER(Max, Maximum); class PowOpBuilder : public BaseOpBuilder { bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h new file mode 100644 index 0000000000000..19cbe4e7f3e48 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/pad_op_builder.h @@ -0,0 +1,191 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/optimizer/initializer.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +typedef tim::vx::ops::PadV2::pad_mode_type PadMode; + +class PadOpBuilder : public BaseOpBuilder { + public: + int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 11; } + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + const auto mode = helper.Get("mode", "constant"); + auto input_defs = node->InputDefs(); + size_t num_inputs = input_defs.size(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + const auto& initializers = graph_viewer.GetAllInitializedTensors(); + + if (mode == "wrap") { + LOGS_DEFAULT(WARNING) << "`wrap` mode Pad is not currently supported for now."; + return false; + } + if (mode == "constant") { + if (num_inputs > 2 && input_defs[2]->Exists()) { + // only support if `constant_value` input is a constant initializer + if (!Contains(initializers, input_defs[2]->Name())) { + LOGS_DEFAULT(WARNING) << "constant_value must be a constant initializer."; + return false; + } + } + } + // only support if `pads` input is known and does not contain negative values + { + const auto* pads_initializer = graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!pads_initializer) { + LOGS_DEFAULT(WARNING) << "pads must be a constant initializer"; + return false; + } + + Initializer unpacked_tensor(*pads_initializer); + auto tensor_data = unpacked_tensor.DataAsSpan(); + for (size_t i = 0; i < unpacked_tensor.size(); i++) { + if (tensor_data[i] < 0) { + LOGS_DEFAULT(WARNING) << "Negative pad value is not supported: pads[" + << i << "] = " << tensor_data[i]; + return false; + } + } + } + return true; + } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (0 == i) { + if (!util::IsTypeSupported(&iodef.node_arg) || + (*iodef.node_arg.Type() == "tensor(int64)") || + (*iodef.node_arg.Type() == "tensor(bool)")) { + LOGS_DEFAULT(WARNING) << "Unspport tensor data type:" << *iodef.node_arg.Type(); + return false; + } + } else if (1 == i) { + if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "pads must be a constant initializer."; + return false; + } + } else if (2 == i) { + if (iodef.node_arg.Exists() && !Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "constant_value must be a constant initializer."; + return false; + } + } else if (i == 3) { + if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "axes must be a constant initializer.."; + return false; + } + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Pad Op."; + NodeAttrHelper helper(node_unit); + const auto mode = helper.Get("mode", "constant"); + auto input_defs = node_unit.Inputs(); + PadMode pad_mode = PadMode::PAD_MODE_CONSTANT; + float const_val = 0.0f; + std::vector axes_tensor_data; + int32_t input_rank = inputs[0]->GetShape().size(); + + if (mode == "constant") { + pad_mode = PadMode::PAD_MODE_CONSTANT; + } else if (mode == "reflect") { + pad_mode = PadMode::PAD_MODE_REFLECT; + } else if (mode == "edge") { + pad_mode = PadMode::PAD_MODE_EDGE; + } else { + LOGS_DEFAULT(WARNING) << "`wrap` mode Pad is not currently supported for now."; + return false; + } + + // `pads` input + std::vector onnx_pads(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(onnx_pads.data()); + + // `constant_value` input + if (inputs.size() > 2 && pad_mode == PadMode::PAD_MODE_CONSTANT) { + if (input_defs[2].node_arg.Exists()) { + inputs[2]->CopyDataFromTensor(&const_val); + } + } + // `axes` input + if (inputs.size() > 3) { + // optional input axes is provided, use axes initializer data + std::vector axes_tensor(inputs[3]->GetSpec().GetElementNum()); + inputs[3]->CopyDataFromTensor(axes_tensor.data()); + std::transform( + axes_tensor.begin(), axes_tensor.end(), std::back_inserter(axes_tensor_data), + [input_rank](int64_t axis) { return HandleNegativeAxis(axis, input_rank); }); + } else { + // if not provided, make a default axes as [0, 1, ..., input_rank - 1] + std::vector default_axes(input_rank); + std::iota(std::begin(default_axes), std::end(default_axes), 0); + axes_tensor_data = std::move(default_axes); + } + + int64_t num_axes = axes_tensor_data.size(); + std::vector front_size(input_rank, 0); + std::vector back_size(input_rank, 0); + + int64_t axes_index = 0; + for (int64_t axes : axes_tensor_data) { + front_size[axes] = onnx_pads[axes_index]; + back_size[axes] = onnx_pads[axes_index + num_axes]; + axes_index++; + } + + std::reverse(front_size.begin(), front_size.end()); + std::reverse(back_size.begin(), back_size.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + front_size, back_size, const_val, pad_mode); + op->BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h new file mode 100644 index 0000000000000..e08416bda70d4 --- /dev/null +++ b/onnxruntime/core/providers/vsinpu/builders/impl/split_op_builder.h @@ -0,0 +1,190 @@ +/**************************************************************************** + * + * Copyright (c) 2024 Vivante Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + * + *****************************************************************************/ +#pragma once +#include +#include +#include +#include +#include +#include "core/optimizer/initializer.h" +#include "core/providers/vsinpu/builders/impl/base_op_builder.h" +#include "core/providers/common.h" +#include "core/providers/shared/utils/utils.h" + +namespace onnxruntime { +namespace vsi { +namespace npu { + +class SplitOpBuilder : public BaseOpBuilder { + public: + bool IsOpSupported(const onnxruntime::GraphViewer& graph_viewer, + const Node* node) const override { + NodeAttrHelper helper(*node); + auto axis = helper.Get("axis", 0); + auto input_defs = node->InputDefs(); + size_t num_inputs = input_defs.size(); + size_t num_outputs = node->OutputDefs().size(); + auto input_shape = vsi::npu::util::GetTensorShape(*input_defs[0]); + int32_t rank = input_shape.NumDimensions(); + std::vector splits_list; + bool split_provided = false; + if (axis >= rank || axis < -rank) { + LOGS_DEFAULT(WARNING) << "Axis is invalid in Split. Axis(" << axis + << ") is out of rank[" << -rank << "," << rank - 1 << "]"; + return false; + } + axis = HandleNegativeAxis(axis, rank); + const auto split_dims_at_axis = input_shape.GetDims()[axis]; + if (num_inputs > 1 && input_defs[1]->Exists()) { + // if optional input `split` is provided + const auto* splits = graph_viewer.GetConstantInitializer(input_defs[1]->Name()); + if (!splits) { + LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; + return false; + } + Initializer unpacked_tensor(*splits); + auto split_sizes_ = unpacked_tensor.DataAsSpan(); + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + split_provided = true; + } + if (num_inputs == 1) { + // opset1,2,11 split as attribute + if (helper.HasAttr("split")) { + auto split_sizes_ = *helper.GetInt64s("split"); + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + split_provided = true; + } else if (node->SinceVersion() >= 18) { + const auto outputs_count = helper.GetInt64("num_outputs"); + if (!outputs_count.has_value()) { + LOGS_DEFAULT(WARNING) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; + return false; + } + if (outputs_count.value() != static_cast(num_outputs) || + outputs_count.value() > split_dims_at_axis) { + LOGS_DEFAULT(WARNING) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size " + "of dimension being split. num_outputs: " + << outputs_count.value(); + return false; + } + } + } + if (!split_provided) { + // populate split sizes based on num_outputs so existing code can be utilized + int32_t size = narrow(std::ceil(float(split_dims_at_axis) / num_outputs)); + int32_t remainder = split_dims_at_axis % size; + std::vector split_sizes_(num_outputs, size); + if (remainder) { + split_sizes_.back() = remainder; + } + splits_list.assign(split_sizes_.begin(), split_sizes_.end()); + } + + uint32_t sum_of_splits = std::accumulate(splits_list.begin(), splits_list.end(), SafeInt(0)); + if (sum_of_splits != split_dims_at_axis) { + LOGS_DEFAULT(WARNING) << "Sum of the 'split' input values must equal to the dim value at 'axis' specified. " + << "dim value at 'axis' specified: " + << split_dims_at_axis + << ", sum of 'split' input values: " + << sum_of_splits; + return false; + } + if (!std::all_of(splits_list.begin(), splits_list.end(), [](int64_t value) { return value >= 0; })) { + LOGS_DEFAULT(WARNING) << "Invalid value in 'split' attribute. All values must be > 0"; + return false; + } + auto average_split = sum_of_splits / num_outputs; + if (!std::all_of(splits_list.begin(), splits_list.end(), [average_split](int64_t value) { return value == average_split; })) { + // TO DO, remove this check after driver supports it. + LOGS_DEFAULT(WARNING) << "Uneven splits are not currently supported for now."; + return false; + } + + return true; + } + + bool HasSupportedInputOutputsImpl(const InitializedTensorSet& initializers, + const NodeUnit& node_unit) const override { + for (size_t i = 0; i < node_unit.Inputs().size(); ++i) { + const auto& iodef = node_unit.Inputs()[i]; + if (0 == i) { + if (!util::IsTypeSupported(&iodef.node_arg) || + (*iodef.node_arg.Type() == "tensor(int64)") || + (*iodef.node_arg.Type() == "tensor(bool)")) { + LOGS_DEFAULT(WARNING) << "Unsupport tensor data type:" << *iodef.node_arg.Type(); + return false; + } + } else if (!Contains(initializers, iodef.node_arg.Name())) { + LOGS_DEFAULT(WARNING) << "Optional input 'split' must be a constant initializer if provided."; + return false; + } + } + return true; + } + + bool HandleBuildOp(vsi::npu::GraphEP* graph_ep, + std::vector>& inputs, + std::vector>& outputs, + const NodeUnit& node_unit) override { + LOGS_DEFAULT(VERBOSE) << "Creating Split Op."; + NodeAttrHelper helper(node_unit); + auto axis = helper.Get("axis", 0); + axis = util::ReverseAxis(axis, inputs[0]->GetShape().size()); + const auto split_dims_at_axis = inputs[0]->GetShape()[axis]; + auto num_outputs = outputs.size(); + // transform splite vector to timvx slice + std::vector onnx_split; + if (inputs.size() > 1) { + std::vector split_sizes_(inputs[1]->GetSpec().GetElementNum()); + inputs[1]->CopyDataFromTensor(split_sizes_.data()); + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + if (inputs.size() == 1) { + if (helper.HasAttr("split")) { + auto split_sizes_ = *helper.GetInt64s("split"); + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + if (node_unit.SinceVersion() >= 18 || !helper.HasAttr("split")) { + // populate split sizes based on num_outputs so existing code can be utilized + int32_t size = narrow(std::ceil(float(split_dims_at_axis) / num_outputs)); + int32_t remainder = split_dims_at_axis % size; + std::vector split_sizes_(num_outputs, size); + if (remainder) { + split_sizes_.back() = remainder; + } + onnx_split.assign(split_sizes_.begin(), split_sizes_.end()); + } + } + std::vector slices(onnx_split.begin(), onnx_split.end()); + std::reverse(slices.begin(), slices.end()); + + auto op = graph_ep->GetGraph()->CreateOperation( + axis, slices); + op->BindInput(inputs[0]).BindOutputs(outputs); + graph_ep->GetOps().push_back(std::move(op)); + return true; + } +}; +} // namespace npu +} // namespace vsi +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h index dc0969429b8ff..fcf9479a6058b 100644 --- a/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/vsinpu/builders/op_builder_factory.h @@ -53,6 +53,8 @@ #include "impl/cast_op_builder.h" #include "impl/dropout_op_builder.h" #include "impl/slice_op_builder.h" +#include "impl/split_op_builder.h" +#include "impl/pad_op_builder.h" namespace onnxruntime { namespace vsi { namespace npu { @@ -110,7 +112,15 @@ static const std::map reg = { REGISTER_OP_BUILDER("Resize", ResizeOpBuilder), REGISTER_OP_BUILDER("Cast", CastOpBuilder), REGISTER_OP_BUILDER("Dropout", DropoutOpBuilder), - REGISTER_OP_BUILDER("Slice", SliceOpBuilder) + REGISTER_OP_BUILDER("Slice", SliceOpBuilder), + REGISTER_OP_BUILDER("Split", SplitOpBuilder), + REGISTER_OP_BUILDER("Neg", NegOpBuilder), + REGISTER_OP_BUILDER("Not", NotOpBuilder), + REGISTER_OP_BUILDER("Ceil", CeilOpBuilder), + REGISTER_OP_BUILDER("Round", RoundOpBuilder), + REGISTER_OP_BUILDER("Min", MinOpBuilder), + REGISTER_OP_BUILDER("Max", MaxOpBuilder), + REGISTER_OP_BUILDER("Pad", PadOpBuilder) #undef REGISTER_OP_BUILDER }; From 450f7c4dfb56cbb36cd82f7a38610ba0093963a1 Mon Sep 17 00:00:00 2001 From: Kee Date: Fri, 22 Nov 2024 00:00:30 +0800 Subject: [PATCH 2/2] [VSINPU EP]Add VSINPU EP to support python bindings Signed-off-by: Kee --- cmake/onnxruntime_python.cmake | 10 ++++++++++ onnxruntime/python/onnxruntime_pybind_schema.cc | 3 +++ onnxruntime/python/onnxruntime_pybind_state.cc | 4 ++++ onnxruntime/python/onnxruntime_pybind_state_common.h | 1 + 4 files changed, 18 insertions(+) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index d2c022e4e0269..5a87252b08573 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -170,6 +170,7 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} ${PROVIDERS_NNAPI} + ${PROVIDERS_VSINPU} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} ${PROVIDERS_RKNPU} @@ -1018,4 +1019,13 @@ if (onnxruntime_USE_QNN) endif() endif() +if (onnxruntime_USE_VSINPU) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $/onnxruntime/capi/ + ) +endif() + endif() diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 1319e8f6fe959..77503c2ef4fcb 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -69,6 +69,9 @@ void addGlobalSchemaFunctions(pybind11::module& m) { #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), #endif +#ifdef USE_VSINPU + onnxruntime::VSINPUProviderFactoryCreator::Create(), +#endif #ifdef USE_RKNPU onnxruntime::RknpuProviderFactoryCreator::Create(), #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 54accf7ed88f3..db32a82ac4361 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1180,6 +1180,10 @@ std::unique_ptr CreateExecutionProviderInstance( const auto partitioning_stop_ops_list = session_options.config_options.GetConfigEntry( kOrtSessionOptionsConfigNnapiEpPartitioningStopOps); return onnxruntime::NnapiProviderFactoryCreator::Create(0, partitioning_stop_ops_list)->CreateProvider(); +#endif + } else if (type == kVSINPUExecutionProvider) { +#ifdef USE_VSINPU + return onnxruntime::VSINPUProviderFactoryCreator::Create()->CreateProvider(); #endif } else if (type == kRknpuExecutionProvider) { #ifdef USE_RKNPU diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index b71081bf20efc..995341b0f8dc0 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -440,6 +440,7 @@ std::shared_ptr CreateExecutionProviderFactory_ArmNN( std::shared_ptr CreateExecutionProviderFactory_DML(int device_id); std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); +std::shared_ptr CreateExecutionProviderFactory_VSINPU(); std::shared_ptr CreateExecutionProviderFactory_Rknpu(); std::shared_ptr CreateExecutionProviderFactory_CoreML(uint32_t flags); constexpr const char* kDefaultExecutionProviderEntry = "GetProvider";