Skip to content
Closed
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
0b5b4d7
update
chilo-ms Dec 8, 2025
80561db
update
chilo-ms Dec 9, 2025
6bd316f
add API summary
chilo-ms Dec 9, 2025
ad0a023
update
chilo-ms Dec 9, 2025
5e398d4
address reviewer's comments and add GetNumCustomOpDomains
chilo-ms Dec 10, 2025
aeb2386
update example ep to run Custom_Mul op
chilo-ms Dec 10, 2025
3849cd3
address reviewr's comment
chilo-ms Dec 10, 2025
9c987be
lintrunner -a
chilo-ms Dec 10, 2025
fbe2434
update example ep GetCapability()
chilo-ms Dec 10, 2025
40fa8fe
update Example EP
chilo-ms Dec 11, 2025
c7a0491
add more comments in API summary
chilo-ms Dec 12, 2025
4787c3f
address reviewer's comments
chilo-ms Dec 18, 2025
632ce31
lintrunner -a
chilo-ms Dec 18, 2025
5905434
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 7, 2026
6017c00
Use CreateKernelV2 and ComputeKernelV2
chilo-ms Jan 8, 2026
47bb4dc
address reviewer's comments
chilo-ms Jan 8, 2026
6721a98
lintrunner -a
chilo-ms Jan 8, 2026
1ab246d
update
chilo-ms Jan 8, 2026
3478732
Remove accidentally added file
chilo-ms Jan 8, 2026
a1d36af
address reviewer's comments
chilo-ms Jan 9, 2026
3065e9d
address reviewer's comment
chilo-ms Jan 9, 2026
d340de5
address reveiwer's comment
chilo-ms Jan 9, 2026
ee8851b
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 9, 2026
15f5baf
update
chilo-ms Jan 9, 2026
6b01e7f
lintrunner -a
chilo-ms Jan 9, 2026
adf565e
fix bug when merging main
chilo-ms Jan 9, 2026
062280e
Make auto ep selection be able to register custom op
chilo-ms Jan 13, 2026
ff58721
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 13, 2026
002fcdc
add comments
chilo-ms Jan 13, 2026
bb7e082
Make code be able to get model_metadata from model during auto ep sel…
chilo-ms Jan 13, 2026
953dbd3
Use Model::Load
chilo-ms Jan 14, 2026
cf5948a
revert unnecessary change
chilo-ms Jan 14, 2026
cc31408
update API comment
chilo-ms Jan 14, 2026
e2604b9
fix build issue for minimal build
chilo-ms Jan 14, 2026
27cb17a
address reviewer's comments
chilo-ms Jan 14, 2026
f77e5f7
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 14, 2026
2e91855
lintrunner -a
chilo-ms Jan 14, 2026
1f448a6
fix compile warning for minimal build
chilo-ms Jan 14, 2026
debdc79
address reviewer's comment
chilo-ms Jan 14, 2026
452bb26
Add AddEpCustomDomainsToSessionOptions()
chilo-ms Jan 15, 2026
96d42fb
clean up code
chilo-ms Jan 15, 2026
04b75e8
clean up code and fix compile error
chilo-ms Jan 15, 2026
84cff1f
revert auto ep selection
chilo-ms Jan 15, 2026
4d9849f
Merge branch 'main' into chi/custom_op_for_ep
chilo-ms Jan 15, 2026
32e2e57
add back accidentaly removed code
chilo-ms Jan 15, 2026
b80b451
address reviewer's comments
chilo-ms Jan 16, 2026
0b7302e
update
chilo-ms Jan 16, 2026
be94f18
fix compile error for onnxruntime_pybind_state.cc
chilo-ms Jan 16, 2026
1841117
address reveiwer's comment
chilo-ms Jan 16, 2026
6a571ef
address reviewer's comments
chilo-ms Jan 16, 2026
3b2e5a5
address Copilot comment
chilo-ms Jan 17, 2026
5d0b15b
address Copilot comment
chilo-ms Jan 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2060,6 +2060,64 @@ struct OrtEpFactory {
ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr,
_In_ const OrtEpDevice* ep_device,
_Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer);

/** \brief Returns the number of OrtCustomOpDomains that this factory provides.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[out] num_domains Output parameter set to the number of provided OrtCustomOpDomain instances.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains);

/** \brief Gets the EP-specific OrtCustomOpDomains.
*
* This function is used when running inference on a model that contains EP-specific custom operations.
*
* Workflow:
* 1. The EP factory implements this function to supply a list of OrtCustomOpDomain instances.
* 2. The application either 1) calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing
* the plugin EP's factory or 2) enables auto ep selection.
* 3. 1) SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomains to the
* session options or 2) ORT registers the OrtCustomOpDomains provided by the selected EP devices.
*
* As a result, any session created from these session options will have these custom op domains registered
* in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded.
*
* Plugin EPs can provide two types of custom ops:
* 1. A full OrtCustomOp with a concrete kernel implementation
* - A Plugin EP can supply an OrtCustomOp and a corresponding CustomKernel::Compute() implementation.
* - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT
* that the custom node should NOT be fused or compiled. Instead, ORT should invoke
* the custom node's Compute() function at runtime.
*
* 2. A "placeholder" OrtCustomOp with an empty kernel implementation
* - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute()
* does nothing. The purpose is to satisfy model validation during model loading by
* registering the custom op as a valid operator in the session.
* - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to
* notify ORT that this custom node should be fused and compiled by the EP.
* - In Compile(), the EP executes its compiled bits to perform inference for
* the fused custom node.
*
* Note: The OrtCustomOpDomain instances must be valid while any session is using them.
EP factory has the responsibility to release OrtCustomOpDomain instances it creates. It happens
* automatically if using the C++ Ort::CustomOpDomain class.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[out] domains Array of `num_domains` elements pre-allocated by ORT that should be filled with
OrtCustomOpDomain instances created by the EP. The `num_domains` is the value returned by
GetNumCustomOpDomains(). The implementation is expected to treat `domains` as a buffer.
* \param[in] num_domains The size of the `domains` array pre-allocated by ORT.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr,
_Out_writes_all_(num_domains) OrtCustomOpDomain** domains, _In_ size_t num_domains);
};

#ifdef __cplusplus
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,22 @@
return session_id_;
}

void SetExecutionDevices(std::vector<const OrtEpDevice*> execution_devices) {
execution_devices_ = std::move(execution_devices);
}

const std::vector<const OrtEpDevice*>& GetExecutionDevices() noexcept {
return execution_devices_;
}

void SetSelectedDevices(std::vector<const OrtEpDevice*> selected_devices) {
devices_selected_ = std::move(selected_devices);

Check warning on line 674 in onnxruntime/core/session/inference_session.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/inference_session.h:674: Add #include <utility> for move [build/include_what_you_use] [4]
}

const std::vector<const OrtEpDevice*>& GetSelectedDevices() noexcept {
return devices_selected_;
}

protected:
#if !defined(ORT_MINIMAL_BUILD)

Expand Down Expand Up @@ -1054,6 +1070,13 @@
// Enable nodestats collection
std::optional<NodeStatsRecorder> node_stats_recorder_;
#endif

// Holds the list of devices from the environment, ordered via OrderDevices().
// It's used for auto ep selection.
std::vector<const OrtEpDevice*> execution_devices_;

// Holds the list of devices selected by policies.
std::vector<const OrtEpDevice*> devices_selected_;

Check warning on line 1079 in onnxruntime/core/session/inference_session.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/inference_session.h:1079: Add #include <vector> for vector<> [build/include_what_you_use] [4]
};

struct SessionIOBinding {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3333,7 +3333,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS
ep_devices_span,
ep_option_keys_span,
ep_option_vals_span,
session_options->value));
*session_options));

session_options->provider_factories.push_back(std::move(provider_factory));

Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ class EpFactoryInternalImpl {
return nullptr;
}

virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept {
*num_domains = 0;
return nullptr;
}

virtual OrtStatus* GetCustomOpDomains(_Out_writes_all_(num_domains) OrtCustomOpDomain** domains,
_In_ size_t num_domains) const noexcept {
ORT_UNUSED_PARAMETER(domains);
ORT_UNUSED_PARAMETER(num_domains);
return nullptr;
}

// Function ORT calls to release an EP instance.
void ReleaseEp(OrtEp* ep);

Expand Down
201 changes: 121 additions & 80 deletions onnxruntime/core/session/provider_policy_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
return d->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU &&
d->ep_vendor == "Microsoft";
}
} // namespace

// Sort devices. NPU -> GPU -> CPU
// Within in type, vendor owned, not.
Expand Down Expand Up @@ -115,16 +116,10 @@
return sorted_devices;
}

OrtKeyValuePairs GetModelMetadata(const InferenceSession& session) {
OrtKeyValuePairs GetModelMetadataKeyValuePairs(const ModelMetadata& model_metadata) {
OrtKeyValuePairs metadata;
auto status_and_metadata = session.GetModelMetadata();

if (!status_and_metadata.first.IsOK()) {
return metadata;
}

// use field names from onnx.proto
const auto& model_metadata = *status_and_metadata.second;
metadata.Add("producer_name", model_metadata.producer_name);
metadata.Add("producer_version", model_metadata.producer_version);
metadata.Add("domain", model_metadata.domain);
Expand All @@ -138,88 +133,42 @@

return metadata;
}
} // namespace

OrtKeyValuePairs GetModelMetadataFromSession(const InferenceSession& session) {
OrtKeyValuePairs metadata;
auto status_and_metadata = session.GetModelMetadata();

if (!status_and_metadata.first.IsOK()) {
return metadata;
}

const auto& model_metadata = *status_and_metadata.second;
return GetModelMetadataKeyValuePairs(model_metadata);
}

// Select execution providers based on the device policy and available devices and add to session
Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const OrtSessionOptions& options,
InferenceSession& sess) {
// Get the list of devices from the environment and order them.
// Ordered by preference within each type. NPU -> GPU -> NPU
// TODO: Should environment.cc do the ordering?
std::vector<const OrtEpDevice*> execution_devices = OrderDevices(env.GetOrtEpDevices());
std::vector<const OrtEpDevice*> execution_devices;

// Check if the sorted list of devices has cached in the session
if (!sess.GetExecutionDevices().empty()) {
execution_devices = sess.GetExecutionDevices();
} else {
// Get the list of devices from the environment and order them.
// Ordered by preference within each type. NPU -> GPU -> NPU
// TODO: Should environment.cc do the ordering?

Check warning on line 160 in onnxruntime/core/session/provider_policy_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/session/provider_policy_context.cc:160: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
execution_devices = OrderDevices(env.GetOrtEpDevices());
}

// The list of devices selected by policies
std::vector<const OrtEpDevice*> devices_selected;

// Run the delegate if it was passed in lieu of any other policy
if (options.value.ep_selection_policy.delegate) {
auto model_metadata = GetModelMetadata(sess);
OrtKeyValuePairs runtime_metadata; // TODO: where should this come from?

std::vector<const OrtEpDevice*> delegate_devices(execution_devices.begin(), execution_devices.end());
std::array<const OrtEpDevice*, 8> selected_devices{nullptr};
size_t num_selected = 0;

EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate;
auto* status = delegate(delegate_devices.data(), delegate_devices.size(),
&model_metadata, &runtime_metadata,
selected_devices.data(), selected_devices.size(), &num_selected,
options.value.ep_selection_policy.state);

// return or fall-through for both these cases
// going with explicit failure for now so it's obvious to user what is happening
if (status != nullptr) {
std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy
OrtApis::ReleaseStatus(status);

return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg);
}

if (num_selected == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything.");
}

if (num_selected > selected_devices.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"EP selection delegate selected too many EP devices (", num_selected, "). ",
"The limit is ", selected_devices.size(), " EP devices.");
}

// Copy the selected devices to the output vector
devices_selected.reserve(num_selected);
for (size_t i = 0; i < num_selected; ++i) {
devices_selected.push_back(selected_devices[i]);
}
// Check if the list of devices has been selected and cached in the session
if (!sess.GetSelectedDevices().empty()) {
devices_selected = sess.GetSelectedDevices();
} else {
// Create the selector for the chosen policy
std::unique_ptr<IEpPolicySelector> selector;
switch (options.value.ep_selection_policy.policy) {
case OrtExecutionProviderDevicePolicy_DEFAULT:
selector = std::make_unique<DefaultEpPolicy>();
break;
case OrtExecutionProviderDevicePolicy_PREFER_CPU:
selector = std::make_unique<PreferCpuEpPolicy>();
break;
case OrtExecutionProviderDevicePolicy_PREFER_NPU:
case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY:
case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER:
selector = std::make_unique<PreferNpuEpPolicy>();
break;
case OrtExecutionProviderDevicePolicy_PREFER_GPU:
case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE:
selector = std::make_unique<PreferGpuEpPolicy>();
break;
}

// Execute policy

selector->SelectProvidersForDevices(execution_devices, devices_selected);
}

// Fail if we did not find any device matches
if (devices_selected.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"No execution providers selected. Please check the device policy and available devices.");
ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, nullptr, sess));
}

// Log telemetry for auto EP selection
Expand Down Expand Up @@ -314,6 +263,98 @@
}
}

if (sess.GetExecutionDevices().empty()) {
sess.SetExecutionDevices(execution_devices);
}

if (sess.GetSelectedDevices().empty()) {
sess.SetSelectedDevices(devices_selected);
}

return Status::OK();
}

Status ProviderPolicyContext::SelectEpDevices(const OrtSessionOptions& options,
const std::vector<const OrtEpDevice*>& execution_devices,
std::vector<const OrtEpDevice*>& devices_selected,
const OrtKeyValuePairs* metadata_from_model,
const InferenceSession& sess) {
// Run the delegate if it was passed in lieu of any other policy
if (options.value.ep_selection_policy.delegate) {
OrtKeyValuePairs model_metadata;

if (metadata_from_model) {
model_metadata = *metadata_from_model;
} else {
model_metadata = GetModelMetadataFromSession(sess);
}
OrtKeyValuePairs runtime_metadata; // TODO: where should this come from?

Check warning on line 291 in onnxruntime/core/session/provider_policy_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/session/provider_policy_context.cc:291: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]

std::vector<const OrtEpDevice*> delegate_devices(execution_devices.begin(), execution_devices.end());
std::array<const OrtEpDevice*, 8> selected_devices{nullptr};
size_t num_selected = 0;

EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate;
auto* status = delegate(delegate_devices.data(), delegate_devices.size(),
&model_metadata, &runtime_metadata,
selected_devices.data(), selected_devices.size(), &num_selected,
options.value.ep_selection_policy.state);

// return or fall-through for both these cases
// going with explicit failure for now so it's obvious to user what is happening
if (status != nullptr) {
std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy
OrtApis::ReleaseStatus(status);

return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg);
}

if (num_selected == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything.");
}

if (num_selected > selected_devices.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"EP selection delegate selected too many EP devices (", num_selected, "). ",
"The limit is ", selected_devices.size(), " EP devices.");
}

// Copy the selected devices to the output vector
devices_selected.reserve(num_selected);
for (size_t i = 0; i < num_selected; ++i) {
devices_selected.push_back(selected_devices[i]);
}
} else {
// Create the selector for the chosen policy
std::unique_ptr<IEpPolicySelector> selector;
switch (options.value.ep_selection_policy.policy) {
case OrtExecutionProviderDevicePolicy_DEFAULT:
selector = std::make_unique<DefaultEpPolicy>();
break;
case OrtExecutionProviderDevicePolicy_PREFER_CPU:
selector = std::make_unique<PreferCpuEpPolicy>();
break;
case OrtExecutionProviderDevicePolicy_PREFER_NPU:
case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY:
case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER:
selector = std::make_unique<PreferNpuEpPolicy>();
break;
case OrtExecutionProviderDevicePolicy_PREFER_GPU:
case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE:
selector = std::make_unique<PreferGpuEpPolicy>();
break;
}

// Execute policy

selector->SelectProvidersForDevices(execution_devices, devices_selected);
}

// Fail if we did not find any device matches
if (devices_selected.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"No execution providers selected. Please check the device policy and available devices.");
}
return Status::OK();
}

Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/session/provider_policy_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

#include "core/session/abi_session_options_impl.h"
#include "core/session/environment.h"
#include "core/session/inference_session.h"
#include "core/session/onnxruntime_c_api.h" // For OrtExecutionProviderDevicePolicy

namespace onnxruntime {
Expand Down Expand Up @@ -41,6 +42,9 @@
ProviderPolicyContext() = default;

Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess);
Status SelectEpDevices(const OrtSessionOptions& options, const std::vector<const OrtEpDevice*>& execution_devices,
std::vector<const OrtEpDevice*>& devices_selected, const OrtKeyValuePairs* model_metadata,
const InferenceSession& sess);
Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector<const OrtEpDevice*> devices);
void RemoveOrtCpuDevice(std::vector<const OrtEpDevice*>& devices);
Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger,
Expand Down Expand Up @@ -75,6 +79,10 @@
std::vector<const OrtEpDevice*>& selected_devices) override;
};

std::vector<const OrtEpDevice*> OrderDevices(const std::vector<const OrtEpDevice*>& devices);

Check warning on line 82 in onnxruntime/core/session/provider_policy_context.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/provider_policy_context.h:82: Add #include <vector> for vector<> [build/include_what_you_use] [4]

OrtKeyValuePairs GetModelMetadataKeyValuePairs(const ModelMetadata& session);

} // namespace onnxruntime

#endif // !ORT_MINIMAL_BUILD
Loading
Loading