Skip to content

Commit 7d22c09

Browse files
adrianlizarragaCopilotedgchen1
authored
[EP ABI] Add support for creating EP Context models. (microsoft#25124)
### Description - Updates `OrtEp::Compile()` to allow a plugin EP to create and return EPContext nodes. - Updates the example EP plugin to generate an example EPContext model: <img width="747" alt="image" src="https://github.com/user-attachments/assets/e5d98a10-ec15-45aa-bfaf-887d3b6226e2" /> ### Motivation and Context Adds more of the functionality missing from the EP ABI used for plugin EPs. --------- Co-authored-by: Copilot <[email protected]> Co-authored-by: Edward Chen <[email protected]>
1 parent f80e6f4 commit 7d22c09

File tree

10 files changed

+424
-82
lines changed

10 files changed

+424
-82
lines changed

cmake/onnxruntime_unittests.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1834,6 +1834,8 @@ if (WIN32 AND onnxruntime_BUILD_SHARED_LIB AND
18341834
NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND
18351835
NOT onnxruntime_MINIMAL_BUILD)
18361836
onnxruntime_add_shared_library_module(example_plugin_ep
1837+
${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.h
1838+
${TEST_SRC_DIR}/autoep/library/example_plugin_ep_utils.cc
18371839
${TEST_SRC_DIR}/autoep/library/example_plugin_ep.cc)
18381840
target_include_directories(example_plugin_ep PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
18391841
target_link_libraries(example_plugin_ep PRIVATE onnxruntime)

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3678,7 +3678,7 @@ struct OrtApi {
36783678
*
36793679
* \param[in] name Name of the attribute
36803680
* \param[in] data Data content of the attribute
3681-
* \param[in] len Number of bytes stored in data
3681+
* \param[in] len Number of elements if data represents an array (e.g., ORT_OP_ATTR_INTS). Otherwise, set to 1.
36823682
* \param[in] type Data type
36833683
* \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr
36843684
*

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,14 @@ struct OrtEp {
182182
/** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance
183183
* for each OrtGraph in order to define its computation function.
184184
*
185+
* If the session is configured to generate a pre-compiled model, the execution provider must return EPContext nodes,
186+
* as OrtNode instances, that ONNX Runtime uses to create a pre-compiled model, known as an "EPContext model".
187+
* An EPContext model contains EPContext nodes. Each EPContext node encapsulates the pre-compiled binary data for a
188+
* OrtGraph compiled for a specific execution provider. For more details about the EPContext design, refer to:
189+
* \htmlonly
190+
* <a href="https://onnxruntime.ai/docs/execution-providers/EP-Context-Design.html">EPContext design document.</a>
191+
* \endhtmlonly
192+
*
185193
* \param[in] this_ptr The OrtEp instance.
186194
* \param[in] graphs Array of `count` OrtGraph instances to compile. Each graph contains only the nodes for
187195
* which the execution provider indicated support. Nested subgraphs contained by a
@@ -190,9 +198,15 @@ struct OrtEp {
190198
* Each fused node is an OrtNode initialized with the intended fused node name and
191199
* input/output information.
192200
* \param[in] count The number of OrtGraph instances to compile.
193-
* \param[inout] node_compute_infos Array of `count` OrtNodeComputeInfo instances that define each OrtGraph instance's
194-
* computation function. The implementer allocates the OrtNodeComputeInfo instances.
195-
* ORT calls ReleaseNodeComputeInfos() to release multiple instances in a batch.
201+
* \param[out] node_compute_infos Array of `count` OrtNodeComputeInfo instances that define each OrtGraph instance's
202+
* computation function. The implementer allocates the OrtNodeComputeInfo instances.
203+
* ORT calls ReleaseNodeComputeInfos() to release multiple instances in a batch.
204+
* \param[out] ep_context_nodes Output array of `count` OrtNode instances, each representing an EPContext
205+
* node for a compiled OrtGraph. The execution provider must use
206+
* OrtModelEditorApi::CreateNode to create the OrtNode instances. ONNX Runtime takes
207+
* ownership of the OrtNode instances, so the execution provider must NOT call
208+
* OrtApi::ReleaseNode. Should be ignored if the session is not configured to generate an
209+
* EPContext model.
196210
*
197211
* \snippet{doc} snippets.dox OrtStatus Return Value
198212
*
@@ -204,7 +218,8 @@ struct OrtEp {
204218
*/
205219
OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
206220
_In_ const OrtNode** fused_nodes, _In_ size_t count,
207-
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos);
221+
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
222+
_Out_writes_(count) OrtNode** ep_context_nodes);
208223

209224
/** \brief Release OrtNodeComputeInfo instances.
210225
*

onnxruntime/core/session/ep_plugin_provider_interfaces.cc

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@
88
#include <unordered_set>
99
#include <utility>
1010
#include <vector>
11+
#include "core/framework/abi_pointer_array.h"
1112
#include "core/framework/compute_capability.h"
1213
#include "core/framework/error_code_helper.h"
1314
#include "core/framework/model_metadef_id_generator.h"
1415
#include "core/graph/ep_api_types.h"
15-
#include "core/session/ort_apis.h"
16+
#include "core/graph/model_editor_api_types.h"
1617
#include "core/session/abi_devices.h"
1718
#include "core/session/abi_ep_types.h"
1819
#include "core/session/abi_logger.h"
20+
#include "core/session/abi_session_options_impl.h"
1921
#include "core/session/allocator_adapters.h"
22+
#include "core/session/ort_apis.h"
2023
#include "core/providers/partitioning_utils.h"
2124

2225
namespace onnxruntime {
@@ -48,7 +51,8 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_
4851
ORT_THROW("Error creating execution provider: ", status.ToString());
4952
}
5053

51-
auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)));
54+
auto ep_wrapper = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)),
55+
session_options);
5256
ep_wrapper->SetLogger(session_logger.ToInternal());
5357

5458
return ep_wrapper;
@@ -80,9 +84,10 @@ struct PluginEpMetaDefNameFunctor {
8084
// PluginExecutionProvider
8185
//
8286

83-
PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep)
87+
PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options)
8488
: IExecutionProvider(ep->GetName(ep.get()), OrtDevice()), // TODO: What to do about OrtDevice for plugins?
8589
ort_ep_(std::move(ep)) {
90+
generate_ep_ctx_model_ = session_options.value.GetEpContextGenerationOptions().enable;
8691
}
8792

8893
PluginExecutionProvider::~PluginExecutionProvider() {
@@ -185,6 +190,87 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n
185190
return Status::OK();
186191
}
187192

193+
/// <summary>
194+
/// Converts the EPContext nodes provided by the plugin EP (OrtNode instances) to onnxruntime::Node instances.
195+
/// Note that the EP plugin uses the model editor API to create the OrtNode instances.
196+
/// </summary>
197+
/// <param name="ep_name">Name of the plugin EP.</param>
198+
/// <param name="plugin_ep_context_nodes">EPContext nodes provided by the plugin EP.</param>
199+
/// <param name="result_nodes">Output parameter set to the resulting array of EPContext nodes.</param>
200+
/// <param name="result_node_args">Output parameter that stores the NodeArgs used by the EPContext nodes.</param>
201+
/// <returns>A status indicating success or an error.</returns>
202+
static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector<OrtNode*> plugin_ep_context_nodes,
203+
/*out*/ std::vector<std::unique_ptr<Node>>& result_nodes,
204+
/*out*/ std::vector<std::unique_ptr<NodeArg>>& result_node_args) {
205+
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
206+
if (plugin_ep_context_nodes.empty()) {
207+
return Status::OK(); // No EPContext nodes.
208+
}
209+
210+
std::vector<std::unique_ptr<Node>> ep_context_nodes_holder;
211+
std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_holder;
212+
213+
ep_context_nodes_holder.reserve(plugin_ep_context_nodes.size());
214+
215+
for (const OrtNode* ort_node : plugin_ep_context_nodes) {
216+
ORT_RETURN_IF_NOT(ort_node != nullptr, ep_name, ": OrtEp::Compile() returned a NULL EPContext node.");
217+
218+
const ModelEditorNode* editor_node = ModelEditorNode::ToInternal(ort_node);
219+
ORT_RETURN_IF_NOT(editor_node != nullptr, ep_name, ": OrtEp::Compile() returned OrtNode objects ",
220+
"that were not created with OrtModelEditorApi.");
221+
222+
// Create NodeArg for each input/output.
223+
std::vector<NodeArg*> input_node_args;
224+
std::vector<NodeArg*> output_node_args;
225+
226+
input_node_args.reserve(editor_node->input_names.size());
227+
output_node_args.reserve(editor_node->output_names.size());
228+
229+
for (const std::string& input_name : editor_node->input_names) {
230+
auto node_arg = std::make_unique<NodeArg>(input_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
231+
input_node_args.push_back(node_arg.get());
232+
ep_context_node_args_holder.push_back(std::move(node_arg));
233+
}
234+
235+
for (const std::string& output_name : editor_node->output_names) {
236+
auto node_arg = std::make_unique<NodeArg>(output_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type.
237+
output_node_args.push_back(node_arg.get());
238+
ep_context_node_args_holder.push_back(std::move(node_arg));
239+
}
240+
241+
// Create a name -> attribute map.
242+
NodeAttributes attributes;
243+
attributes.reserve(editor_node->attributes.size());
244+
245+
for (const ONNX_NAMESPACE::AttributeProto& attr : editor_node->attributes) {
246+
attributes.emplace(attr.name(), attr);
247+
}
248+
249+
// Create Node
250+
auto internal_node = std::make_unique<Node>(editor_node->node_name,
251+
editor_node->operator_name,
252+
"EPContext node for " + ep_name,
253+
input_node_args,
254+
output_node_args,
255+
&attributes,
256+
editor_node->domain_name);
257+
258+
ep_context_nodes_holder.push_back(std::move(internal_node));
259+
}
260+
261+
result_nodes = std::move(ep_context_nodes_holder);
262+
result_node_args = std::move(ep_context_node_args_holder);
263+
264+
return Status::OK();
265+
#else
266+
ORT_UNUSED_PARAMETER(ep_name);
267+
ORT_UNUSED_PARAMETER(plugin_ep_context_nodes);
268+
ORT_UNUSED_PARAMETER(result_nodes);
269+
ORT_UNUSED_PARAMETER(result_node_args);
270+
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Creating EPContext models is not supported in this build");
271+
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
272+
}
273+
188274
common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
189275
std::vector<NodeComputeInfo>& node_compute_infos) {
190276
const logging::Logger* logger = GetLogger();
@@ -220,8 +306,21 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
220306
api_fused_nodes.push_back(ep_fused_node->ToExternal());
221307
}
222308

223-
ORT_RETURN_IF_ERROR(ToStatusAndRelease(ort_ep_->Compile(ort_ep_.get(), api_graphs.data(), api_fused_nodes.data(),
224-
num_graphs, api_node_compute_infos.data())));
309+
// Provide an output buffer for the plugin EP to store EPContext nodes if it needs to (i.e., enabled in session options).
310+
std::vector<std::unique_ptr<OrtNode, decltype(&OrtApis::ReleaseNode)>> plugin_ep_context_nodes_holder;
311+
std::vector<OrtNode*> plugin_ep_context_nodes;
312+
plugin_ep_context_nodes_holder.reserve(num_graphs);
313+
plugin_ep_context_nodes.resize(num_graphs, nullptr);
314+
315+
Status compile_status = ToStatusAndRelease(ort_ep_->Compile(ort_ep_.get(), api_graphs.data(), api_fused_nodes.data(),
316+
num_graphs, api_node_compute_infos.data(),
317+
plugin_ep_context_nodes.data()));
318+
319+
// Store any EPContext nodes provided by the plugin EP in std::unique_ptr so that they are always properly released.
320+
for (OrtNode* ort_node : plugin_ep_context_nodes) {
321+
auto unique_ort_node = std::unique_ptr<OrtNode, decltype(&OrtApis::ReleaseNode)>(ort_node, OrtApis::ReleaseNode);
322+
plugin_ep_context_nodes_holder.push_back(std::move(unique_ort_node));
323+
}
225324

226325
// Save OrtNodeComputeInfo created by OrtEp instance. They're freed when this IExecutionProvider
227326
// is destroyed.
@@ -231,6 +330,8 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
231330
}
232331
}
233332

333+
ORT_RETURN_IF_ERROR(compile_status);
334+
234335
// Initialize node_compute_infos as wrappers to api_node_compute_infos.
235336
for (size_t i = 0; i < num_graphs; i++) {
236337
OrtNodeComputeInfo* api_node_compute_info = api_node_compute_infos[i];
@@ -268,6 +369,25 @@ common::Status PluginExecutionProvider::Compile(const std::vector<FusedNodeAndGr
268369
node_compute_infos.push_back(std::move(compute_info));
269370
}
270371

372+
// Convert the EPContext nodes provided by the plugin EP into onnxruntime::Node instances.
373+
// We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph
374+
// partitioner via a call to IExecutionProvider::GetEpContextNodes().
375+
if (generate_ep_ctx_model_) {
376+
ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), plugin_ep_context_nodes,
377+
/*out*/ ep_context_nodes_, /*out*/ ep_context_node_args_));
378+
}
379+
271380
return Status::OK();
272381
}
382+
383+
const InlinedVector<const Node*> PluginExecutionProvider::GetEpContextNodes() const {
384+
InlinedVector<const Node*> result;
385+
386+
for (const std::unique_ptr<Node>& node : ep_context_nodes_) {
387+
result.push_back(node.get());
388+
}
389+
390+
return result;
391+
}
392+
273393
} // namespace onnxruntime

onnxruntime/core/session/ep_plugin_provider_interfaces.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
namespace onnxruntime {
1818
struct EpNode;
1919
struct EpValueInfo;
20+
class NodeArg;
2021

2122
/// <summary>
2223
/// IExecutionProviderFactory that wraps a OrtEpFactory. Required for SessionOptionsAppendExecutionProvider_V2.
@@ -59,7 +60,7 @@ using UniqueOrtEp = std::unique_ptr<OrtEp, OrtEpDeleter>;
5960
/// </summary>
6061
class PluginExecutionProvider : public IExecutionProvider {
6162
public:
62-
explicit PluginExecutionProvider(UniqueOrtEp ep);
63+
explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options);
6364
~PluginExecutionProvider();
6465

6566
std::vector<std::unique_ptr<ComputeCapability>>
@@ -71,6 +72,8 @@ class PluginExecutionProvider : public IExecutionProvider {
7172
common::Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
7273
std::vector<NodeComputeInfo>& node_compute_funcs) override;
7374

75+
const InlinedVector<const Node*> GetEpContextNodes() const override;
76+
7477
private:
7578
struct FusedNodeState {
7679
FusedNodeState() = default;
@@ -83,12 +86,19 @@ class PluginExecutionProvider : public IExecutionProvider {
8386
};
8487

8588
UniqueOrtEp ort_ep_;
89+
bool generate_ep_ctx_model_ = false;
8690
std::vector<OrtNodeComputeInfo*> api_node_compute_infos_;
8791

8892
// Fused nodes have to be valid throughout model inference because they may be cached in NodeComputeInfo instances.
8993
// For each fused node, the Compile() function creates EpNode and EpValueInfo instances on the heap,
9094
// which are then passed to the underlying OrtEp instance. This class stores this "fused node state"
9195
// so that it is not destroyed until the EP itself is destroyed.
9296
std::vector<FusedNodeState> fused_node_states_;
97+
98+
// Stores the EPContext Nodes created from the OrtNode instances returned by the underlying plugin EP.
99+
// Need to store both the Node and NodeArg instances so that they are available when the GraphPartitioner
100+
// calls IExecutionProvider::GetEpContextNodes().
101+
std::vector<std::unique_ptr<Node>> ep_context_nodes_;
102+
std::vector<std::unique_ptr<NodeArg>> ep_context_node_args_;
93103
};
94104
} // namespace onnxruntime

onnxruntime/core/session/provider_policy_context.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, Or
292292
ORT_RETURN_IF_ERROR(ToStatusAndRelease(info.ep_factory->CreateEp(info.ep_factory, info.devices.data(),
293293
info.ep_metadata.data(), info.devices.size(),
294294
&options, &logger, &api_ep)));
295-
ep = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)));
295+
ep = std::make_unique<PluginExecutionProvider>(UniqueOrtEp(api_ep, OrtEpDeleter(*info.ep_factory)), options);
296296
}
297297

298298
return Status::OK();

0 commit comments

Comments
 (0)