8
8
#include < unordered_set>
9
9
#include < utility>
10
10
#include < vector>
11
+ #include " core/framework/abi_pointer_array.h"
11
12
#include " core/framework/compute_capability.h"
12
13
#include " core/framework/error_code_helper.h"
13
14
#include " core/framework/model_metadef_id_generator.h"
14
15
#include " core/graph/ep_api_types.h"
15
- #include " core/session/ort_apis .h"
16
+ #include " core/graph/model_editor_api_types .h"
16
17
#include " core/session/abi_devices.h"
17
18
#include " core/session/abi_ep_types.h"
18
19
#include " core/session/abi_logger.h"
20
+ #include " core/session/abi_session_options_impl.h"
19
21
#include " core/session/allocator_adapters.h"
22
+ #include " core/session/ort_apis.h"
20
23
#include " core/providers/partitioning_utils.h"
21
24
22
25
namespace onnxruntime {
@@ -48,7 +51,8 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_
48
51
ORT_THROW (" Error creating execution provider: " , status.ToString ());
49
52
}
50
53
51
- auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp (ort_ep, OrtEpDeleter (ep_factory_)));
54
+ auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp (ort_ep, OrtEpDeleter (ep_factory_)),
55
+ session_options);
52
56
ep_wrapper->SetLogger (session_logger.ToInternal ());
53
57
54
58
return ep_wrapper;
@@ -80,9 +84,10 @@ struct PluginEpMetaDefNameFunctor {
80
84
// PluginExecutionProvider
81
85
//
82
86
83
- PluginExecutionProvider::PluginExecutionProvider (UniqueOrtEp ep)
87
+ PluginExecutionProvider::PluginExecutionProvider (UniqueOrtEp ep, const OrtSessionOptions& session_options )
84
88
: IExecutionProvider(ep->GetName (ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins?
85
89
ort_ep_(std::move(ep)) {
90
+ generate_ep_ctx_model_ = session_options.value .GetEpContextGenerationOptions ().enable ;
86
91
}
87
92
88
93
PluginExecutionProvider::~PluginExecutionProvider () {
@@ -185,6 +190,87 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
185
190
return Status::OK ();
186
191
}
187
192
193
+ // / <summary>
194
+ // / Converts the EPContext nodes provided by the plugin EP (OrtNode instances) to onnxruntime::Node instances.
195
+ // / Note that the EP plugin uses the model editor API to create the OrtNode instances.
196
+ // / </summary>
197
+ // / <param name="ep_name">Name of the plugin EP.</param>
198
+ // / <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
199
+ // / <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
200
+ // / <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
201
+ // / <returns>A status indicating success or an error.</returns>
202
+ static Status ConvertEpContextNodes (const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
203
+ /* out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
204
+ /* out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
205
+ #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
206
+ if (plugin_ep_context_nodes.empty ()) {
207
+ return Status::OK (); // No EPContext nodes.
208
+ }
209
+
210
+ std::vector<std::unique_ptr<Node>> ep_context_nodes_holder;
211
+ std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;
212
+
213
+ ep_context_nodes_holder.reserve (plugin_ep_context_nodes.size ());
214
+
215
+ for (const OrtNode* ort_node : plugin_ep_context_nodes) {
216
+ ORT_RETURN_IF_NOT (ort_node != nullptr , ep_name, " : OrtEp::Compile() returned a NULL EPContext node." );
217
+
218
+ const ModelEditorNode* editor_node = ModelEditorNode::ToInternal (ort_node);
219
+ ORT_RETURN_IF_NOT (editor_node != nullptr , ep_name, " : OrtEp::Compile() returned OrtNode objects " ,
220
+ " that were not created with OrtModelEditorApi." );
221
+
222
+ // Create NodeArg for each input/output.
223
+ std::vector<NodeArg*> input_node_args;
224
+ std::vector<NodeArg*> output_node_args;
225
+
226
+ input_node_args.reserve (editor_node->input_names .size ());
227
+ output_node_args.reserve (editor_node->output_names .size ());
228
+
229
+ for (const std::string& input_name : editor_node->input_names ) {
230
+ auto node_arg = std::make_unique<NodeArg>(input_name, /* p_arg_type*/ nullptr ); // Graph.Resolve() sets type.
231
+ input_node_args.push_back (node_arg.get ());
232
+ ep_context_node_args_holder.push_back (std::move (node_arg));
233
+ }
234
+
235
+ for (const std::string& output_name : editor_node->output_names ) {
236
+ auto node_arg = std::make_unique<NodeArg>(output_name, /* p_arg_type*/ nullptr ); // Graph.Resolve() sets type.
237
+ output_node_args.push_back (node_arg.get ());
238
+ ep_context_node_args_holder.push_back (std::move (node_arg));
239
+ }
240
+
241
+ // Create a name -> attribute map.
242
+ NodeAttributes attributes;
243
+ attributes.reserve (editor_node->attributes .size ());
244
+
245
+ for (const ONNX_NAMESPACE::AttributeProto& attr : editor_node->attributes ) {
246
+ attributes.emplace (attr.name (), attr);
247
+ }
248
+
249
+ // Create Node
250
+ auto internal_node = std::make_unique<Node>(editor_node->node_name ,
251
+ editor_node->operator_name ,
252
+ " EPContext node for " + ep_name,
253
+ input_node_args,
254
+ output_node_args,
255
+ &attributes,
256
+ editor_node->domain_name );
257
+
258
+ ep_context_nodes_holder.push_back (std::move (internal_node));
259
+ }
260
+
261
+ result_nodes = std::move (ep_context_nodes_holder);
262
+ result_node_args = std::move (ep_context_node_args_holder);
263
+
264
+ return Status::OK ();
265
+ #else
266
+ ORT_UNUSED_PARAMETER (ep_name);
267
+ ORT_UNUSED_PARAMETER (plugin_ep_context_nodes);
268
+ ORT_UNUSED_PARAMETER (result_nodes);
269
+ ORT_UNUSED_PARAMETER (result_node_args);
270
+ return ORT_MAKE_STATUS (ONNXRUNTIME, NOT_IMPLEMENTED, " Creating EPContext models is not supported in this build" );
271
+ #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
272
+ }
273
+
188
274
common::Status PluginExecutionProvider::Compile (const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
189
275
std::vector<NodeComputeInfo>& node_compute_infos) {
190
276
const logging::Logger* logger = GetLogger ();
@@ -220,8 +306,21 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
220
306
api_fused_nodes.push_back (ep_fused_node->ToExternal ());
221
307
}
222
308
223
- ORT_RETURN_IF_ERROR (ToStatusAndRelease (ort_ep_->Compile (ort_ep_.get (), api_graphs.data (), api_fused_nodes.data (),
224
- num_graphs, api_node_compute_infos.data ())));
309
+ // Provide an output buffer for the plugin EP to store EPContext nodes if it needs to (i.e., enabled in session options).
310
+ std::vector<std::unique_ptr<OrtNode, decltype (&OrtApis::ReleaseNode)>> plugin_ep_context_nodes_holder;
311
+ std::vector<OrtNode*> plugin_ep_context_nodes;
312
+ plugin_ep_context_nodes_holder.reserve (num_graphs);
313
+ plugin_ep_context_nodes.resize (num_graphs, nullptr );
314
+
315
+ Status compile_status = ToStatusAndRelease (ort_ep_->Compile (ort_ep_.get (), api_graphs.data (), api_fused_nodes.data (),
316
+ num_graphs, api_node_compute_infos.data (),
317
+ plugin_ep_context_nodes.data ()));
318
+
319
+ // Store any EPContext nodes provided by the plugin EP in std::unique_ptr so that they are always properly released.
320
+ for (OrtNode* ort_node : plugin_ep_context_nodes) {
321
+ auto unique_ort_node = std::unique_ptr<OrtNode, decltype (&OrtApis::ReleaseNode)>(ort_node, OrtApis::ReleaseNode);
322
+ plugin_ep_context_nodes_holder.push_back (std::move (unique_ort_node));
323
+ }
225
324
226
325
// Save OrtNodeComputeInfo created by OrtEp instance. They're freed when this IExecutionProvider
227
326
// is destroyed.
@@ -231,6 +330,8 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
231
330
}
232
331
}
233
332
333
+ ORT_RETURN_IF_ERROR (compile_status);
334
+
234
335
// Initialize node_compute_infos as wrappers to api_node_compute_infos.
235
336
for (size_t i = 0 ; i < num_graphs; i++) {
236
337
OrtNodeComputeInfo* api_node_compute_info = api_node_compute_infos[i];
@@ -268,6 +369,25 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
268
369
node_compute_infos.push_back (std::move (compute_info));
269
370
}
270
371
372
+ // Convert the EPContext nodes provided by the plugin EP into onnxruntime::Node instances.
373
+ // We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
374
+ // partitioner via a call to IExecutionProvider::GetEpContextNodes().
375
+ if (generate_ep_ctx_model_) {
376
+ ORT_RETURN_IF_ERROR (ConvertEpContextNodes (Type (), plugin_ep_context_nodes,
377
+ /* out*/ ep_context_nodes_, /* out*/ ep_context_node_args_));
378
+ }
379
+
271
380
return Status::OK ();
272
381
}
382
+
383
+ const InlinedVector<const Node*> PluginExecutionProvider::GetEpContextNodes () const {
384
+ InlinedVector<const Node*> result;
385
+
386
+ for (const std::unique_ptr<Node>& node : ep_context_nodes_) {
387
+ result.push_back (node.get ());
388
+ }
389
+
390
+ return result;
391
+ }
392
+
273
393
} // namespace onnxruntime
0 commit comments