diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index b0a7f46521cce..42a2b5d30c25c 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -21,7 +21,8 @@ Status EPCtxHandler::ExportEPCtxModel(const GraphViewer& graph_viewer, const bool& ep_context_embed_mode, std::string&& model_blob_str, const std::string& openvino_sdk_version) const { - auto model_build = graph_viewer.CreateModel(logger); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); + auto model_build = graph_viewer.CreateModel(logger, metadata); auto& graph_build = model_build->MainGraph(); // Get graph inputs and outputs diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 1436afa41c2f8..f9f2bb69a9d1a 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -960,7 +960,7 @@ struct ProviderHost { // GraphViewer virtual void GraphViewer__operator_delete(GraphViewer* p) = 0; - virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger) = 0; + virtual std::unique_ptr GraphViewer__CreateModel(const GraphViewer* p, const logging::Logger& logger, const ModelMetaData&) = 0; virtual const std::string& GraphViewer__Name(const GraphViewer* p) noexcept = 0; virtual const std::filesystem::path& GraphViewer__ModelPath(const GraphViewer* p) noexcept = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 5e8996d590db8..a82ddfe64c64b 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1022,11 +1022,13 @@ struct Graph final { PROVIDER_DISALLOW_ALL(Graph) }; +using ModelMetaData = std::unordered_map; + class GraphViewer final { public: static void operator delete(void* p) { g_host->GraphViewer__operator_delete(reinterpret_cast(p)); } - std::unique_ptr CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); } + std::unique_ptr CreateModel(const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) const { return g_host->GraphViewer__CreateModel(this, logger, metadata); } const std::string& Name() const noexcept { return g_host->GraphViewer__Name(this); } const std::filesystem::path& ModelPath() const noexcept { return g_host->GraphViewer__ModelPath(this); } diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index ef45d6c85d6a9..fbccd7d4a286b 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -128,7 +128,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer, // Serialize modelproto to string auto new_graph_viewer = graph_build.CreateGraphViewer(); - auto model = new_graph_viewer->CreateModel(*logger); + auto& metadata = graph_viewer.GetGraph().GetModel().MetaData(); + auto model = new_graph_viewer->CreateModel(*logger, metadata); auto model_proto = model->ToProto(); new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 0aa93bce354e8..19a1ad8a5a160 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1156,8 +1156,8 @@ struct ProviderHostImpl : ProviderHost { // GraphViewer (wrapped) void GraphViewer__operator_delete(GraphViewer* p) override { delete p; } - std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger) override { - return std::make_unique(graph_viewer->Name(), true, ModelMetaData(), PathString(), + std::unique_ptr GraphViewer__CreateModel(const GraphViewer* graph_viewer, const logging::Logger& logger, const ModelMetaData& metadata = ModelMetaData()) override { + return std::make_unique(graph_viewer->Name(), true, metadata, PathString(), #if !defined(ORT_MINIMAL_BUILD) IOnnxRuntimeOpSchemaRegistryList({graph_viewer->GetSchemaRegistry()}), graph_viewer->DomainToVersionMap(), #else