2
2
// Licensed under the MIT License
3
3
4
4
#include < fstream>
5
+ #include < sstream>
5
6
#include < utility>
6
- #include < exception>
7
7
8
8
#include " core/providers/shared_library/provider_api.h"
9
- #include " contexts.h"
10
- #include " backend_manager.h"
11
- #include " ibackend.h"
12
- #include " backend_utils.h"
9
+ #include " core/providers/openvino/ contexts.h"
10
+ #include " core/providers/openvino/ backend_manager.h"
11
+ #include " core/providers/openvino/ ibackend.h"
12
+ #include " core/providers/openvino/ backend_utils.h"
13
13
14
14
namespace onnxruntime {
15
15
namespace openvino_ep {
@@ -21,8 +21,17 @@ GlobalContext& BackendManager::GetGlobalContext() {
21
21
BackendManager::BackendManager (const GlobalContext& global_context,
22
22
const onnxruntime::Node& fused_node,
23
23
const onnxruntime::GraphViewer& subgraph,
24
- const logging::Logger& logger) {
24
+ const logging::Logger& logger,
25
+ EPCtxHandler& ctx_handle) {
25
26
global_context_ = global_context;
27
+ ep_ctx_handle_ = ctx_handle;
28
+
29
+ openvino_sdk_version_ = std::to_string (global_context_.OpenVINO_Version .at (0 )) + " ." +
30
+ std::to_string (global_context_.OpenVINO_Version .at (1 ));
31
+ if (ep_ctx_handle_.CheckForOVEPCtxNode (subgraph, openvino_sdk_version_)) {
32
+ if (ep_ctx_handle_.ImportBlobFromEPCtxModel (subgraph) != Status::OK ())
33
+ ORT_THROW (" Import blob from model failed" );
34
+ }
26
35
27
36
auto prec_str = GetGlobalContext ().precision_str ;
28
37
@@ -66,7 +75,8 @@ BackendManager::BackendManager(const GlobalContext& global_context,
66
75
try {
67
76
concrete_backend_ = BackendFactory::MakeBackend (*model_proto_,
68
77
GetGlobalContext (),
69
- subgraph_context_);
78
+ subgraph_context_,
79
+ ep_ctx_handle_);
70
80
} catch (std::string const & msg) {
71
81
ORT_THROW (msg);
72
82
}
@@ -85,7 +95,8 @@ BackendManager::BackendManager(const GlobalContext& global_context,
85
95
try {
86
96
concrete_backend_ = BackendFactory::MakeBackend (*model_proto_,
87
97
GetGlobalContext (),
88
- subgraph_context_);
98
+ subgraph_context_,
99
+ ep_ctx_handle_);
89
100
} catch (const OnnxRuntimeException& ex) {
90
101
if (device_type.find (" NPU" ) != std::string::npos) {
91
102
LOGS_DEFAULT (WARNING) << ex.what ();
@@ -96,7 +107,8 @@ BackendManager::BackendManager(const GlobalContext& global_context,
96
107
try {
97
108
concrete_backend_ = BackendFactory::MakeBackend (*model_proto_,
98
109
GetGlobalContext (),
99
- subgraph_context_);
110
+ subgraph_context_,
111
+ ep_ctx_handle_);
100
112
} catch (std::string const & msg) {
101
113
ORT_THROW (msg);
102
114
}
@@ -107,6 +119,45 @@ BackendManager::BackendManager(const GlobalContext& global_context,
107
119
}
108
120
}
109
121
122
+ // Call EPContext model exporter here if the provider option for exporting
123
+ // precompiled blob is set. If that's the case:
124
+ // By default, create model in embed mode where the blob stream is exported as data within
125
+ // the EPContext node.
126
+ Status BackendManager::ExportCompiledBlobAsEPCtxNode (const onnxruntime::GraphViewer& graph_body_viewer,
127
+ const logging::Logger& logger) {
128
+ std::string model_blob_str;
129
+ auto compiled_model = concrete_backend_->GetOVCompiledModel ();
130
+ auto graph_name = global_context_.onnx_model_path_name ;
131
+ // Remove extension so we can append suffix to form the complete name of output graph
132
+ graph_name = [&]() {
133
+ size_t dot = graph_name.find_last_of (" ." );
134
+ if (dot == std::string::npos) return graph_name;
135
+ return graph_name.substr (0 , dot);
136
+ }();
137
+ // If embed_mode, then pass on the serialized blob
138
+ // If not embed_mode, dump the blob here and only pass on the path to the blob
139
+ if (global_context_.ep_context_embed_mode ) {
140
+ std::ostringstream model_blob_stream;
141
+ compiled_model.export_model (model_blob_stream);
142
+ model_blob_str = model_blob_stream.str ();
143
+ ORT_ENFORCE (model_blob_str.size () != 0 );
144
+ } else {
145
+ std::ofstream f (graph_name + " .blob" , std::ios::out | std::ios::trunc | std::ios::binary);
146
+ compiled_model.export_model (f);
147
+ model_blob_str = graph_name + " .blob" ;
148
+ }
149
+
150
+ ORT_RETURN_IF_ERROR (ep_ctx_handle_.ExportEPCtxModel (graph_body_viewer,
151
+ graph_name,
152
+ logger,
153
+ global_context_.ep_context_embed_mode ,
154
+ model_blob_str,
155
+ openvino_sdk_version_,
156
+ GetGlobalContext ().device_type ));
157
+
158
+ return Status::OK ();
159
+ }
160
+
110
161
bool BackendManager::ModelHasBatchedInputs (const ONNX_NAMESPACE::ModelProto& model_proto) const {
111
162
bool has_batched_inputs = true ;
112
163
@@ -182,7 +233,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node,
182
233
return model_proto;
183
234
}
184
235
185
- std::vector<std::vector<int64_t >> GetInputTensorShapes (Ort::KernelContext& context) {
236
+ std::vector<std::vector<int64_t >> GetInputTensorShapes (const Ort::KernelContext& context) {
186
237
const auto input_count = context.GetInputCount ();
187
238
std::vector<std::vector<int64_t >> input_shapes;
188
239
input_shapes.reserve (input_count);
@@ -289,7 +340,8 @@ void BackendManager::Compute(OrtKernelContext* context) {
289
340
try {
290
341
dynamic_backend = BackendFactory::MakeBackend (*modelproto_with_concrete_shapes,
291
342
GetGlobalContext (),
292
- subgraph_context_);
343
+ subgraph_context_,
344
+ ep_ctx_handle_);
293
345
} catch (const OnnxRuntimeException& ex) {
294
346
if (GetGlobalContext ().device_type .find (" NPU" ) != std::string::npos) {
295
347
LOGS_DEFAULT (WARNING) << ex.what ();
@@ -301,7 +353,8 @@ void BackendManager::Compute(OrtKernelContext* context) {
301
353
try {
302
354
dynamic_backend = BackendFactory::MakeBackend (*modelproto_with_concrete_shapes,
303
355
GetGlobalContext (),
304
- subgraph_context_);
356
+ subgraph_context_,
357
+ ep_ctx_handle_);
305
358
} catch (std::string const & msg) {
306
359
ORT_THROW (msg);
307
360
}
0 commit comments