From bd0cf82565f6964d1f723d7c4035dbbc6f3e1cf8 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 15 Nov 2024 23:25:41 +0000 Subject: [PATCH 1/3] update --- .../core/providers/shared_library/provider_interfaces.h | 2 +- .../core/providers/shared_library/provider_wrappedtypes.h | 4 +++- onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc | 3 ++- onnxruntime/core/session/provider_bridge_ort.cc | 4 ++-- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 1436afa41c2f8..f9f2bb69a9d1a 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -960,7 +960,7 @@ struct ProviderHost { // GraphViewer virtual void GraphViewer__operator_delete(GraphViewer* p) = 0; - virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger) = 0; + virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger, const ModelMetaData&) = 0; virtual const std::string& GraphViewer__Name(const GraphViewer* p) noexcept = 0; virtual const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 5e8996d590db8..a82ddfe64c64b 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1022,11 +1022,13 @@ struct Graph final { PROVIDER_DISALLOW_ALL(Graph) }; +using ModelMetaData = std::unordered_map; + class GraphViewer final { public: static void operator delete(void* p) { g_host->GraphViewer__operator_delete(reinterpret_cast(p)); } - std::unique_ptr CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); } + std::unique_ptr CreateModel(const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) const { return g_host->GraphViewer__CreateModel(this, logger, metadata); } const std::string& Name() const noexcept { return g_host->GraphViewer__Name(this); } const std::filesystem::path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); } diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index ef45d6c85d6a9..f1c6a3477c364 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -127,8 +127,9 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, ORT_ENFORCE(graph_build.Resolve().IsOK()); // Serialize modelproto to string + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); auto new_graph_viewer = graph_build.CreateGraphViewer(); - auto model = new_graph_viewer->CreateModel(*logger); + auto model = new_graph_viewer->CreateModel(*logger, metadata); auto model_proto = model->ToProto(); new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 0aa93bce354e8..19a1ad8a5a160 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1156,8 +1156,8 @@ struct ProviderHostImpl : ProviderHost { // GraphViewer (wrapped) void GraphViewer__operator_delete(GraphViewer* p) override { delete p; } - std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger) override { - return std::make_unique(graph_viewer->Name(), true, ModelMetaData(), PathString(), + std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) override { + return std::make_unique(graph_viewer->Name(), true, metadata, PathString(), #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), #else From 6f128a1675b6b5d1625c514bddc7a45e3a856e45 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 15 Nov 2024 23:35:38 +0000 Subject: [PATCH 2/3] update --- onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index f1c6a3477c364..fbccd7d4a286b 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -127,8 +127,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, ORT_ENFORCE(graph_build.Resolve().IsOK()); // Serialize modelproto to string - auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); auto new_graph_viewer = graph_build.CreateGraphViewer(); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); auto model = new_graph_viewer->CreateModel(*logger, metadata); auto model_proto = model->ToProto(); new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); From a172f401bc07b2b55ad618c20a889301c68b8946 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 15 Nov 2024 23:35:56 +0000 Subject: [PATCH 3/3] addd openvino --- onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index b0a7f46521cce..42a2b5d30c25c 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -21,7 +21,8 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, const bool& ep_context_embed_mode, std::string&& model_blob_str, const std::string& openvino_sdk_version) const { - auto model_build = graph_viewer.CreateModel(logger); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); + auto model_build = graph_viewer.CreateModel(logger, metadata); auto& graph_build = model_build->MainGraph(); // Get graph inputs and outputs