From 047f32c79d85e38cdbf6f7cb7c06701e1d7fdaba Mon Sep 17 00:00:00 2001 From: mingyueliuh <131847423+mingyueliuh@users.noreply.github.com> Date: Sat, 31 Aug 2024 11:57:23 -0400 Subject: [PATCH] [VitisAI] Remove shape infer from bridge ort (#21331) ### Description Vitis AI EP's custom op are completely self contained within Vitis AI EP implementation (rather than needing to add static functions in provider_bridge). --------- Co-authored-by: liumingyue --- .../shared_library/provider_interfaces.h | 2 +- .../providers/vitisai/imp/register_xir_ops.cc | 9 +- .../vitisai/include/vaip/vaip_ort_api.h | 2 +- .../core/session/provider_bridge_ort.cc | 124 +----------------- 4 files changed, 4 insertions(+), 133 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 1059443469067..041b387ff874d 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -567,7 +567,7 @@ struct ProviderHost { virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; - virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0; virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0; virtual const std::string& OpSchema__inputs__GetName(const ONNX_NAMESPACE::OpSchema* p, const size_t i) = 0; virtual const std::string& OpSchema__inputs__GetTypeStr(const ONNX_NAMESPACE::OpSchema* p, const size_t i) = 0; diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index 03458f42d5f28..ea5687f2691b1 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -13,14 +13,7 @@ void register_xir_ops(const std::vector& domains) { for (auto domain : domains) { for (auto op : domain->custom_ops_) { if (Provider_GetHost()->GetSchema(op->GetName(op), op->GetStartVersion(op), domain->domain_) == nullptr) { - auto name = op->GetName(op); - if ((std::string)name == "super_layer") { - Provider_GetHost()->RegisterSchema(domain->domain_, op, 1); - } else if ((std::string)name == "FixNeuron") { - Provider_GetHost()->RegisterSchema(domain->domain_, op, 2); - } else { - Provider_GetHost()->RegisterSchema(domain->domain_, op, 3); - } + Provider_GetHost()->RegisterSchema(domain->domain_, op); } } } diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 17fd9ef21d34a..db70ef0cc17d5 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (7u) +#define VAIP_ORT_API_MAJOR (8u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index dc5b983f86cbb..ce259946944ca 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -682,135 +682,13 @@ struct ProviderHostImpl : ProviderHost { int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) override { return p->metadata_props_size(); } ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) override { return p->add_metadata_props(); } - static int32_t convert_elem_type(const ONNX_NAMESPACE::AttributeProto* data_type) { - int32_t elemType = 0; - if (data_type->s() == "float32") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; - } else if (data_type->s() == "int8") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_INT8; - } else if (data_type->s() == "uint8") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8; - } else if (data_type->s() == "int32") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32; - } else if (data_type->s() == "uint32") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT32; - } else if (data_type->s() == "int64") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64; - } else if (data_type->s() == "uint64") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT64; - } else if (data_type->s() == "int1") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL; - } else if (data_type->s() == "bfloat16") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; - } else if (data_type->s() == "float16") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; - } else if (data_type->s() == "uint16") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16; - } else if (data_type->s() == "int16") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16; - } else if (data_type->s() == "double") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_DOUBLE; - } else if (data_type->s() == "string") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_STRING; - } else if (data_type->s() == "complex64") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64; - } else if (data_type->s() == "complex128") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128; - } else if (data_type->s() == "float8e4m3fn") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN; - } else if (data_type->s() == "float8e4m3fnuz") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ; - } else if (data_type->s() == "float8e5m2") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; - } else if (data_type->s() == "float8e5m2funz") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ; - } else if (data_type->s() == "uint4") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4; - } else if (data_type->s() == "int4") { - elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4; - } - return elemType; - } - - static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { - auto num_output = ctx.getNumOutputs(); - if (num_output == 1) { - auto* shape = ctx.getAttribute("shape"); - auto* data_type = ctx.getAttribute("data_type"); - if (data_type == nullptr) { - std::cerr << "Custom op is missing `data_type` attr." << std::endl; - return; - } - int32_t elemType = convert_elem_type(data_type); - ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); - if (shape != nullptr) { - for (auto i = 0; i < shape->ints_size(); ++i) { - ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); - } - } else { - // set scalar type. - ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); - } - } else { - for (auto idx = 0u; idx < num_output; idx++) { - auto* shape = ctx.getAttribute("shape_" + std::to_string(idx)); - auto* data_type = ctx.getAttribute("data_type_" + std::to_string(idx)); - if (shape == nullptr || data_type == nullptr) { - // this output is optional - } else { - int32_t elemType = convert_elem_type(data_type); - ONNX_NAMESPACE::updateOutputElemType(ctx, idx, elemType); - for (auto i = 0; i < shape->ints_size(); ++i) { - ONNX_NAMESPACE::getOutputShape(ctx, idx, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); - } - } - } - } - } - - static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); - } - - static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - auto num_inputs = ctx.getNumInputs(); - - // Run inferencing on the subgraph - auto* graphInferencer = ctx.getGraphAttributeInferencer("body"); - - std::vector input_data; - std::vector subgraph_input_types; - for (size_t i = 0; i < num_inputs; ++i) { - input_data.push_back(ctx.getInputData(i)); - subgraph_input_types.push_back(ctx.getInputType(i)); - } - - auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); - for (size_t i = 0, end = output_types.size(); i < end; ++i) { - *ctx.getOutputType(i) = *output_types[i]; - } - } - void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override { + void RegisterSchema(const std::string& domain, const OrtCustomOp* op) override { auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); const auto& domain_to_version_map = domain_instance.Map(); if (domain_to_version_map.find(domain) == domain_to_version_map.end()) { domain_instance.AddDomainToVersion(domain, 1, 1000); } auto schema = CreateSchema(domain, {op}); - switch (type) { - case 1: - schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); - break; - case 2: - schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); - break; - case 3: - schema.TypeAndShapeInferenceFunction(xir_shape_infer); - break; - default: - break; - } ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION); } const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) override {