From 0b5b4d7ed7fe37cc149717da2ca8bc7623d76882 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 8 Dec 2025 14:47:46 -0800 Subject: [PATCH 01/47] update --- .../core/session/onnxruntime_ep_c_api.h | 9 +++++ onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- .../plugin_ep/ep_factory_internal_impl.h | 7 ++++ onnxruntime/core/session/utils.cc | 35 ++++++++++++++++++- onnxruntime/core/session/utils.h | 2 +- .../unittest_util/test_dynamic_plugin_ep.cc | 2 +- 6 files changed, 53 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 6fa5c8dea04e6..1aa11790d6fde 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1413,6 +1413,15 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); + + /** \brief + * + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateCustomOpDomain, _Outptr_result_maybenull_ OrtCustomOpDomain** out, _Out_ size_t* num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 82f7cef4aec49..e0580adc20cc2 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3332,7 +3332,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS ep_devices_span, ep_option_keys_span, ep_option_vals_span, - session_options->value)); + *session_options)); session_options->provider_factories.push_back(std::move(provider_factory)); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index de9e2d44431bf..da46694ef7af7 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -88,6 +88,13 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* CreateCustomOpDomain(_Outptr_result_maybenull_ OrtCustomOpDomain** out, + _Out_ size_t* num_domains) const noexcept { + *out = nullptr; + *num_domains = 0; + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 4cb21b80109c8..1e16b833e876f 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -482,7 +482,8 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, Status AddEpOptionsToSessionOptions(gsl::span ep_devices, gsl::span ep_option_keys, gsl::span ep_option_vals, - SessionOptions& session_options) { + OrtSessionOptions& ort_session_options) { + SessionOptions& session_options = ort_session_options.value; const size_t num_ep_options = ep_option_keys.size(); if (ep_option_vals.size() != num_ep_options) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -505,6 +506,38 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); } + + // add custom op domain provided by EP if any + OrtEpFactory* ep_factory = ep_device->ep_factory; + if (ep_factory) { + auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + + if (ep_factory->CreateCustomOpDomain == nullptr) { + continue; + } + + size_t num_domains = 0; + OrtCustomOpDomain** domain_ptrs = nullptr; + ep_factory->CreateCustomOpDomain(domain_ptrs, &num_domains); + + const auto custom_op_domains_span = gsl::span(domain_ptrs, num_domains); + for (auto domain : custom_op_domains_span) { + if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) && + domain->custom_ops_.size() > 0) { + ort_session_options.custom_op_domains_.push_back(domain); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " + << domain->domain_ << " is already in the session option. Skip it."; + } + } + } } return Status::OK(); diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 2ccd4d464a261..da951b5cb9810 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -69,7 +69,7 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, Status AddEpOptionsToSessionOptions(gsl::span ep_devices, gsl::span ep_options_keys, gsl::span ep_options_vals, - SessionOptions& session_options); + OrtSessionOptions& session_options); } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc index 6ac741fb616a8..a944b1cea7515 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc @@ -185,7 +185,7 @@ std::unique_ptr MakeEp(const logging::Logger* logger) { ORT_THROW_IF_ERROR(AddEpOptionsToSessionOptions(state.selected_c_ep_devices, default_ep_option_key_cstrs, default_ep_option_value_cstrs, - ort_session_options.value)); + ort_session_options)); return state.ep_factory->CreateProvider(ort_session_options, *logger->ToExternal()); } From 80561db254fbdb4bcb56e9788b5dea42765b1add Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 8 Dec 2025 17:43:51 -0800 Subject: [PATCH 02/47] update --- .../core/session/onnxruntime_ep_c_api.h | 3 +- onnxruntime/core/session/utils.cc | 8 +- .../library/example_plugin_ep/ep_factory.cc | 24 ++++++ .../library/example_plugin_ep/ep_factory.h | 75 +++++++++++++++++++ onnxruntime/test/autoep/test_execution.cc | 25 +++++++ 5 files changed, 130 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 1aa11790d6fde..9a51715fdf3a3 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1421,7 +1421,8 @@ struct OrtEpFactory { * * \since Version 1.24. */ - ORT_API2_STATUS(CreateCustomOpDomain, _Outptr_result_maybenull_ OrtCustomOpDomain** out, _Out_ size_t* num_domains); + ORT_API2_STATUS(CreateCustomOpDomain, _In_ OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** out, _Out_ size_t* num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 1e16b833e876f..3d53b5420d90a 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -507,7 +507,7 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); } - // add custom op domain provided by EP if any + // add custom op domain provided by EP to the session options if any OrtEpFactory* ep_factory = ep_device->ep_factory; if (ep_factory) { auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { @@ -524,10 +524,10 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic } size_t num_domains = 0; - OrtCustomOpDomain** domain_ptrs = nullptr; - ep_factory->CreateCustomOpDomain(domain_ptrs, &num_domains); + OrtCustomOpDomain* domain_ptrs = nullptr; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomain(ep_factory, &domain_ptrs, &num_domains))); - const auto custom_op_domains_span = gsl::span(domain_ptrs, num_domains); + const auto custom_op_domains_span = gsl::span(&domain_ptrs, num_domains); for (auto domain : custom_op_domains_span) { if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) && domain->custom_ops_.size() > 0) { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 85f9504b14a4d..715e2afcd0d4f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -37,6 +37,8 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + CreateCustomOpDomain = CreateCustomOpDomainImpl; + // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", @@ -68,6 +70,9 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator}; + + // Custom Op Domain + custom_op_domain_ = Ort::CustomOpDomain{"test"}; } /*static*/ @@ -308,3 +313,22 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac return nullptr; } + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** out, + _Out_ size_t* num_domains) noexcept { + auto* factory = static_cast(this_ptr); + + std::vector> created_custom_op_list; + created_custom_op_list.push_back(std::make_unique(factory->ep_name_.c_str(), nullptr)); + created_custom_op_list.back().get()->SetName("VariadicNode"); + factory->custom_op_domain_.Add(created_custom_op_list.back().get()); + + *out = factory->custom_op_domain_; + *num_domains = 1; + + factory->created_custom_op_list_ = std::move(created_custom_op_list); + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 196e67fc5c558..1cdb93ee2774d 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -5,10 +5,77 @@ #include +#include "onnxruntime_c_api.h" #include "ep_arena.h" #include "ep_data_transfer.h" #include "../plugin_ep_utils.h" +// This is a placeholder for "compile-based" plugin EP to provide custom op domains to ORT. +// Please note that this is not for "kernel registration" plugin EP to register kernels. +struct PluginEpCustomKernel { + PluginEpCustomKernel(const OrtKernelInfo* /*info*/, void* compute_stream) + : compute_stream_(compute_stream) { + } + + void Compute(OrtKernelContext* /*context*/) { + // The implementation is in plugin EP's compiled bits. No need to implement it here. + }; + + private: + void* compute_stream_; +}; + +struct PluginEpCustomOp : Ort::CustomOpBase { + explicit PluginEpCustomOp(const char* provider, void* compute_stream) : provider_(provider), + compute_stream_(compute_stream) { + } + + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { + return new PluginEpCustomKernel(info, compute_stream_); + }; + + const char* GetName() const { return name_; }; + + void SetName(const char* name) { name_ = name; }; + + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return num_inputs_; }; + + void SetInputTypeCount(size_t num) { num_inputs_ = num; }; + + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + size_t GetOutputTypeCount() const { return num_outputs_; }; + + void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; + + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + bool GetVariadicInputHomogeneity() const { + return false; // heterogenous + } + + bool GetVariadicOutputHomogeneity() const { + return false; // heterogeneous + } + + private: + const char* provider_ = nullptr; + void* compute_stream_ = nullptr; + const char* name_ = nullptr; + size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input + size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output +}; + /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. /// @@ -67,6 +134,10 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept; + static OrtStatus* ORT_API_CALL CreateCustomOpDomainImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** out, + _Out_ size_t* num_domains) noexcept; + const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name @@ -83,4 +154,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { std::mutex mutex_; // mutex to protect arena_allocator_ and num_arena_users_ std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory + + // std::unique_ptr> custom_op_domain_; + Ort::CustomOpDomain custom_op_domain_; + std::vector> created_custom_op_list_; }; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index bb391bb0bca23..d38981a2d8089 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -282,5 +282,30 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) { gsl::span output_span(output_data, 6); EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84)); } + +// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_Create_OrtCustomOpDomain) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + try { + Ort::Session session(*ort_env, ORT_TSTR("testdata/custom_op_variadic_io.onnx"), session_options); + FAIL(); + } catch (const Ort::Exception& excpt) { + // The session is expected to pass the custom onnx model validation as example plugin EP provides the custom op domain for VariadicNode op" + // If the custom op domain is not provided, the error message should be: + // "Load model from testdata/custom_op_variadic_io.onnx failed:Fatal error: test:VariadicNode(-1) is not a registered function/op" + ASSERT_THAT(excpt.what(), testing::Not(testing::HasSubstr("test:VariadicNode(-1) is not a registered function/op"))); + + // But still, session creation is expected to fail as example plugin EP is not able to run VariadicNode op. + } +} } // namespace test } // namespace onnxruntime From 6bd316f92f51d193413edc1794308b0256767683 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 8 Dec 2025 22:26:13 -0800 Subject: [PATCH 03/47] add API summary --- .../core/session/onnxruntime_ep_c_api.h | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 9a51715fdf3a3..c6f5b69d27fe6 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1414,8 +1414,24 @@ struct OrtEpFactory { */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); - /** \brief + /** \brief Create the EP-specific custom op domain list. * + * This function is used when running inference on a model that contains EP-specific custom operations. + * For compile-based EPs, the EP does not need to provide a concrete kernel implementation for each custom op. + * Instead, it may provide only placeholder custom ops with the correct names so they can be recognized + * during model loading. + * + * Workflow: + * 1. The EP implements this function to supply a list of OrtCustomOpDomain instances. + * 2. The application calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing + * the plugin EP's factory. + * 3. SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomain list to the + * session options. + * + * As a result, any session created from these session options will have these custom op domains registered + * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. + * + * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. * * \snippet{doc} snippets.dox OrtStatus Return Value * From ad0a023a727185c2392626cbf1cf6171329c68a4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 8 Dec 2025 22:37:25 -0800 Subject: [PATCH 04/47] update --- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 2 +- .../core/session/plugin_ep/ep_factory_internal_impl.h | 4 ++-- onnxruntime/core/session/utils.cc | 4 ++-- .../test/autoep/library/example_plugin_ep/ep_factory.cc | 8 ++++---- .../test/autoep/library/example_plugin_ep/ep_factory.h | 6 +++--- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index c6f5b69d27fe6..f8b81ca157f3f 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1437,7 +1437,7 @@ struct OrtEpFactory { * * \since Version 1.24. */ - ORT_API2_STATUS(CreateCustomOpDomain, _In_ OrtEpFactory* this_ptr, + ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Outptr_result_maybenull_ OrtCustomOpDomain** out, _Out_ size_t* num_domains); }; diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index da46694ef7af7..306b0c60a5e23 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -88,8 +88,8 @@ class EpFactoryInternalImpl { return nullptr; } - virtual OrtStatus* CreateCustomOpDomain(_Outptr_result_maybenull_ OrtCustomOpDomain** out, - _Out_ size_t* num_domains) const noexcept { + virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** out, + _Out_ size_t* num_domains) const noexcept { *out = nullptr; *num_domains = 0; return nullptr; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 3d53b5420d90a..139486c0405ff 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -519,13 +519,13 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic return false; }; - if (ep_factory->CreateCustomOpDomain == nullptr) { + if (ep_factory->CreateCustomOpDomains == nullptr) { continue; } size_t num_domains = 0; OrtCustomOpDomain* domain_ptrs = nullptr; - ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomain(ep_factory, &domain_ptrs, &num_domains))); + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory, &domain_ptrs, &num_domains))); const auto custom_op_domains_span = gsl::span(&domain_ptrs, num_domains); for (auto domain : custom_op_domains_span) { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 715e2afcd0d4f..e10bc86a28d19 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -37,7 +37,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; - CreateCustomOpDomain = CreateCustomOpDomainImpl; + CreateCustomOpDomains = CreateCustomOpDomainsImpl; // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. @@ -315,9 +315,9 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac } /*static*/ -OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainImpl(OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** out, - _Out_ size_t* num_domains) noexcept { +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** out, + _Out_ size_t* num_domains) noexcept { auto* factory = static_cast(this_ptr); std::vector> created_custom_op_list; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 1cdb93ee2774d..dd9fc3e8ed0ea 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -134,9 +134,9 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept; - static OrtStatus* ORT_API_CALL CreateCustomOpDomainImpl(OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** out, - _Out_ size_t* num_domains) noexcept; + static OrtStatus* ORT_API_CALL CreateCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** out, + _Out_ size_t* num_domains) noexcept; const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name From 5e398d4310ea93fb39657d033d916c5ed86c573a Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 9 Dec 2025 16:29:34 -0800 Subject: [PATCH 05/47] address reviewer's comments and add GetNumCustomOpDomains --- .../core/session/onnxruntime_ep_c_api.h | 23 ++++++++++-- onnxruntime/core/session/utils.cc | 26 ++++++++----- .../library/example_plugin_ep/ep_factory.cc | 37 +++++++++++++++---- .../library/example_plugin_ep/ep_factory.h | 11 ++++-- 4 files changed, 72 insertions(+), 25 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index f8b81ca157f3f..d276926caa402 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1414,7 +1414,18 @@ struct OrtEpFactory { */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); - /** \brief Create the EP-specific custom op domain list. + /** \brief Returns the number of OrtCustomOpDomains that this factory creates. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] num_domains Output parameter set to the number of created OrtCustomOpDomain instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains); + + /** \brief Creates the EP-specific OrtCustomOpDomains. * * This function is used when running inference on a model that contains EP-specific custom operations. * For compile-based EPs, the EP does not need to provide a concrete kernel implementation for each custom op. @@ -1431,14 +1442,20 @@ struct OrtEpFactory { * As a result, any session created from these session options will have these custom op domains registered * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. * - * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. + * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens + * automatically if using ORT C++ api. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with + OrtCustomOpDomain created by the EP. + * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.24. */ ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** out, _Out_ size_t* num_domains); + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 139486c0405ff..734445fcda5fd 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -507,9 +507,13 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); } - // add custom op domain provided by EP to the session options if any + // Add custom op domain provided by EP to the session options if any. + // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::CreateCustomOpDomains + // were added in ORT 1.24. OrtEpFactory* ep_factory = ep_device->ep_factory; - if (ep_factory) { + if (ep_factory && + ep_factory->ort_version_supported >= 24 && + ep_factory->CreateCustomOpDomains != nullptr) { auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { for (auto ptr : domains) { if (domain_name == ptr->domain_) { @@ -519,16 +523,18 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic return false; }; - if (ep_factory->CreateCustomOpDomains == nullptr) { - continue; - } - size_t num_domains = 0; - OrtCustomOpDomain* domain_ptrs = nullptr; - ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory, &domain_ptrs, &num_domains))); + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); + + InlinedVector domains; + domains.resize(num_domains); + + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory, + domains.data(), + num_domains))); - const auto custom_op_domains_span = gsl::span(&domain_ptrs, num_domains); - for (auto domain : custom_op_domains_span) { + const auto domains_span = gsl::span(domains.data(), domains.size()); + for (auto domain : domains_span) { if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) && domain->custom_ops_.size() > 0) { ort_session_options.custom_op_domains_.push_back(domain); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index e10bc86a28d19..3247c4bcc0011 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -37,6 +37,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + GetNumCustomOpDomains = GetNumCustomOpDomainsImpl; CreateCustomOpDomains = CreateCustomOpDomainsImpl; // setup the OrtMemoryInfo instances required by the EP. @@ -71,8 +72,9 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator}; - // Custom Op Domain - custom_op_domain_ = Ort::CustomOpDomain{"test"}; + // Custom Op Domains + custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; + custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; } /*static*/ @@ -315,20 +317,39 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac } /*static*/ -OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl(OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** out, +OrtStatus* ORT_API_CALL ExampleEpFactory::GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, _Out_ size_t* num_domains) noexcept { auto* factory = static_cast(this_ptr); + *num_domains = factory->custom_op_domains_.size(); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl( + OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept { + auto* factory = static_cast(this_ptr); std::vector> created_custom_op_list; created_custom_op_list.push_back(std::make_unique(factory->ep_name_.c_str(), nullptr)); created_custom_op_list.back().get()->SetName("VariadicNode"); - factory->custom_op_domain_.Add(created_custom_op_list.back().get()); + factory->custom_op_domains_[0].Add(created_custom_op_list.back().get()); + + std::vector> created_custom_op_list_2; + created_custom_op_list_2.push_back(std::make_unique(factory->ep_name_.c_str(), nullptr)); + created_custom_op_list_2.back().get()->SetName("VariadicNode2"); + factory->custom_op_domains_[1].Add(created_custom_op_list_2.back().get()); - *out = factory->custom_op_domain_; - *num_domains = 1; + // The `num_domains` should be 2 as ORT calls GetNumCustomOpDomainsImpl() to get the number prior to + // call this function. + gsl::span domains_span(domains, num_domains); + domains_span[0] = factory->custom_op_domains_[0]; + domains_span[1] = factory->custom_op_domains_[1]; - factory->created_custom_op_list_ = std::move(created_custom_op_list); + factory->created_custom_op_lists_[0] = std::move(created_custom_op_list); + factory->created_custom_op_lists_[1] = std::move(created_custom_op_list_2); return nullptr; } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index dd9fc3e8ed0ea..028b5b64a7ff4 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -134,10 +134,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept; - static OrtStatus* ORT_API_CALL CreateCustomOpDomainsImpl(OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** out, + static OrtStatus* ORT_API_CALL GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, _Out_ size_t* num_domains) noexcept; + static OrtStatus* ORT_API_CALL CreateCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept; + const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name @@ -156,6 +159,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory // std::unique_ptr> custom_op_domain_; - Ort::CustomOpDomain custom_op_domain_; - std::vector> created_custom_op_list_; + std::vector custom_op_domains_{2}; + std::vector>> created_custom_op_lists_{2}; }; From aeb238657c8871fab22e88928617d7e31f08b20f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Dec 2025 15:02:55 -0800 Subject: [PATCH 06/47] update example ep to run Custom_Mul op --- .../python/onnxruntime_pybind_state.cc | 2 +- .../autoep/library/example_plugin_ep/ep.cc | 218 +++++++++--------- .../autoep/library/example_plugin_ep/ep.h | 29 ++- .../library/example_plugin_ep/ep_factory.cc | 23 +- .../library/example_plugin_ep/ep_factory.h | 59 +++-- onnxruntime/test/autoep/test_execution.cc | 43 +++- onnxruntime/test/testdata/custom_mul.onnx | Bin 0 -> 155 bytes onnxruntime/test/testdata/custom_mul.py | 42 ++++ 8 files changed, 265 insertions(+), 151 deletions(-) create mode 100644 onnxruntime/test/testdata/custom_mul.onnx create mode 100644 onnxruntime/test/testdata/custom_mul.py diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 0bd0daf837645..7a8c11b713c76 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1346,7 +1346,7 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(ep_devices, ep_option_keys, ep_option_vals, - py_sess_options.value)); + py_sess_options)); py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index bce9b59ff0ea4..254bee4e2fd4e 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -15,117 +15,97 @@ #include "ep_factory.h" #include "ep_stream_support.h" -/// -/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. -/// -struct MulKernel { - MulKernel(const OrtApi& ort_api, const OrtLogger& logger, - const std::unordered_map& float_initializers, - std::string input0_name, std::string input1_name) - : ort_api(ort_api), - logger(logger), - float_initializers(float_initializers), - input0_name(input0_name), - input1_name(input1_name) {} - - const FloatInitializer* TryGetSavedInitializer(const std::string& name) const { - auto iter = float_initializers.find(name); - return iter != float_initializers.end() ? &iter->second : nullptr; - } - - void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) const { - Ort::ConstValue input = kernel_context.GetInput(index); - auto type_shape = input.GetTensorTypeAndShapeInfo(); +const FloatInitializer* MulKernel::TryGetSavedInitializer(const std::string& name) const { + auto iter = float_initializers.find(name); + return iter != float_initializers.end() ? &iter->second : nullptr; +} - ONNXTensorElementDataType elem_type = type_shape.GetElementType(); - if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) - throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); +void MulKernel::GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); - const float* float_data = input.GetTensorData(); - size_t num_elems = type_shape.GetElementCount(); - data = gsl::span(float_data, num_elems); - shape = type_shape.GetShape(); - } + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); - OrtStatus* Compute(OrtKernelContext* kernel_ctx) { - RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - Ort::KernelContext kernel_context(kernel_ctx); - try { - gsl::span input0; - gsl::span input1; - std::vector shape0; - std::vector shape1; - - size_t num_inputs = kernel_context.GetInputCount(); - - if (num_inputs == 2) { - // Both inputs are non-constant. Get them from ORT's KernelContext. - GetInputDataAndShape(kernel_context, 0, input0, shape0); - GetInputDataAndShape(kernel_context, 1, input1, shape1); - } else if (num_inputs == 1) { - // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. - // Get the constant input from the initializers saved by the EP. - // Refer to "NodeFusionOptions_DropConstantInitializers()". - - if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input1, shape1); - input0 = gsl::span(const_input0->data); - shape0 = const_input0->shape; - } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input0, shape0); - input1 = gsl::span(const_input1->data); - shape1 = const_input1->shape; - } - } else { - // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) - // are disabled. - const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); - const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); - RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, - "Expected 2 initializer inputs to be saved by EP"); + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); + data = gsl::span(float_data, num_elems); + shape = type_shape.GetShape(); +} +OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input1, shape1); input0 = gsl::span(const_input0->data); - input1 = gsl::span(const_input1->data); shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input0, shape0); + input1 = gsl::span(const_input1->data); shape1 = const_input1->shape; } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + + input0 = gsl::span(const_input0->data); + input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; + shape1 = const_input1->shape; + } - if (shape0 != shape1) { - throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); - } + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } - size_t num_outputs = kernel_context.GetOutputCount(); - if (num_outputs != 1) { - throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); - } + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + } - auto output = kernel_context.GetOutput(0, shape0); - float* output_data = output.GetTensorMutableData(); + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; - } - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } catch (const std::exception& ex) { - Ort::Status status(ex.what(), ORT_EP_FAIL); - return status.release(); + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; } - - return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } - const OrtApi& ort_api; - const OrtLogger& logger; - const std::unordered_map& float_initializers; - std::string input0_name; - std::string input1_name; -}; + return nullptr; +} /// /// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. @@ -226,6 +206,17 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; // No nodes to process } + if (nodes.size() != 1) { + Ort::Status status("Expected to get capability from a model with only one node", ORT_EP_FAIL); + return status.release(); + } + + auto node_op_type = nodes[0].GetOperatorType(); + if (node_op_type != "Mul" && node_op_type != "Custom_Mul") { + Ort::Status status("Expected to get capability from a model with only a Mul or Custom_Mul node", ORT_EP_FAIL); + return status.release(); + } + std::vector supported_nodes; for (const auto& node : nodes) { @@ -262,6 +253,8 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG supported_nodes.push_back(node); // Only support a single Mul for now. break; + } else if (op_type == "Custom_Mul") { + supported_nodes.push_back(node); } } @@ -269,19 +262,26 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; } - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; - - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, - reinterpret_cast(supported_nodes.data()), - supported_nodes.size(), - &node_fusion_options)); + if (nodes[0].GetOperatorType() == "Mul") { + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } else { + // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, + // as CustomMul has the concrete kernel implementation. + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); + } + } catch (const Ort::Exception& ex) { Ort::Status status(ex); return status.release(); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 7e96a523cf285..5d4788ed76bf2 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -8,7 +8,34 @@ #include "../plugin_ep_utils.h" class ExampleEpFactory; -struct MulKernel; + +/// +/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. +/// +struct MulKernel { + MulKernel(const OrtApi& ort_api, const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, std::string input1_name) + : ort_api(ort_api), + logger(logger), + float_initializers(float_initializers), + input0_name(input0_name), + input1_name(input1_name) {} + + const FloatInitializer* TryGetSavedInitializer(const std::string& name) const; + + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const; + + OrtStatus* Compute(OrtKernelContext* kernel_ctx); + + const OrtApi& ort_api; + const OrtLogger& logger; + const std::unordered_map& float_initializers; + std::string input0_name; + std::string input1_name; +}; /// /// Example EP that can compile a single Mul operator. diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 3247c4bcc0011..fa789f217a596 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -332,14 +332,14 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl( _Out_ size_t num_domains) noexcept { auto* factory = static_cast(this_ptr); - std::vector> created_custom_op_list; - created_custom_op_list.push_back(std::make_unique(factory->ep_name_.c_str(), nullptr)); - created_custom_op_list.back().get()->SetName("VariadicNode"); + std::vector> created_custom_op_list; + created_custom_op_list.push_back(std::make_unique(factory->ep_name_.c_str(), factory)); + created_custom_op_list.back().get()->SetName("Custom_Mul"); factory->custom_op_domains_[0].Add(created_custom_op_list.back().get()); - std::vector> created_custom_op_list_2; - created_custom_op_list_2.push_back(std::make_unique(factory->ep_name_.c_str(), nullptr)); - created_custom_op_list_2.back().get()->SetName("VariadicNode2"); + std::vector> created_custom_op_list_2; + created_custom_op_list_2.push_back(std::make_unique(factory->ep_name_.c_str(), factory)); + created_custom_op_list_2.back().get()->SetName("Custom_Mul2"); factory->custom_op_domains_[1].Add(created_custom_op_list_2.back().get()); // The `num_domains` should be 2 as ORT calls GetNumCustomOpDomainsImpl() to get the number prior to @@ -353,3 +353,14 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl( return nullptr; } + +void* ExampleEpCustomOp::CreateKernel(const OrtApi& /*api*/, const OrtKernelInfo* /*info*/) const { + std::string node_input_0 = "X"; + std::string node_input_1 = "W"; + auto custom_kernel_op = std::make_unique(factory_->ort_api, + factory_->default_logger_, + float_initializers_, + node_input_0, + node_input_1); + return custom_kernel_op.release(); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 028b5b64a7ff4..bf474160fb24f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -9,30 +9,43 @@ #include "ep_arena.h" #include "ep_data_transfer.h" #include "../plugin_ep_utils.h" - -// This is a placeholder for "compile-based" plugin EP to provide custom op domains to ORT. -// Please note that this is not for "kernel registration" plugin EP to register kernels. -struct PluginEpCustomKernel { - PluginEpCustomKernel(const OrtKernelInfo* /*info*/, void* compute_stream) - : compute_stream_(compute_stream) { +#include "ep.h" + +// Plugin EPs can provide two types of custom ops: +// +// 1. A full OrtCustomOp with a concrete kernel implementation +// - This Example EP demonstrates this approach. +// - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT +// that the custom node should NOT be fused or compiled. Instead, ORT should invoke +// the custom node's Compute() function at runtime. +// +// 2. A "placeholder" OrtCustomOp with an empty kernel implementation +// - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() +// does nothing. The purpose is to satisfy model validation during model loading by +// registering the custom op as a valid operator in the session. +// - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to +// notify ORT that this custom node should be fused and compiled by the EP. +// - In Compile(), the EP executes its compiled bits to perform inference for +// the fused custom node. +// +// Note: Approach #2 is suitable for plugin TRT RTX EP to support TRT plugins. + +struct CustomMulKernel : MulKernel { + CustomMulKernel(const OrtApi& ort_api, + const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, + std::string input1_name) : MulKernel(ort_api, logger, float_initializers, + input0_name, input1_name) { } - - void Compute(OrtKernelContext* /*context*/) { - // The implementation is in plugin EP's compiled bits. No need to implement it here. - }; - - private: - void* compute_stream_; }; -struct PluginEpCustomOp : Ort::CustomOpBase { - explicit PluginEpCustomOp(const char* provider, void* compute_stream) : provider_(provider), - compute_stream_(compute_stream) { +struct ExampleEpCustomOp : Ort::CustomOpBase { + explicit ExampleEpCustomOp(const char* provider, ExampleEpFactory* factory) : provider_(provider), + factory_(factory) { } - void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { - return new PluginEpCustomKernel(info, compute_stream_); - }; + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; const char* GetName() const { return name_; }; @@ -70,10 +83,11 @@ struct PluginEpCustomOp : Ort::CustomOpBase float_initializers_; }; /// @@ -92,6 +106,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return arena_allocator_.get(); } + const OrtLogger& default_logger_; // default logger for the EP factory + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -141,7 +157,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains) noexcept; - const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID @@ -160,5 +175,5 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { // std::unique_ptr> custom_op_domain_; std::vector custom_op_domains_{2}; - std::vector>> created_custom_op_lists_{2}; + std::vector>> created_custom_op_lists_{2}; }; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index d38981a2d8089..fda0ec4c5e53b 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -47,6 +47,35 @@ void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } +void RunCustomMulModelWithPluginEp(const Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/custom_mul.onnx"), session_options); + + // Create two inputs with same values + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + ort_input_names.push_back("W"); + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(1, 4, 9, 16, 25, 36)); +} + void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) { // This model has Add -> Mul -> Add. The example plugin EP only supports Mul. Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options); @@ -283,7 +312,7 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) { EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84)); } -// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. // Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. TEST(OrtEpLibrary, PluginEp_Create_OrtCustomOpDomain) { RegisteredEpDeviceUniquePtr example_ep; @@ -295,17 +324,7 @@ TEST(OrtEpLibrary, PluginEp_Create_OrtCustomOpDomain) { std::unordered_map ep_options; session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - try { - Ort::Session session(*ort_env, ORT_TSTR("testdata/custom_op_variadic_io.onnx"), session_options); - FAIL(); - } catch (const Ort::Exception& excpt) { - // The session is expected to pass the custom onnx model validation as example plugin EP provides the custom op domain for VariadicNode op" - // If the custom op domain is not provided, the error message should be: - // "Load model from testdata/custom_op_variadic_io.onnx failed:Fatal error: test:VariadicNode(-1) is not a registered function/op" - ASSERT_THAT(excpt.what(), testing::Not(testing::HasSubstr("test:VariadicNode(-1) is not a registered function/op"))); - - // But still, session creation is expected to fail as example plugin EP is not able to run VariadicNode op. - } + RunCustomMulModelWithPluginEp(session_options); } } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/custom_mul.onnx b/onnxruntime/test/testdata/custom_mul.onnx new file mode 100644 index 0000000000000000000000000000000000000000..87bb64764a66956b18c70fdc21e1789c2e315091 GIT binary patch literal 155 zcmdEiTE=jn6I3iBBrc%t=WtvdZC-=3B1g3o@dFfr^FrxOg}ig*dpFIGBN$2_zVfE|>%qj6@f7V&P&C;C15X;!e)b P)l02N%q_@C6<`DaJ=7ld literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/custom_mul.py b/onnxruntime/test/testdata/custom_mul.py new file mode 100644 index 0000000000000..ee5fb229503e7 --- /dev/null +++ b/onnxruntime/test/testdata/custom_mul.py @@ -0,0 +1,42 @@ +import onnx +from onnx import helper, TensorProto + +def create_custom_mul_model(): + # === Inputs === + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) + w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [3, 2]) + + # === Output === + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2]) + + # === Custom Node: Custom_Mul === + # Replace "Mul" with your custom op name and domain + custom_node = helper.make_node( + op_type="Custom_Mul", # <-- custom op name + inputs=["X", "W"], + outputs=["Y"], + domain="test" # <-- custom domain + ) + + # === Graph === + graph = helper.make_graph( + nodes=[custom_node], + name="CustomMulGraph", + inputs=[x, w], + outputs=[y], + ) + + # === Model (opset version 13 or later is fine) === + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 13), # standard ONNX domain + helper.make_opsetid("com.example", 1)], # your custom domain + producer_name="custom_mul_builder" + ) + + return model + +# ===== Save the Model ===== +model = create_custom_mul_model() +onnx.save(model, "custom_mul.onnx") +print("Saved custom_mul.onnx") From 3849cd3f3fc93e706b2669db603de1f536184949 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Dec 2025 15:09:54 -0800 Subject: [PATCH 07/47] address reviewr's comment --- onnxruntime/core/session/utils.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 734445fcda5fd..d155aa1308e38 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -514,7 +514,7 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic if (ep_factory && ep_factory->ort_version_supported >= 24 && ep_factory->CreateCustomOpDomains != nullptr) { - auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + auto is_already_in_domains = [&](const std::string& domain_name, const std::vector& domains) { for (auto ptr : domains) { if (domain_name == ptr->domain_) { return true; From 9c987be28f9b9a471bf219de21bf9d76e53b73a0 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Dec 2025 15:10:56 -0800 Subject: [PATCH 08/47] lintrunner -a --- onnxruntime/test/testdata/custom_mul.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/testdata/custom_mul.py b/onnxruntime/test/testdata/custom_mul.py index ee5fb229503e7..c8fd8b0b720a3 100644 --- a/onnxruntime/test/testdata/custom_mul.py +++ b/onnxruntime/test/testdata/custom_mul.py @@ -1,5 +1,6 @@ import onnx -from onnx import helper, TensorProto +from onnx import TensorProto, helper + def create_custom_mul_model(): # === Inputs === @@ -12,10 +13,10 @@ def create_custom_mul_model(): # === Custom Node: Custom_Mul === # Replace "Mul" with your custom op name and domain custom_node = helper.make_node( - op_type="Custom_Mul", # <-- custom op name + op_type="Custom_Mul", # <-- custom op name inputs=["X", "W"], outputs=["Y"], - domain="test" # <-- custom domain + domain="test", # <-- custom domain ) # === Graph === @@ -29,13 +30,16 @@ def create_custom_mul_model(): # === Model (opset version 13 or later is fine) === model = helper.make_model( graph, - opset_imports=[helper.make_opsetid("", 13), # standard ONNX domain - helper.make_opsetid("com.example", 1)], # your custom domain - producer_name="custom_mul_builder" + opset_imports=[ + helper.make_opsetid("", 13), # standard ONNX domain + helper.make_opsetid("com.example", 1), + ], # your custom domain + producer_name="custom_mul_builder", ) return model + # ===== Save the Model ===== model = create_custom_mul_model() onnx.save(model, "custom_mul.onnx") From fbe24345d1b7aa4fd0391d8890a9d4bf2cbd9753 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 10 Dec 2025 15:52:22 -0800 Subject: [PATCH 09/47] update example ep GetCapability() --- .../test/autoep/library/example_plugin_ep/ep.cc | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 254bee4e2fd4e..41fdbc51a0ff0 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -206,17 +206,6 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; // No nodes to process } - if (nodes.size() != 1) { - Ort::Status status("Expected to get capability from a model with only one node", ORT_EP_FAIL); - return status.release(); - } - - auto node_op_type = nodes[0].GetOperatorType(); - if (node_op_type != "Mul" && node_op_type != "Custom_Mul") { - Ort::Status status("Expected to get capability from a model with only a Mul or Custom_Mul node", ORT_EP_FAIL); - return status.release(); - } - std::vector supported_nodes; for (const auto& node : nodes) { @@ -262,7 +251,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; } - if (nodes[0].GetOperatorType() == "Mul") { + if (supported_nodes[0].GetOperatorType() == "Mul") { // Create (optional) fusion options for the supported nodes to fuse. OrtNodeFusionOptions node_fusion_options = {}; node_fusion_options.ort_version_supported = ORT_API_VERSION; @@ -276,7 +265,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG reinterpret_cast(supported_nodes.data()), supported_nodes.size(), &node_fusion_options)); - } else { + } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, // as CustomMul has the concrete kernel implementation. RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); From 40fa8fe518c16538472dcba4c4e218fbf0b12a51 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 11 Dec 2025 11:25:18 -0800 Subject: [PATCH 10/47] update Example EP --- .../test/autoep/library/example_plugin_ep/ep_factory.cc | 8 ++++---- .../test/autoep/library/example_plugin_ep/ep_factory.h | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index fa789f217a596..1ee3f433606b9 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -71,10 +71,6 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator}; - - // Custom Op Domains - custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; - custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; } /*static*/ @@ -332,6 +328,10 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl( _Out_ size_t num_domains) noexcept { auto* factory = static_cast(this_ptr); + // Custom Op Domains + factory->custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; + factory->custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; + std::vector> created_custom_op_list; created_custom_op_list.push_back(std::make_unique(factory->ep_name_.c_str(), factory)); created_custom_op_list.back().get()->SetName("Custom_Mul"); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index bf474160fb24f..a03c5a0083792 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -173,7 +173,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory - // std::unique_ptr> custom_op_domain_; std::vector custom_op_domains_{2}; std::vector>> created_custom_op_lists_{2}; }; From c7a0491efaf5912318b46dc260711d0312714d87 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 12 Dec 2025 09:46:07 -0800 Subject: [PATCH 11/47] add more comments in API summary --- .../core/session/onnxruntime_ep_c_api.h | 19 ++++++++++++++++--- onnxruntime/core/session/utils.cc | 19 ++++++++++--------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index d276926caa402..d1a01af4d1046 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1428,9 +1428,6 @@ struct OrtEpFactory { /** \brief Creates the EP-specific OrtCustomOpDomains. * * This function is used when running inference on a model that contains EP-specific custom operations. - * For compile-based EPs, the EP does not need to provide a concrete kernel implementation for each custom op. - * Instead, it may provide only placeholder custom ops with the correct names so they can be recognized - * during model loading. * * Workflow: * 1. The EP implements this function to supply a list of OrtCustomOpDomain instances. @@ -1442,6 +1439,22 @@ struct OrtEpFactory { * As a result, any session created from these session options will have these custom op domains registered * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. * + * Plugin EPs can provide two types of custom ops: + * 1. A full OrtCustomOp with a concrete kernel implementation + * - This Example EP demonstrates this approach. + * - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT + * that the custom node should NOT be fused or compiled. Instead, ORT should invoke + * the custom node's Compute() function at runtime. + * + * 2. A "placeholder" OrtCustomOp with an empty kernel implementation + * - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() + * does nothing. The purpose is to satisfy model validation during model loading by + * registering the custom op as a valid operator in the session. + * - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to + * notify ORT that this custom node should be fused and compiled by the EP. + * - In Compile(), the EP executes its compiled bits to perform inference for + * the fused custom node. + * * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens * automatically if using ORT C++ api. * diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index d155aa1308e38..830f0ac378e1d 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -514,14 +514,15 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic if (ep_factory && ep_factory->ort_version_supported >= 24 && ep_factory->CreateCustomOpDomains != nullptr) { - auto is_already_in_domains = [&](const std::string& domain_name, const std::vector& domains) { - for (auto ptr : domains) { - if (domain_name == ptr->domain_) { - return true; - } - } - return false; - }; + auto is_already_in_domains = + [&](const std::string& domain_name, const std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; size_t num_domains = 0; ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); @@ -531,7 +532,7 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory, domains.data(), - num_domains))); + domains.size()))); const auto domains_span = gsl::span(domains.data(), domains.size()); for (auto domain : domains_span) { From 4787c3f22d9528556a9ff9b41d4aff76f26db0e3 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 18 Dec 2025 00:39:54 -0800 Subject: [PATCH 12/47] address reviewer's comments --- .../core/session/onnxruntime_ep_c_api.h | 2 +- .../plugin_ep/ep_factory_internal_impl.h | 9 +- onnxruntime/core/session/utils.cc | 1 + .../library/example_plugin_ep/ep_custom_op.h | 88 +++++++++++++++++++ .../library/example_plugin_ep/ep_factory.h | 81 +---------------- 5 files changed, 97 insertions(+), 84 deletions(-) create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index d1a01af4d1046..67d0aeddf015e 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1468,7 +1468,7 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains); + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _In_ size_t num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 306b0c60a5e23..d281465dd3e17 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -88,13 +88,16 @@ class EpFactoryInternalImpl { return nullptr; } - virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** out, - _Out_ size_t* num_domains) const noexcept { - *out = nullptr; + virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept { *num_domains = 0; return nullptr; } + virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** /*domains*/, + _In_ size_t /*num_domains*/) const noexcept { + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 830f0ac378e1d..1e6fcb9743d28 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -513,6 +513,7 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic OrtEpFactory* ep_factory = ep_device->ep_factory; if (ep_factory && ep_factory->ort_version_supported >= 24 && + ep_factory->GetNumCustomOpDomains != nullptr && ep_factory->CreateCustomOpDomains != nullptr) { auto is_already_in_domains = [&](const std::string& domain_name, const std::vector& domains) { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h new file mode 100644 index 0000000000000..a49a772d676f4 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "onnxruntime_c_api.h" +#include "ep.h" + + +// Plugin EPs can provide two types of custom ops: +// +// 1. A full OrtCustomOp with a concrete kernel implementation +// - This Example EP demonstrates this approach. +// - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT +// that the custom node should NOT be fused or compiled. Instead, ORT should invoke +// the custom node's Compute() function at runtime. +// +// 2. A "placeholder" OrtCustomOp with an empty kernel implementation +// - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() +// does nothing. The purpose is to satisfy model validation during model loading by +// registering the custom op as a valid operator in the session. +// - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to +// notify ORT that this custom node should be fused and compiled by the EP. +// - In Compile(), the EP executes its compiled bits to perform inference for +// the fused custom node. +// +// Note: Approach #2 is suitable for plugin TRT RTX EP to support TRT plugins. + +struct CustomMulKernel : MulKernel { + CustomMulKernel(const OrtApi& ort_api, + const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, + std::string input1_name) : MulKernel(ort_api, logger, float_initializers, + input0_name, input1_name) { + } +}; + +struct ExampleEpCustomOp : Ort::CustomOpBase { + explicit ExampleEpCustomOp(const char* provider, ExampleEpFactory* factory) : provider_(provider), + factory_(factory) { + } + + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; + + const char* GetName() const { return name_; }; + + void SetName(const char* name) { name_ = name; }; + + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return num_inputs_; }; + + void SetInputTypeCount(size_t num) { num_inputs_ = num; }; + + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + size_t GetOutputTypeCount() const { return num_outputs_; }; + + void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; + + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + bool GetVariadicInputHomogeneity() const { + return false; // heterogenous + } + + bool GetVariadicOutputHomogeneity() const { + return false; // heterogeneous + } + + private: + const char* provider_ = nullptr; + const char* name_ = nullptr; + size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input + size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output + ExampleEpFactory* factory_ = nullptr; + std::unordered_map float_initializers_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index a03c5a0083792..d262471000c9e 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -5,90 +5,11 @@ #include -#include "onnxruntime_c_api.h" #include "ep_arena.h" #include "ep_data_transfer.h" #include "../plugin_ep_utils.h" #include "ep.h" - -// Plugin EPs can provide two types of custom ops: -// -// 1. A full OrtCustomOp with a concrete kernel implementation -// - This Example EP demonstrates this approach. -// - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT -// that the custom node should NOT be fused or compiled. Instead, ORT should invoke -// the custom node's Compute() function at runtime. -// -// 2. A "placeholder" OrtCustomOp with an empty kernel implementation -// - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() -// does nothing. The purpose is to satisfy model validation during model loading by -// registering the custom op as a valid operator in the session. -// - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to -// notify ORT that this custom node should be fused and compiled by the EP. -// - In Compile(), the EP executes its compiled bits to perform inference for -// the fused custom node. -// -// Note: Approach #2 is suitable for plugin TRT RTX EP to support TRT plugins. - -struct CustomMulKernel : MulKernel { - CustomMulKernel(const OrtApi& ort_api, - const OrtLogger& logger, - const std::unordered_map& float_initializers, - std::string input0_name, - std::string input1_name) : MulKernel(ort_api, logger, float_initializers, - input0_name, input1_name) { - } -}; - -struct ExampleEpCustomOp : Ort::CustomOpBase { - explicit ExampleEpCustomOp(const char* provider, ExampleEpFactory* factory) : provider_(provider), - factory_(factory) { - } - - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; - - const char* GetName() const { return name_; }; - - void SetName(const char* name) { name_ = name; }; - - const char* GetExecutionProviderType() const { return provider_; }; - - size_t GetInputTypeCount() const { return num_inputs_; }; - - void SetInputTypeCount(size_t num) { num_inputs_ = num; }; - - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; - - OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; - }; - - size_t GetOutputTypeCount() const { return num_outputs_; }; - - void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; - - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; - - OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; - }; - - bool GetVariadicInputHomogeneity() const { - return false; // heterogenous - } - - bool GetVariadicOutputHomogeneity() const { - return false; // heterogeneous - } - - private: - const char* provider_ = nullptr; - const char* name_ = nullptr; - size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input - size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output - ExampleEpFactory* factory_ = nullptr; - std::unordered_map float_initializers_; -}; +#include "ep_custom_op.h" /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. From 632ce3141cbbbb557184253e832aca24b1d223b8 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 18 Dec 2025 15:32:09 -0800 Subject: [PATCH 13/47] lintrunner -a --- onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h index a49a772d676f4..00775f9a8d9e2 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h @@ -7,7 +7,6 @@ #include "onnxruntime_c_api.h" #include "ep.h" - // Plugin EPs can provide two types of custom ops: // // 1. A full OrtCustomOp with a concrete kernel implementation From 6017c00902b99a3f783578df3dea96fcd99dfc0c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 7 Jan 2026 17:24:59 -0800 Subject: [PATCH 14/47] Use CreateKernelV2 and ComputeKernelV2 --- .../autoep/library/example_plugin_ep/ep_custom_op.h | 10 ++++++++-- .../autoep/library/example_plugin_ep/ep_factory.cc | 11 +++++++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h index 00775f9a8d9e2..c37038a727067 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h @@ -34,14 +34,20 @@ struct CustomMulKernel : MulKernel { std::string input1_name) : MulKernel(ort_api, logger, float_initializers, input0_name, input1_name) { } + + OrtStatusPtr ComputeV2(OrtKernelContext* kernel_ctx) { + return MulKernel::Compute(kernel_ctx); + } }; -struct ExampleEpCustomOp : Ort::CustomOpBase { +struct ExampleEpCustomOp : Ort::CustomOpBase { explicit ExampleEpCustomOp(const char* provider, ExampleEpFactory* factory) : provider_(provider), factory_(factory) { } - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const; + OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const; + + OrtStatusPtr KernelComputeV2(void* op_kernel, OrtKernelContext* context) const; const char* GetName() const { return name_; }; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 1ee3f433606b9..5dcc675c85641 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -354,7 +354,9 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl( return nullptr; } -void* ExampleEpCustomOp::CreateKernel(const OrtApi& /*api*/, const OrtKernelInfo* /*info*/) const { +OrtStatusPtr ExampleEpCustomOp::CreateKernelV2(const OrtApi& /*api*/, + const OrtKernelInfo* /*info*/, + void** op_kernel) const { std::string node_input_0 = "X"; std::string node_input_1 = "W"; auto custom_kernel_op = std::make_unique(factory_->ort_api, @@ -362,5 +364,10 @@ void* ExampleEpCustomOp::CreateKernel(const OrtApi& /*api*/, const OrtKernelInfo float_initializers_, node_input_0, node_input_1); - return custom_kernel_op.release(); + *op_kernel = custom_kernel_op.release(); + return nullptr; +} + +OrtStatusPtr ExampleEpCustomOp::KernelComputeV2(void* op_kernel, OrtKernelContext* context) const { + return static_cast(op_kernel)->ComputeV2(context); } From 47bb4dcde5b3950e69f20a7f67cf229b748835c1 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 7 Jan 2026 17:35:06 -0800 Subject: [PATCH 15/47] address reviewer's comments --- .../core/session/onnxruntime_ep_c_api.h | 6 +- .../tensorrt/utilities/common/exceptions.h | 71 +++++++++++++++++++ .../plugin_ep/ep_factory_internal_impl.h | 2 +- onnxruntime/core/session/utils.cc | 6 +- .../library/example_plugin_ep/ep_factory.cc | 4 +- .../library/example_plugin_ep/ep_factory.h | 2 +- onnxruntime/test/autoep/test_execution.cc | 4 +- 7 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 2f823a4bcfcbb..da31a13ddc6e2 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1576,7 +1576,7 @@ struct OrtEpFactory { */ ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains); - /** \brief Creates the EP-specific OrtCustomOpDomains. + /** \brief Gets the EP-specific OrtCustomOpDomains. * * This function is used when running inference on a model that contains EP-specific custom operations. * @@ -1618,8 +1618,8 @@ struct OrtEpFactory { * * \since Version 1.24. */ - ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _In_ size_t num_domains); + ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr, + _Outptr_ OrtCustomOpDomain** domains, _In_ size_t num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h b/onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h new file mode 100644 index 0000000000000..494a770b8db98 --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/common/code_location.h" + +namespace onnxruntime { + +class NotImplementedException : public std::logic_error { + public: + explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; + explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; +}; + +class TypeMismatchException : public std::logic_error { + public: + TypeMismatchException() noexcept : logic_error("Type mismatch") {}; +}; + +class OnnxRuntimeException : public std::exception { + public: + OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept + : OnnxRuntimeException(location, nullptr, msg) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) + : location_{location} { + std::ostringstream ss; + + ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous + if (failed_condition != nullptr) { + ss << " " << failed_condition << " was false."; + } + + ss << " " << msg << "\n"; + if (!location.stacktrace.empty()) { + ss << "Stacktrace:\n"; + // skip the first entry in the stacktrace as we have that information from location.ToString() + std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); + } + + what_ = ss.str(); + } + + const char* what() const noexcept override { + return what_.c_str(); + } + + private: + const CodeLocation location_; + const std::vector stacktrace_; + std::string what_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index d281465dd3e17..56fda0074147e 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -93,7 +93,7 @@ class EpFactoryInternalImpl { return nullptr; } - virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** /*domains*/, + virtual OrtStatus* GetCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** /*domains*/, _In_ size_t /*num_domains*/) const noexcept { return nullptr; } diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index d33a542979125..6a191b3e48fb0 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -545,13 +545,13 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic } // Add custom op domain provided by EP to the session options if any. - // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::CreateCustomOpDomains + // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains // were added in ORT 1.24. OrtEpFactory* ep_factory = ep_device->ep_factory; if (ep_factory && ep_factory->ort_version_supported >= 24 && ep_factory->GetNumCustomOpDomains != nullptr && - ep_factory->CreateCustomOpDomains != nullptr) { + ep_factory->GetCustomOpDomains != nullptr) { auto is_already_in_domains = [&](const std::string& domain_name, const std::vector& domains) { for (auto ptr : domains) { @@ -568,7 +568,7 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic InlinedVector domains; domains.resize(num_domains); - ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory, + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, domains.data(), domains.size()))); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 5dcc675c85641..9c90503bf3fc8 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -38,7 +38,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; GetNumCustomOpDomains = GetNumCustomOpDomainsImpl; - CreateCustomOpDomains = CreateCustomOpDomainsImpl; + GetCustomOpDomains = GetCustomOpDomainsImpl; // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. @@ -322,7 +322,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetNumCustomOpDomainsImpl(OrtEpFactory } /*static*/ -OrtStatus* ORT_API_CALL ExampleEpFactory::CreateCustomOpDomainsImpl( +OrtStatus* ORT_API_CALL ExampleEpFactory::GetCustomOpDomainsImpl( OrtEpFactory* this_ptr, _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains) noexcept { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index d262471000c9e..0ca53f54d6735 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -74,7 +74,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, _Out_ size_t* num_domains) noexcept; - static OrtStatus* ORT_API_CALL CreateCustomOpDomainsImpl(OrtEpFactory* this_ptr, + static OrtStatus* ORT_API_CALL GetCustomOpDomainsImpl(OrtEpFactory* this_ptr, _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains) noexcept; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index c47a154a9a33e..6643de46fb9dd 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -46,8 +46,8 @@ void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { gsl::span output_span(output_data, 6); EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } - - void RunCustomMulModelWithPluginEp(const Ort::SessionOptions& session_options) { + +void RunCustomMulModelWithPluginEp(const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, ORT_TSTR("testdata/custom_mul.onnx"), session_options); // Create two inputs with same values From 6721a98904a1f5ba8cbe206c4a40aa842a2dda55 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 7 Jan 2026 19:42:26 -0800 Subject: [PATCH 16/47] lintrunner -a --- onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h | 2 +- onnxruntime/core/session/utils.cc | 4 ++-- .../test/autoep/library/example_plugin_ep/ep_factory.h | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 56fda0074147e..2cc8b4182abdc 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -94,7 +94,7 @@ class EpFactoryInternalImpl { } virtual OrtStatus* GetCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** /*domains*/, - _In_ size_t /*num_domains*/) const noexcept { + _In_ size_t /*num_domains*/) const noexcept { return nullptr; } diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 6a191b3e48fb0..699a3d89f6784 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -569,8 +569,8 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic domains.resize(num_domains); ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, - domains.data(), - domains.size()))); + domains.data(), + domains.size()))); const auto domains_span = gsl::span(domains.data(), domains.size()); for (auto domain : domains_span) { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 0ca53f54d6735..e0f9440d6fc25 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -75,8 +75,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { _Out_ size_t* num_domains) noexcept; static OrtStatus* ORT_API_CALL GetCustomOpDomainsImpl(OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** domains, - _Out_ size_t num_domains) noexcept; + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept; const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name From 1ab246d118e7856f19429b3ff432a632ddc84b95 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 7 Jan 2026 19:57:20 -0800 Subject: [PATCH 17/47] update --- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index da31a13ddc6e2..bece95b5cfdaf 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1619,7 +1619,7 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr, - _Outptr_ OrtCustomOpDomain** domains, _In_ size_t num_domains); + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _In_ size_t num_domains); }; #ifdef __cplusplus From 3478732552a9650cf10599a5cb3afb425279a475 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 8 Jan 2026 15:51:16 -0800 Subject: [PATCH 18/47] Remove accidentally added file --- .../tensorrt/utilities/common/exceptions.h | 71 ------------------- 1 file changed, 71 deletions(-) delete mode 100644 onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h diff --git a/onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h b/onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h deleted file mode 100644 index 494a770b8db98..0000000000000 --- a/onnxruntime/core/providers/tensorrt/utilities/common/exceptions.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "core/common/common.h" -#include "core/common/code_location.h" - -namespace onnxruntime { - -class NotImplementedException : public std::logic_error { - public: - explicit NotImplementedException(const char* _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; - explicit NotImplementedException(const std::string& _Message = "Function not yet implemented") noexcept : std::logic_error(_Message) {}; -}; - -class TypeMismatchException : public std::logic_error { - public: - TypeMismatchException() noexcept : logic_error("Type mismatch") {}; -}; - -class OnnxRuntimeException : public std::exception { - public: - OnnxRuntimeException(const CodeLocation& location, const std::string& msg) noexcept - : OnnxRuntimeException(location, nullptr, msg) { - } - - /** - Create a new exception that captures the location it was thrown from. - @param location Location in the source code the exception is being thrown from - @param failed_condition Optional string containing the condition that failed. - e.g. "tensor.Size() == input.Size()". May be nullptr. - @param msg Message containing additional information about the exception cause. - */ - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { - std::ostringstream ss; - - ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous - if (failed_condition != nullptr) { - ss << " " << failed_condition << " was false."; - } - - ss << " " << msg << "\n"; - if (!location.stacktrace.empty()) { - ss << "Stacktrace:\n"; - // skip the first entry in the stacktrace as we have that information from location.ToString() - std::copy(std::next(location.stacktrace.begin()), location.stacktrace.end(), std::ostream_iterator(ss, "\n")); - } - - what_ = ss.str(); - } - - const char* what() const noexcept override { - return what_.c_str(); - } - - private: - const CodeLocation location_; - const std::vector stacktrace_; - std::string what_; -}; - -} // namespace onnxruntime From a1d36af712a0eb38b6a7be9b2ed9de9476780152 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 9 Jan 2026 10:58:32 -0800 Subject: [PATCH 19/47] address reviewer's comments --- .../core/session/onnxruntime_ep_c_api.h | 12 ++++--- .../library/example_plugin_ep/ep_factory.cc | 33 +++++++++---------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index bece95b5cfdaf..f4f0bc7bb3326 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1565,7 +1565,7 @@ struct OrtEpFactory { */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); - /** \brief Returns the number of OrtCustomOpDomains that this factory creates. + /** \brief Returns the number of OrtCustomOpDomains that this factory provides. * * \param[in] this_ptr The OrtEpFactory instance. * \param[out] num_domains Output parameter set to the number of created OrtCustomOpDomain instances. @@ -1581,7 +1581,7 @@ struct OrtEpFactory { * This function is used when running inference on a model that contains EP-specific custom operations. * * Workflow: - * 1. The EP implements this function to supply a list of OrtCustomOpDomain instances. + * 1. The EP factory implements this function to supply a list of OrtCustomOpDomain instances. * 2. The application calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing * the plugin EP's factory. * 3. SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomain list to the @@ -1597,7 +1597,7 @@ struct OrtEpFactory { * that the custom node should NOT be fused or compiled. Instead, ORT should invoke * the custom node's Compute() function at runtime. * - * 2. A "placeholder" OrtCustomOp with an empty kernel implementation + * 2. A "placeholder" OrtCustomOp with an empty kernel implementation * - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() * does nothing. The purpose is to satisfy model validation during model loading by * registering the custom op as a valid operator in the session. @@ -1606,20 +1606,22 @@ struct OrtEpFactory { * - In Compile(), the EP executes its compiled bits to perform inference for * the fused custom node. * - * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens + * Note: The OrtCustomOpDomain instances must be valid while any session is using them. + EP factory has the responsibility to release OrtCustomOpDomain instances it creates. It happens * automatically if using ORT C++ api. * * \param[in] this_ptr The OrtEpFactory instance. * \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with OrtCustomOpDomain created by the EP. * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. + The value is returned by GetNumCustomOpDomains(). * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.24. */ ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr, - _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _In_ size_t num_domains); + _Out_writes_all_(num_domains) OrtCustomOpDomain** domains, _In_ size_t num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 9c90503bf3fc8..6e8ec148b39eb 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -71,6 +71,22 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator}; + // Custom Op Domains + custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; + custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; + + std::vector> created_custom_op_list; + created_custom_op_list.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list.back().get()->SetName("Custom_Mul"); + custom_op_domains_[0].Add(created_custom_op_list.back().get()); + + std::vector> created_custom_op_list_2; + created_custom_op_list_2.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list_2.back().get()->SetName("Custom_Mul2"); + custom_op_domains_[1].Add(created_custom_op_list_2.back().get()); + + created_custom_op_lists_[0] = std::move(created_custom_op_list); + created_custom_op_lists_[1] = std::move(created_custom_op_list_2); } /*static*/ @@ -328,29 +344,12 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetCustomOpDomainsImpl( _Out_ size_t num_domains) noexcept { auto* factory = static_cast(this_ptr); - // Custom Op Domains - factory->custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; - factory->custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; - - std::vector> created_custom_op_list; - created_custom_op_list.push_back(std::make_unique(factory->ep_name_.c_str(), factory)); - created_custom_op_list.back().get()->SetName("Custom_Mul"); - factory->custom_op_domains_[0].Add(created_custom_op_list.back().get()); - - std::vector> created_custom_op_list_2; - created_custom_op_list_2.push_back(std::make_unique(factory->ep_name_.c_str(), factory)); - created_custom_op_list_2.back().get()->SetName("Custom_Mul2"); - factory->custom_op_domains_[1].Add(created_custom_op_list_2.back().get()); - // The `num_domains` should be 2 as ORT calls GetNumCustomOpDomainsImpl() to get the number prior to // call this function. gsl::span domains_span(domains, num_domains); domains_span[0] = factory->custom_op_domains_[0]; domains_span[1] = factory->custom_op_domains_[1]; - factory->created_custom_op_lists_[0] = std::move(created_custom_op_list); - factory->created_custom_op_lists_[1] = std::move(created_custom_op_list_2); - return nullptr; } From 3065e9d3406d867b835ccceca50a7534921eda3c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 9 Jan 2026 11:06:40 -0800 Subject: [PATCH 20/47] address reviewer's comment --- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index f4f0bc7bb3326..37b8ca370a473 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1612,9 +1612,9 @@ struct OrtEpFactory { * * \param[in] this_ptr The OrtEpFactory instance. * \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with - OrtCustomOpDomain created by the EP. + OrtCustomOpDomain created by the EP. The `num_domains` is the value returned by + GetNumCustomOpDomains(). The implementation is expected to treat `domains` as a buffer. * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. - The value is returned by GetNumCustomOpDomains(). * * \snippet{doc} snippets.dox OrtStatus Return Value * From d340de5c7da29bbf1457919a26a014d116a72d53 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 9 Jan 2026 12:34:39 -0800 Subject: [PATCH 21/47] address reveiwer's comment --- onnxruntime/test/autoep/library/example_plugin_ep/ep.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 41fdbc51a0ff0..76b2502da5c3c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -210,6 +210,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); + auto domain = node.GetDomain(); if (op_type == "Mul") { // Check that Mul has inputs/output of type float @@ -242,7 +243,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG supported_nodes.push_back(node); // Only support a single Mul for now. break; - } else if (op_type == "Custom_Mul") { + } else if (op_type == "Custom_Mul" && domain == "test") { supported_nodes.push_back(node); } } From 15f5baf317c7ba4019a54d3f7fd73bc56d06b4a5 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 9 Jan 2026 12:51:39 -0800 Subject: [PATCH 22/47] update --- .../core/session/plugin_ep/ep_factory_internal_impl.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 3651f90d0ff2e..43b6f33608c8a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -101,8 +101,10 @@ class EpFactoryInternalImpl { return nullptr; } - virtual OrtStatus* GetCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** /*domains*/, - _In_ size_t /*num_domains*/) const noexcept { + virtual OrtStatus* GetCustomOpDomains(_Out_writes_all_(num_domains) OrtCustomOpDomain** domains, + _In_ size_t num_domains) const noexcept { + ORT_UNUSED_PARAMETER(domains); + ORT_UNUSED_PARAMETER(num_domains); return nullptr; } From 6b01e7f6fec1d3b137e51a0548a1e2a003aff61c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 9 Jan 2026 12:52:21 -0800 Subject: [PATCH 23/47] lintrunner -a --- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 2 +- onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h | 2 +- .../test/autoep/library/example_plugin_ep/ep_factory.cc | 3 +-- onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h | 2 +- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index edecfa3f3cb05..42a44bc082583 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1798,7 +1798,7 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); - + /** \brief Create an OrtExternalResourceImporterImpl for external resource import. * * This is used to create an external resource importer that enables zero-copy import of diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 43b6f33608c8a..2fa3f456658ac 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -95,7 +95,7 @@ class EpFactoryInternalImpl { *importer = nullptr; return nullptr; } - + virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept { *num_domains = 0; return nullptr; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 4c5666cd1a0b5..437d2afcef90d 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -38,11 +38,10 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; - + GetNumCustomOpDomains = GetNumCustomOpDomainsImpl; GetCustomOpDomains = GetCustomOpDomainsImpl; - // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index e13dfacc617ec..ca47be300e341 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -76,7 +76,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { OrtEpFactory* this_ptr, const OrtEpDevice* ep_device, OrtExternalResourceImporterImpl** out_importer) noexcept; - + static OrtStatus* ORT_API_CALL GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, _Out_ size_t* num_domains) noexcept; From adf565ea0a4f96176be19057a75aacf1c65768c4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 9 Jan 2026 13:24:09 -0800 Subject: [PATCH 24/47] fix bug when merging main --- onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index ca47be300e341..737276203826c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -84,7 +84,6 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains) noexcept; - const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID From 062280eccc7483db7b33585c04654c698dea7cfc Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 12 Jan 2026 21:46:40 -0800 Subject: [PATCH 25/47] Make auto ep selection be able to register custom op --- .../core/session/provider_policy_context.cc | 148 +++++++++--------- .../core/session/provider_policy_context.h | 4 + onnxruntime/core/session/utils.cc | 120 +++++++++----- onnxruntime/test/autoep/test_execution.cc | 17 +- 4 files changed, 178 insertions(+), 111 deletions(-) diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index aa2859985a479..713bd1f091dd9 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -50,6 +50,7 @@ bool IsDefaultCpuEp(const OrtEpDevice* d) { return d->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU && d->ep_vendor == "Microsoft"; } +} // namespace // Sort devices. NPU -> GPU -> CPU // Within in type, vendor owned, not. @@ -138,7 +139,6 @@ OrtKeyValuePairs GetModelMetadata(const InferenceSession& session) { return metadata; } -} // namespace // Select execution providers based on the device policy and available devices and add to session Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, @@ -151,76 +151,7 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const // The list of devices selected by policies std::vector devices_selected; - // Run the delegate if it was passed in lieu of any other policy - if (options.value.ep_selection_policy.delegate) { - auto model_metadata = GetModelMetadata(sess); - OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? - - std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); - std::array selected_devices{nullptr}; - size_t num_selected = 0; - - EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate; - auto* status = delegate(delegate_devices.data(), delegate_devices.size(), - &model_metadata, &runtime_metadata, - selected_devices.data(), selected_devices.size(), &num_selected, - options.value.ep_selection_policy.state); - - // return or fall-through for both these cases - // going with explicit failure for now so it's obvious to user what is happening - if (status != nullptr) { - std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy - OrtApis::ReleaseStatus(status); - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg); - } - - if (num_selected == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); - } - - if (num_selected > selected_devices.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "EP selection delegate selected too many EP devices (", num_selected, "). ", - "The limit is ", selected_devices.size(), " EP devices."); - } - - // Copy the selected devices to the output vector - devices_selected.reserve(num_selected); - for (size_t i = 0; i < num_selected; ++i) { - devices_selected.push_back(selected_devices[i]); - } - } else { - // Create the selector for the chosen policy - std::unique_ptr selector; - switch (options.value.ep_selection_policy.policy) { - case OrtExecutionProviderDevicePolicy_DEFAULT: - selector = std::make_unique(); - break; - case OrtExecutionProviderDevicePolicy_PREFER_CPU: - selector = std::make_unique(); - break; - case OrtExecutionProviderDevicePolicy_PREFER_NPU: - case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: - case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: - selector = std::make_unique(); - break; - case OrtExecutionProviderDevicePolicy_PREFER_GPU: - case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: - selector = std::make_unique(); - break; - } - - // Execute policy - - selector->SelectProvidersForDevices(execution_devices, devices_selected); - } - - // Fail if we did not find any device matches - if (devices_selected.empty()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "No execution providers selected. Please check the device policy and available devices."); - } + ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, sess)); // Log telemetry for auto EP selection { @@ -317,6 +248,81 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const return Status::OK(); } +Status ProviderPolicyContext::SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, + std::vector& devices_selected, InferenceSession& sess) { + // Run the delegate if it was passed in lieu of any other policy + if (options.value.ep_selection_policy.delegate) { + auto model_metadata = GetModelMetadata(sess); + OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? + + std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); + std::array selected_devices{nullptr}; + size_t num_selected = 0; + + EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate; + auto* status = delegate(delegate_devices.data(), delegate_devices.size(), + &model_metadata, &runtime_metadata, + selected_devices.data(), selected_devices.size(), &num_selected, + options.value.ep_selection_policy.state); + + // return or fall-through for both these cases + // going with explicit failure for now so it's obvious to user what is happening + if (status != nullptr) { + std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy + OrtApis::ReleaseStatus(status); + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg); + } + + if (num_selected == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); + } + + if (num_selected > selected_devices.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EP selection delegate selected too many EP devices (", num_selected, "). ", + "The limit is ", selected_devices.size(), " EP devices."); + } + + // Copy the selected devices to the output vector + devices_selected.reserve(num_selected); + for (size_t i = 0; i < num_selected; ++i) { + devices_selected.push_back(selected_devices[i]); + } + } else { + // Create the selector for the chosen policy + std::unique_ptr selector; + switch (options.value.ep_selection_policy.policy) { + case OrtExecutionProviderDevicePolicy_DEFAULT: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_CPU: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_NPU: + case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: + case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_GPU: + case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: + selector = std::make_unique(); + break; + } + + // Execute policy + + selector->SelectProvidersForDevices(execution_devices, devices_selected); + } + + // Fail if we did not find any device matches + if (devices_selected.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "No execution providers selected. Please check the device policy and available devices."); + } + return Status::OK(); +} + void ProviderPolicyContext::FoldSelectedDevices(std::vector devices_selected, std::vector& eps_selected) { while (devices_selected.size() > 0) { diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h index 295ac21ca4aa5..f2d3aa6f780d9 100644 --- a/onnxruntime/core/session/provider_policy_context.h +++ b/onnxruntime/core/session/provider_policy_context.h @@ -41,6 +41,8 @@ class ProviderPolicyContext { ProviderPolicyContext() = default; Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); + Status SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, + std::vector& devices_selected, InferenceSession& sess); Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); void RemoveOrtCpuDevice(std::vector& devices); Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, @@ -75,6 +77,8 @@ class PreferGpuEpPolicy : public IEpPolicySelector { std::vector& selected_devices) override; }; +std::vector OrderDevices(const std::vector& devices); + } // namespace onnxruntime #endif // !ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 699a3d89f6784..55ce7ce4badaa 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -20,6 +20,7 @@ #include "core/session/ort_apis.h" #include "core/session/ort_env.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" +#include "core/session/provider_policy_context.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/session/plugin_ep/ep_factory_internal.h" @@ -31,8 +32,8 @@ #endif // !defined(ORT_MINIMAL_BUILD) using namespace onnxruntime; -#if !defined(ORT_MINIMAL_BUILD) namespace { +#if !defined(ORT_MINIMAL_BUILD) // temporary implementation for testing. EP to 'select' is specified in config option Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, const std::string& ep_to_select) { const auto& execution_devices = env.GetOrtEpDevices(); @@ -98,9 +99,38 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con return Status::OK(); } -} // namespace #endif // !defined(ORT_MINIMAL_BUILD) +Status GetCustomOpDomainsFromOrtEpDevices(const OrtEpDevice& ep_device, InlinedVector& domains) { + // Get custom op domain provided by EP if any. + // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains were added in ORT 1.24. + OrtEpFactory* ep_factory = ep_device.ep_factory; + if (ep_factory && + ep_factory->ort_version_supported >= 24 && + ep_factory->GetNumCustomOpDomains != nullptr && + ep_factory->GetCustomOpDomains != nullptr) { + size_t num_domains = 0; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); + + domains.resize(num_domains); + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, + domains.data(), + domains.size()))); + } + + return Status::OK(); +} + +bool IsDomainExisted(const std::string& domain_name, const std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; +}; +} // namespace + common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { if (size == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "`size` argument is NULL"); @@ -195,6 +225,44 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op } #endif +#if !defined(ORT_MINIMAL_BUILD) + if (options != nullptr && options->value.ep_selection_policy.enable) { + ProviderPolicyContext context; + + // Get the list of devices from the environment and order them. + // Ordered by preference within each type. NPU -> GPU -> NPU + // TODO: Should environment.cc do the ordering? + std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); + + // The list of devices selected by policies + std::vector devices_selected; + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpDevices(*options, execution_devices, + devices_selected, *sess)); + + InlinedVector all_ep_custom_op_domains; + + for (const OrtEpDevice* ep_device : devices_selected) { + InlinedVector domains; + ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromOrtEpDevices(*ep_device, domains)); + + const auto domains_span = gsl::span(domains.data(), domains.size()); + for (auto domain : domains_span) { + if (!IsDomainExisted(domain->domain_, options->custom_op_domains_) && + domain->custom_ops_.size() > 0) { + all_ep_custom_op_domains.push_back(domain); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " + << domain->domain_ << " is already in the session option. Skip it."; + } + } + } + + if (!all_ep_custom_op_domains.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains)); + } + } +#endif + // Finish load if (load_config_from_model) { #if !defined(ORT_MINIMAL_BUILD) @@ -544,43 +612,17 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); } - // Add custom op domain provided by EP to the session options if any. - // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains - // were added in ORT 1.24. - OrtEpFactory* ep_factory = ep_device->ep_factory; - if (ep_factory && - ep_factory->ort_version_supported >= 24 && - ep_factory->GetNumCustomOpDomains != nullptr && - ep_factory->GetCustomOpDomains != nullptr) { - auto is_already_in_domains = - [&](const std::string& domain_name, const std::vector& domains) { - for (auto ptr : domains) { - if (domain_name == ptr->domain_) { - return true; - } - } - return false; - }; - - size_t num_domains = 0; - ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); - - InlinedVector domains; - domains.resize(num_domains); - - ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, - domains.data(), - domains.size()))); - - const auto domains_span = gsl::span(domains.data(), domains.size()); - for (auto domain : domains_span) { - if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) && - domain->custom_ops_.size() > 0) { - ort_session_options.custom_op_domains_.push_back(domain); - } else { - LOGS_DEFAULT(WARNING) << "The custom op domain name " - << domain->domain_ << " is already in the session option. Skip it."; - } + InlinedVector domains; + ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromOrtEpDevices(*ep_device, domains)); + + const auto domains_span = gsl::span(domains.data(), domains.size()); + for (auto domain : domains_span) { + if (!IsDomainExisted(domain->domain_, ort_session_options.custom_op_domains_) && + domain->custom_ops_.size() > 0) { + ort_session_options.custom_op_domains_.push_back(domain); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " + << domain->domain_ << " is already in the session option. Skip it."; } } } diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 6643de46fb9dd..1a408f8d37bc9 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -384,7 +384,7 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) { // Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. // Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. -TEST(OrtEpLibrary, PluginEp_Create_OrtCustomOpDomain) { +TEST(OrtEpLibrary, PluginEp_custom_op_inference_with_explicit_ep) { RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); @@ -396,5 +396,20 @@ TEST(OrtEpLibrary, PluginEp_Create_OrtCustomOpDomain) { RunCustomMulModelWithPluginEp(session_options); } + +// Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. +// Uses the PREFER_CPU policy to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_custom_op_inference_with_prefer_cpu) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + { + // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunCustomMulModelWithPluginEp(session_options); + } +} } // namespace test } // namespace onnxruntime From 002fcdc290cf2b4149f3fa51ebdc0cc741899998 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 13 Jan 2026 10:21:54 -0800 Subject: [PATCH 26/47] add comments --- .../core/session/provider_policy_context.cc | 10 ++++--- .../core/session/provider_policy_context.h | 3 ++- onnxruntime/core/session/utils.cc | 27 ++++++++++++------- onnxruntime/test/autoep/test_execution.cc | 6 ++--- 4 files changed, 30 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 713bd1f091dd9..fc87cd95d5b09 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -151,7 +151,7 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const // The list of devices selected by policies std::vector devices_selected; - ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, sess)); + ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, sess, true)); // Log telemetry for auto EP selection { @@ -249,10 +249,14 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const } Status ProviderPolicyContext::SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, - std::vector& devices_selected, InferenceSession& sess) { + std::vector& devices_selected, InferenceSession& sess, + bool model_metadata_reference) { // Run the delegate if it was passed in lieu of any other policy if (options.value.ep_selection_policy.delegate) { - auto model_metadata = GetModelMetadata(sess); + OrtKeyValuePairs model_metadata; + if (model_metadata_reference) { + model_metadata = GetModelMetadata(sess); + } OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h index f2d3aa6f780d9..5a9cc45237e90 100644 --- a/onnxruntime/core/session/provider_policy_context.h +++ b/onnxruntime/core/session/provider_policy_context.h @@ -42,7 +42,8 @@ class ProviderPolicyContext { Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); Status SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, - std::vector& devices_selected, InferenceSession& sess); + std::vector& devices_selected, InferenceSession& sess, + bool model_metadata_reference = true); Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); void RemoveOrtCpuDevice(std::vector& devices); Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 55ce7ce4badaa..e452e2e508e12 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -101,8 +101,8 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con } #endif // !defined(ORT_MINIMAL_BUILD) -Status GetCustomOpDomainsFromOrtEpDevices(const OrtEpDevice& ep_device, InlinedVector& domains) { - // Get custom op domain provided by EP if any. +Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVector& domains) { + // Get custom op domain provided by EP factory if any. // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains were added in ORT 1.24. OrtEpFactory* ep_factory = ep_device.ep_factory; if (ep_factory && @@ -226,24 +226,32 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op #endif #if !defined(ORT_MINIMAL_BUILD) - if (options != nullptr && options->value.ep_selection_policy.enable) { + // Add custom domains if the selected ep from auto ep selection has custom domains to register. + // The custom domains should be registered before model load for ORT to validate the custom ops. + if (options != nullptr && + options->provider_factories.empty() && + options->value.ep_selection_policy.enable) { ProviderPolicyContext context; - // Get the list of devices from the environment and order them. - // Ordered by preference within each type. NPU -> GPU -> NPU - // TODO: Should environment.cc do the ordering? + // Following code calls the same ep selection functions that InitializeSession() calls as well + // to get `execution_devices` and `devices_selected`. + // Note: If the selection policy is delegate, the model metadata should be provided to the delegate function. + // However, the model metadata is not known at this point as ORT hasn't loaded the model yet. So the empty + // model metadata is provided for now. + // TODO: might need to fetch model metadata from model proto. + std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); // The list of devices selected by policies std::vector devices_selected; ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpDevices(*options, execution_devices, - devices_selected, *sess)); + devices_selected, *sess, false)); InlinedVector all_ep_custom_op_domains; for (const OrtEpDevice* ep_device : devices_selected) { InlinedVector domains; - ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromOrtEpDevices(*ep_device, domains)); + ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); const auto domains_span = gsl::span(domains.data(), domains.size()); for (auto domain : domains_span) { @@ -612,8 +620,9 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); } + // Add custom domains if EP factory has any. InlinedVector domains; - ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromOrtEpDevices(*ep_device, domains)); + ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); const auto domains_span = gsl::span(domains.data(), domains.size()); for (auto domain : domains_span) { diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 1aac304b3e2ab..1480ad18f8497 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -578,10 +578,10 @@ TEST(OrtEpLibrary, KernelPluginEp_ControlFlow_Scan) { ASSERT_NO_FATAL_FAILURE(RunScanMulModel(session_options)); } } - + // Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. // Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. -TEST(OrtEpLibrary, PluginEp_custom_op_inference_with_explicit_ep) { +TEST(OrtEpLibrary, PluginEp_Custom_Op_Inference_With_Explicit_Ep) { RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); @@ -596,7 +596,7 @@ TEST(OrtEpLibrary, PluginEp_custom_op_inference_with_explicit_ep) { // Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. // Uses the PREFER_CPU policy to append the example plugin EP to the session. -TEST(OrtEpLibrary, PluginEp_custom_op_inference_with_prefer_cpu) { +TEST(OrtEpLibrary, PluginEp_Custom_Op_Inference_With_Prefer_Cpu) { RegisteredEpDeviceUniquePtr example_ep; ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); Ort::ConstEpDevice plugin_ep_device(example_ep.get()); From bb7e082b887ce65bed1242099fa0a2a687ab9f00 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 13 Jan 2026 15:26:21 -0800 Subject: [PATCH 27/47] Make code be able to get model_metadata from model during auto ep selection before model load --- onnxruntime/core/graph/model.cc | 2 +- onnxruntime/core/graph/model.h | 3 + .../core/session/provider_policy_context.cc | 34 +++++---- .../core/session/provider_policy_context.h | 7 +- onnxruntime/core/session/utils.cc | 72 ++++++++++++++++++- 5 files changed, 101 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 59878052e7499..cf8c4d1975661 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -580,7 +580,7 @@ static Status LoadModelHelper(const T& file_path, Loader loader) { } template -static Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) { +Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) { const auto loader = [&model_proto](int fd) { return Model::Load(fd, model_proto); }; diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index c86aac44806bd..2c3065be92532 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -374,4 +374,7 @@ class Model { CheckLoadCancellationFn check_load_cancellation_fn_; }; + +template +Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto); } // namespace onnxruntime diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index fc87cd95d5b09..61a6df44d12fc 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -116,16 +116,10 @@ std::vector OrderDevices(const std::vector devices_selected; - ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, sess, true)); + ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, nullptr, sess)); // Log telemetry for auto EP selection { @@ -249,13 +255,17 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const } Status ProviderPolicyContext::SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, - std::vector& devices_selected, InferenceSession& sess, - bool model_metadata_reference) { + std::vector& devices_selected, + OrtKeyValuePairs* metadata_from_model, + InferenceSession& sess) { // Run the delegate if it was passed in lieu of any other policy if (options.value.ep_selection_policy.delegate) { OrtKeyValuePairs model_metadata; - if (model_metadata_reference) { - model_metadata = GetModelMetadata(sess); + + if (metadata_from_model) { + model_metadata = *metadata_from_model; + } else { + model_metadata = GetModelMetadataFromSession(sess); } OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h index 5a9cc45237e90..87498c81cff13 100644 --- a/onnxruntime/core/session/provider_policy_context.h +++ b/onnxruntime/core/session/provider_policy_context.h @@ -7,6 +7,7 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/environment.h" +#include "core/session/inference_session.h" #include "core/session/onnxruntime_c_api.h" // For OrtExecutionProviderDevicePolicy namespace onnxruntime { @@ -42,8 +43,8 @@ class ProviderPolicyContext { Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); Status SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, - std::vector& devices_selected, InferenceSession& sess, - bool model_metadata_reference = true); + std::vector& devices_selected, OrtKeyValuePairs* model_metadata, + InferenceSession& sess); Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); void RemoveOrtCpuDevice(std::vector& devices); Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, @@ -80,6 +81,8 @@ class PreferGpuEpPolicy : public IEpPolicySelector { std::vector OrderDevices(const std::vector& devices); +OrtKeyValuePairs GetModelMetadataKeyValuePairs(const ModelMetadata& session); + } // namespace onnxruntime #endif // !ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index e452e2e508e12..61af42cdb8207 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -21,6 +21,7 @@ #include "core/session/ort_env.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/provider_policy_context.h" +#include "core/graph/model.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/session/plugin_ep/ep_factory_internal.h" @@ -128,9 +129,70 @@ bool IsDomainExisted(const std::string& domain_name, const std::vector void { + if (model_proto.has_producer_name()) { + model_metadata.producer_name = model_proto.producer_name(); + } + + if (model_proto.has_producer_version()) { + model_metadata.producer_version = model_proto.producer_version(); + } + + if (model_proto.has_doc_string()) { + model_metadata.description = model_proto.doc_string(); + } + + if (model_proto.has_graph() && model_proto.graph().has_doc_string()) { + model_metadata.graph_description = model_proto.graph().doc_string(); + } + + if (model_proto.has_domain()) { + model_metadata.domain = model_proto.domain(); + } + + if (model_proto.has_model_version()) { + model_metadata.version = model_proto.model_version(); + } + + std::unordered_map metadata; + for (auto& prop : model_proto.metadata_props()) { + metadata[prop.key()] = prop.value(); + } + model_metadata.custom_metadata_map = metadata; + + if (model_proto.has_graph() && model_proto.graph().has_name()) { + model_metadata.graph_name = model_proto.graph().name(); + } + + return; + }; + + if (model_path != nullptr) { + ONNX_NAMESPACE::ModelProto model_proto; + onnxruntime::PathString path(model_path); + ORT_RETURN_IF_ERROR(LoadModel(path, model_proto)); + get_model_metadata(model_proto, model_metadata); + } else if (model_data != nullptr && model_data_length > 0) { + ONNX_NAMESPACE::ModelProto model_proto; + const bool result = model_proto.ParseFromArray(model_data, static_cast(model_data_length)); + if (!result) { + return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF, + "Failed to load model because protobuf parsing failed."); + } + get_model_metadata(model_proto, model_metadata); + } + + return Status::OK(); +} + common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { if (size == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "`size` argument is NULL"); @@ -244,8 +306,14 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // The list of devices selected by policies std::vector devices_selected; + + ModelMetadata model_metadata; + ORT_API_RETURN_IF_STATUS_NOT_OK(GetModelMetaData(model_path, model_data, model_data_length, model_metadata)); + OrtKeyValuePairs model_metadata_key_value_pairs; + model_metadata_key_value_pairs = GetModelMetadataKeyValuePairs(model_metadata); + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpDevices(*options, execution_devices, - devices_selected, *sess, false)); + devices_selected, &model_metadata_key_value_pairs, *sess)); InlinedVector all_ep_custom_op_domains; From 953dbd3bf526f0d364dc5a729270a2848a039e51 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 13 Jan 2026 16:17:16 -0800 Subject: [PATCH 28/47] Use Model::Load --- onnxruntime/core/graph/model.h | 3 --- onnxruntime/core/session/utils.cc | 26 +++++++++++++------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 2c3065be92532..c86aac44806bd 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -374,7 +374,4 @@ class Model { CheckLoadCancellationFn check_load_cancellation_fn_; }; - -template -Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto); } // namespace onnxruntime diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 61af42cdb8207..98a1e89af8977 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -132,7 +132,7 @@ bool IsDomainExisted(const std::string& domain_name, const std::vector metadata; + ModelMetaData metadata; for (auto& prop : model_proto.metadata_props()) { metadata[prop.key()] = prop.value(); } @@ -175,21 +175,23 @@ Status static GetModelMetaData(const ORTCHAR_T* model_path, return; }; + ONNX_NAMESPACE::ModelProto model_proto; + if (model_path != nullptr) { - ONNX_NAMESPACE::ModelProto model_proto; - onnxruntime::PathString path(model_path); - ORT_RETURN_IF_ERROR(LoadModel(path, model_proto)); - get_model_metadata(model_proto, model_metadata); + PathString path(model_path); + ORT_RETURN_IF_ERROR(Model::Load(path, model_proto)); } else if (model_data != nullptr && model_data_length > 0) { - ONNX_NAMESPACE::ModelProto model_proto; const bool result = model_proto.ParseFromArray(model_data, static_cast(model_data_length)); if (!result) { return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF, "Failed to load model because protobuf parsing failed."); } - get_model_metadata(model_proto, model_metadata); + } else { + return Status::OK(); } + get_model_metadata(model_proto, model_metadata); + return Status::OK(); } @@ -297,18 +299,16 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // Following code calls the same ep selection functions that InitializeSession() calls as well // to get `execution_devices` and `devices_selected`. - // Note: If the selection policy is delegate, the model metadata should be provided to the delegate function. - // However, the model metadata is not known at this point as ORT hasn't loaded the model yet. So the empty - // model metadata is provided for now. - // TODO: might need to fetch model metadata from model proto. std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); // The list of devices selected by policies std::vector devices_selected; + // If the selection policy is delegate, the model metadata as key-value paris should be provided to + // the delegate function ModelMetadata model_metadata; - ORT_API_RETURN_IF_STATUS_NOT_OK(GetModelMetaData(model_path, model_data, model_data_length, model_metadata)); + ORT_API_RETURN_IF_STATUS_NOT_OK(GetModelMetadata(model_path, model_data, model_data_length, model_metadata)); OrtKeyValuePairs model_metadata_key_value_pairs; model_metadata_key_value_pairs = GetModelMetadataKeyValuePairs(model_metadata); From cf5948a1a1321913bad8f3ce64cbcf469a7981a7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 13 Jan 2026 16:18:51 -0800 Subject: [PATCH 29/47] revert unnecessary change --- onnxruntime/core/graph/model.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index cf8c4d1975661..59878052e7499 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -580,7 +580,7 @@ static Status LoadModelHelper(const T& file_path, Loader loader) { } template -Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) { +static Status LoadModel(const T& file_path, ONNX_NAMESPACE::ModelProto& model_proto) { const auto loader = [&model_proto](int fd) { return Model::Load(fd, model_proto); }; From cc3140845c47dc624d1344747746b10d452e399b Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 13 Jan 2026 16:29:20 -0800 Subject: [PATCH 30/47] update API comment --- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index e9a4eb35932b1..9aef7f27f8ef7 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2033,10 +2033,10 @@ struct OrtEpFactory { * * Workflow: * 1. The EP factory implements this function to supply a list of OrtCustomOpDomain instances. - * 2. The application calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing - * the plugin EP's factory. - * 3. SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomain list to the - * session options. + * 2. The application either 1) calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing + * the plugin EP's factory or 2) enables auto ep selection. + * 3. 1) SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomains to the + * session options or 2) ORT registers the OrtCustomOpDomains from the selected EP devices. * * As a result, any session created from these session options will have these custom op domains registered * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. From e2604b9d5af64d6c47550419daead54a4bd710f0 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 13 Jan 2026 16:44:47 -0800 Subject: [PATCH 31/47] fix build issue for minimal build --- onnxruntime/core/session/utils.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 98a1e89af8977..a27507cfac45a 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -132,6 +132,7 @@ bool IsDomainExisted(const std::string& domain_name, const std::vector Date: Wed, 14 Jan 2026 08:59:10 -0800 Subject: [PATCH 32/47] address reviewer's comments --- .../core/session/onnxruntime_ep_c_api.h | 12 ++--- onnxruntime/core/session/inference_session.h | 23 ++++++++++ .../core/session/provider_policy_context.cc | 37 +++++++++++---- .../core/session/provider_policy_context.h | 6 +-- onnxruntime/core/session/utils.cc | 46 +++++++++++-------- 5 files changed, 89 insertions(+), 35 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 9aef7f27f8ef7..02cb49d99a8ae 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2019,7 +2019,7 @@ struct OrtEpFactory { /** \brief Returns the number of OrtCustomOpDomains that this factory provides. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[out] num_domains Output parameter set to the number of created OrtCustomOpDomain instances. + * \param[out] num_domains Output parameter set to the number of provided OrtCustomOpDomain instances. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -2036,14 +2036,14 @@ struct OrtEpFactory { * 2. The application either 1) calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing * the plugin EP's factory or 2) enables auto ep selection. * 3. 1) SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomains to the - * session options or 2) ORT registers the OrtCustomOpDomains from the selected EP devices. + * session options or 2) ORT registers the OrtCustomOpDomains provided by the selected EP devices. * * As a result, any session created from these session options will have these custom op domains registered * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. * * Plugin EPs can provide two types of custom ops: * 1. A full OrtCustomOp with a concrete kernel implementation - * - This Example EP demonstrates this approach. + * - A Plugin EP can supply an OrtCustomOp and a corresponding CustomKernel::Compute() implementation. * - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT * that the custom node should NOT be fused or compiled. Instead, ORT should invoke * the custom node's Compute() function at runtime. @@ -2059,11 +2059,11 @@ struct OrtEpFactory { * * Note: The OrtCustomOpDomain instances must be valid while any session is using them. EP factory has the responsibility to release OrtCustomOpDomain instances it creates. It happens - * automatically if using ORT C++ api. + * automatically if using the C++ Ort::CustomOpDomain class. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with - OrtCustomOpDomain created by the EP. The `num_domains` is the value returned by + * \param[out] domains Array of `num_domains` elements pre-allocated by ORT that should be filled with + OrtCustomOpDomain instances created by the EP. The `num_domains` is the value returned by GetNumCustomOpDomains(). The implementation is expected to treat `domains` as a buffer. * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. * diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index ff54f6fa7bca0..7ad273f03ae3b 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -662,6 +662,22 @@ class InferenceSession { return session_id_; } + void SetExecutionDevices(std::vector& execution_devices) { + execution_devices_ = std::move(execution_devices); + } + + const std::vector& GetExecutionDevices() noexcept { + return execution_devices_; + } + + void SetSelectedDevices(std::vector selected_devices) { + devices_selected_ = std::move(selected_devices); + } + + const std::vector& GetSelectedDevices() noexcept { + return devices_selected_; + } + protected: #if !defined(ORT_MINIMAL_BUILD) @@ -1054,6 +1070,13 @@ class InferenceSession { // Enable nodestats collection std::optional node_stats_recorder_; #endif + + // Holds the list of devices from the environment, ordered via OrderDevices(). + // It's used for auto ep selection. + std::vector execution_devices_; + + // Holds the list of devices selected by policies. + std::vector devices_selected_; }; struct SessionIOBinding { diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 61a6df44d12fc..8699fea3fc73c 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -149,15 +149,27 @@ OrtKeyValuePairs GetModelMetadataFromSession(const InferenceSession& session) { // Select execution providers based on the device policy and available devices and add to session Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess) { - // Get the list of devices from the environment and order them. - // Ordered by preference within each type. NPU -> GPU -> NPU - // TODO: Should environment.cc do the ordering? - std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); + std::vector execution_devices; + + // Check if the sorted list of devices has cached in the session + if (!sess.GetExecutionDevices().empty()) { + execution_devices = sess.GetExecutionDevices(); + } else { + // Get the list of devices from the environment and order them. + // Ordered by preference within each type. NPU -> GPU -> NPU + // TODO: Should environment.cc do the ordering? + execution_devices = OrderDevices(env.GetOrtEpDevices()); + } // The list of devices selected by policies std::vector devices_selected; - ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, nullptr, sess)); + // Check if the list of devices has been selected and cached in the session + if (!sess.GetSelectedDevices().empty()) { + devices_selected = sess.GetSelectedDevices(); + } else { + ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, nullptr, sess)); + } // Log telemetry for auto EP selection { @@ -251,13 +263,22 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const } } + if (sess.GetExecutionDevices().empty()) { + sess.SetExecutionDevices(execution_devices); + } + + if (sess.GetSelectedDevices().empty()) { + sess.SetSelectedDevices(devices_selected); + } + return Status::OK(); } -Status ProviderPolicyContext::SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, +Status ProviderPolicyContext::SelectEpDevices(const OrtSessionOptions& options, + const std::vector& execution_devices, std::vector& devices_selected, - OrtKeyValuePairs* metadata_from_model, - InferenceSession& sess) { + const OrtKeyValuePairs* metadata_from_model, + const InferenceSession& sess) { // Run the delegate if it was passed in lieu of any other policy if (options.value.ep_selection_policy.delegate) { OrtKeyValuePairs model_metadata; diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h index 87498c81cff13..5f4cdf477f8ce 100644 --- a/onnxruntime/core/session/provider_policy_context.h +++ b/onnxruntime/core/session/provider_policy_context.h @@ -42,9 +42,9 @@ class ProviderPolicyContext { ProviderPolicyContext() = default; Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); - Status SelectEpDevices(const OrtSessionOptions& options, std::vector& execution_devices, - std::vector& devices_selected, OrtKeyValuePairs* model_metadata, - InferenceSession& sess); + Status SelectEpDevices(const OrtSessionOptions& options, const std::vector& execution_devices, + std::vector& devices_selected, const OrtKeyValuePairs* model_metadata, + const InferenceSession& sess); Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); void RemoveOrtCpuDevice(std::vector& devices); Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index a27507cfac45a..a58e57690627b 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -102,7 +102,9 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con } #endif // !defined(ORT_MINIMAL_BUILD) -Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVector& domains) { +Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVector& domains_out) { + InlinedVector domains{}; + // Get custom op domain provided by EP factory if any. // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains were added in ORT 1.24. OrtEpFactory* ep_factory = ep_device.ep_factory; @@ -114,15 +116,15 @@ Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVecto ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); domains.resize(num_domains); - ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, - domains.data(), + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, domains.data(), domains.size()))); } + domains_out = std::move(domains); return Status::OK(); } -bool IsDomainExisted(const std::string& domain_name, const std::vector& domains) { +bool DoesDomainWithNameExist(const std::string& domain_name, const std::vector& domains) { for (auto ptr : domains) { if (domain_name == ptr->domain_) { return true; @@ -130,13 +132,12 @@ bool IsDomainExisted(const std::string& domain_name, const std::vector void { if (model_proto.has_producer_name()) { @@ -172,8 +173,6 @@ Status static GetModelMetadata(const ORTCHAR_T* model_path, if (model_proto.has_graph() && model_proto.graph().has_name()) { model_metadata.graph_name = model_proto.graph().name(); } - - return; }; ONNX_NAMESPACE::ModelProto model_proto; @@ -196,6 +195,7 @@ Status static GetModelMetadata(const ORTCHAR_T* model_path, return Status::OK(); } #endif +} // namespace common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { if (size == nullptr) { @@ -299,9 +299,9 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op options->value.ep_selection_policy.enable) { ProviderPolicyContext context; - // Following code calls the same ep selection functions that InitializeSession() calls as well - // to get `execution_devices` and `devices_selected`. - + // Get the list of devices from the environment and order them. + // Ordered by preference within each type. NPU -> GPU -> NPU + // TODO: Should environment.cc do the ordering? std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); // The list of devices selected by policies @@ -310,9 +310,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op // If the selection policy is delegate, the model metadata as key-value paris should be provided to // the delegate function ModelMetadata model_metadata; - ORT_API_RETURN_IF_STATUS_NOT_OK(GetModelMetadata(model_path, model_data, model_data_length, model_metadata)); OrtKeyValuePairs model_metadata_key_value_pairs; - model_metadata_key_value_pairs = GetModelMetadataKeyValuePairs(model_metadata); + if (options->value.ep_selection_policy.delegate) { + ORT_API_RETURN_IF_STATUS_NOT_OK(GetModelMetadata(model_path, model_data, model_data_length, model_metadata)); + model_metadata_key_value_pairs = GetModelMetadataKeyValuePairs(model_metadata); + } ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpDevices(*options, execution_devices, devices_selected, &model_metadata_key_value_pairs, *sess)); @@ -325,7 +327,7 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op const auto domains_span = gsl::span(domains.data(), domains.size()); for (auto domain : domains_span) { - if (!IsDomainExisted(domain->domain_, options->custom_op_domains_) && + if (!DoesDomainWithNameExist(domain->domain_, options->custom_op_domains_) && domain->custom_ops_.size() > 0) { all_ep_custom_op_domains.push_back(domain); } else { @@ -338,6 +340,14 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op if (!all_ep_custom_op_domains.empty()) { ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains)); } + + if (!execution_devices.empty()) { + sess->SetExecutionDevices(execution_devices); + } + + if (!devices_selected.empty()) { + sess->SetSelectedDevices(devices_selected); + } } #endif @@ -696,7 +706,7 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic const auto domains_span = gsl::span(domains.data(), domains.size()); for (auto domain : domains_span) { - if (!IsDomainExisted(domain->domain_, ort_session_options.custom_op_domains_) && + if (!DoesDomainWithNameExist(domain->domain_, ort_session_options.custom_op_domains_) && domain->custom_ops_.size() > 0) { ort_session_options.custom_op_domains_.push_back(domain); } else { From 2e91855ba6c365be2a0bcda926e7c87b49c52060 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 14 Jan 2026 09:04:38 -0800 Subject: [PATCH 33/47] lintrunner -a --- onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h | 2 +- onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index f0839e663ca38..0e6e3b11a769c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -101,7 +101,7 @@ class EpFactoryInternalImpl { // Default implementation: leave details unchanged (device assumed compatible) return nullptr; } - + virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept { *num_domains = 0; return nullptr; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 1aaaffb277716..244051dd5e4d0 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -81,7 +81,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { OrtEpFactory* this_ptr, const OrtHardwareDevice* hw, OrtDeviceEpIncompatibilityDetails* details) noexcept; - + static OrtStatus* ORT_API_CALL GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, _Out_ size_t* num_domains) noexcept; From 1f448a69cd61983773728156ad6c9d534208bdde Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 14 Jan 2026 09:18:41 -0800 Subject: [PATCH 34/47] fix compile warning for minimal build --- onnxruntime/core/session/utils.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index a58e57690627b..c30a78a927243 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -100,7 +100,6 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con return Status::OK(); } -#endif // !defined(ORT_MINIMAL_BUILD) Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVector& domains_out) { InlinedVector domains{}; @@ -133,7 +132,6 @@ bool DoesDomainWithNameExist(const std::string& domain_name, const std::vector Date: Wed, 14 Jan 2026 15:51:16 -0800 Subject: [PATCH 35/47] address reviewer's comment --- onnxruntime/core/session/inference_session.h | 2 +- onnxruntime/core/session/utils.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 7ad273f03ae3b..d28aeb7c515d8 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -662,7 +662,7 @@ class InferenceSession { return session_id_; } - void SetExecutionDevices(std::vector& execution_devices) { + void SetExecutionDevices(std::vector execution_devices) { execution_devices_ = std::move(execution_devices); } diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index c30a78a927243..8dbda62628092 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -340,11 +340,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op } if (!execution_devices.empty()) { - sess->SetExecutionDevices(execution_devices); + sess->SetExecutionDevices(std::move(execution_devices)); } if (!devices_selected.empty()) { - sess->SetSelectedDevices(devices_selected); + sess->SetSelectedDevices(std::move(devices_selected)); } } #endif From 452bb264c39cc05753d509cae46069194d743d1d Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 09:45:41 -0800 Subject: [PATCH 36/47] Add AddEpCustomDomainsToSessionOptions() --- onnxruntime/core/session/onnxruntime_c_api.cc | 4 ++++ onnxruntime/core/session/utils.cc | 7 +++++++ onnxruntime/core/session/utils.h | 4 ++++ 3 files changed, 15 insertions(+) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 37f98f1b7cd76..7d367e6c9650d 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3335,6 +3335,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS ep_option_vals_span, *session_options)); + ORT_API_RETURN_IF_STATUS_NOT_OK(AddEpCustomDomainsToSessionOptions( + ep_devices_span, + *session_options)); + session_options->provider_factories.push_back(std::move(provider_factory)); return nullptr; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 8dbda62628092..466af20c32be7 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -697,7 +697,14 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); } + } + return Status::OK(); +} + +Status AddEpCustomDomainsToSessionOptions(gsl::span ep_devices, + OrtSessionOptions& ort_session_options) { + for (const OrtEpDevice* ep_device : ep_devices) { // Add custom domains if EP factory has any. InlinedVector domains; ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index da951b5cb9810..a39e7cf57303b 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -71,5 +71,9 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic gsl::span ep_options_vals, OrtSessionOptions& session_options); +// Adss EP specific custom domains to the OrtSessionOptions configuration. +Status AddEpCustomDomainsToSessionOptions(gsl::span ep_devices, + OrtSessionOptions& ort_session_options); + } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) From 96d42fbcc91c562a74a53b2b038423831ad72928 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 13:10:51 -0800 Subject: [PATCH 37/47] clean up code --- onnxruntime/core/session/utils.cc | 3 +-- onnxruntime/core/session/utils.h | 2 +- onnxruntime/python/onnxruntime_pybind_state.cc | 5 ++++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 466af20c32be7..0dd8318537834 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -673,8 +673,7 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, Status AddEpOptionsToSessionOptions(gsl::span ep_devices, gsl::span ep_option_keys, gsl::span ep_option_vals, - OrtSessionOptions& ort_session_options) { - SessionOptions& session_options = ort_session_options.value; + SessionOptions& session_options) { const size_t num_ep_options = ep_option_keys.size(); if (ep_option_vals.size() != num_ep_options) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index a39e7cf57303b..59b4d9f0944c3 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -69,7 +69,7 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, Status AddEpOptionsToSessionOptions(gsl::span ep_devices, gsl::span ep_options_keys, gsl::span ep_options_vals, - OrtSessionOptions& session_options); + SessionOptions& session_options); // Adss EP specific custom domains to the OrtSessionOptions configuration. Status AddEpCustomDomainsToSessionOptions(gsl::span ep_devices, diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f2aea061b244d..1ce320629c0f2 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1346,7 +1346,10 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(ep_devices, ep_option_keys, ep_option_vals, - py_sess_options)); + py_sess_options.value)); + + ORT_RETURN_IF_ERROR(AddEpCustomDomainsToSessionOptions(ep_devices, + py_sess_options.value)); py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); From 04b75e87fe495cfb021c780bf60a516fabf2731f Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 13:25:52 -0800 Subject: [PATCH 38/47] clean up code and fix compile error --- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 7d367e6c9650d..ed492d8984b1a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3333,7 +3333,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS ep_devices_span, ep_option_keys_span, ep_option_vals_span, - *session_options)); + session_options->value)); ORT_API_RETURN_IF_STATUS_NOT_OK(AddEpCustomDomainsToSessionOptions( ep_devices_span, From 84cff1fccc1d9fecb150aa79ce1026a645b044d7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 13:44:08 -0800 Subject: [PATCH 39/47] revert auto ep selection --- onnxruntime/core/session/inference_session.h | 32 --- .../core/session/provider_policy_context.cc | 201 +++++++----------- .../core/session/provider_policy_context.h | 8 - onnxruntime/core/session/utils.cc | 101 +-------- 4 files changed, 84 insertions(+), 258 deletions(-) diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index d28aeb7c515d8..8bea15c169ed4 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -484,15 +484,6 @@ class InferenceSession { * This is required for a user to know the location of the input/output when autoep selection is enabled. */ common::Status GetEpDeviceForInputs(InlinedVector& memory_info) const; - - /** - * Get the OrtEpDevice (if available) for the outputs of the model. - * - * This is required for a user to validate that outputs will be placed on the expected device - * for external resource sharing. - */ - common::Status GetEpDeviceForOutputs(InlinedVector& memory_info) const; - /** * Get the current number of in-progress concurrent Run calls. */ @@ -662,22 +653,6 @@ class InferenceSession { return session_id_; } - void SetExecutionDevices(std::vector execution_devices) { - execution_devices_ = std::move(execution_devices); - } - - const std::vector& GetExecutionDevices() noexcept { - return execution_devices_; - } - - void SetSelectedDevices(std::vector selected_devices) { - devices_selected_ = std::move(selected_devices); - } - - const std::vector& GetSelectedDevices() noexcept { - return devices_selected_; - } - protected: #if !defined(ORT_MINIMAL_BUILD) @@ -1070,13 +1045,6 @@ class InferenceSession { // Enable nodestats collection std::optional node_stats_recorder_; #endif - - // Holds the list of devices from the environment, ordered via OrderDevices(). - // It's used for auto ep selection. - std::vector execution_devices_; - - // Holds the list of devices selected by policies. - std::vector devices_selected_; }; struct SessionIOBinding { diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 8699fea3fc73c..aa2859985a479 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -50,7 +50,6 @@ bool IsDefaultCpuEp(const OrtEpDevice* d) { return d->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU && d->ep_vendor == "Microsoft"; } -} // namespace // Sort devices. NPU -> GPU -> CPU // Within in type, vendor owned, not. @@ -116,10 +115,16 @@ std::vector OrderDevices(const std::vector execution_devices; - - // Check if the sorted list of devices has cached in the session - if (!sess.GetExecutionDevices().empty()) { - execution_devices = sess.GetExecutionDevices(); - } else { - // Get the list of devices from the environment and order them. - // Ordered by preference within each type. NPU -> GPU -> NPU - // TODO: Should environment.cc do the ordering? - execution_devices = OrderDevices(env.GetOrtEpDevices()); - } + // Get the list of devices from the environment and order them. + // Ordered by preference within each type. NPU -> GPU -> NPU + // TODO: Should environment.cc do the ordering? + std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); // The list of devices selected by policies std::vector devices_selected; - // Check if the list of devices has been selected and cached in the session - if (!sess.GetSelectedDevices().empty()) { - devices_selected = sess.GetSelectedDevices(); + // Run the delegate if it was passed in lieu of any other policy + if (options.value.ep_selection_policy.delegate) { + auto model_metadata = GetModelMetadata(sess); + OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? + + std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); + std::array selected_devices{nullptr}; + size_t num_selected = 0; + + EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate; + auto* status = delegate(delegate_devices.data(), delegate_devices.size(), + &model_metadata, &runtime_metadata, + selected_devices.data(), selected_devices.size(), &num_selected, + options.value.ep_selection_policy.state); + + // return or fall-through for both these cases + // going with explicit failure for now so it's obvious to user what is happening + if (status != nullptr) { + std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy + OrtApis::ReleaseStatus(status); + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg); + } + + if (num_selected == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); + } + + if (num_selected > selected_devices.size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EP selection delegate selected too many EP devices (", num_selected, "). ", + "The limit is ", selected_devices.size(), " EP devices."); + } + + // Copy the selected devices to the output vector + devices_selected.reserve(num_selected); + for (size_t i = 0; i < num_selected; ++i) { + devices_selected.push_back(selected_devices[i]); + } } else { - ORT_RETURN_IF_ERROR(SelectEpDevices(options, execution_devices, devices_selected, nullptr, sess)); + // Create the selector for the chosen policy + std::unique_ptr selector; + switch (options.value.ep_selection_policy.policy) { + case OrtExecutionProviderDevicePolicy_DEFAULT: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_CPU: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_NPU: + case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: + case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_GPU: + case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: + selector = std::make_unique(); + break; + } + + // Execute policy + + selector->SelectProvidersForDevices(execution_devices, devices_selected); + } + + // Fail if we did not find any device matches + if (devices_selected.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "No execution providers selected. Please check the device policy and available devices."); } // Log telemetry for auto EP selection @@ -263,98 +314,6 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const } } - if (sess.GetExecutionDevices().empty()) { - sess.SetExecutionDevices(execution_devices); - } - - if (sess.GetSelectedDevices().empty()) { - sess.SetSelectedDevices(devices_selected); - } - - return Status::OK(); -} - -Status ProviderPolicyContext::SelectEpDevices(const OrtSessionOptions& options, - const std::vector& execution_devices, - std::vector& devices_selected, - const OrtKeyValuePairs* metadata_from_model, - const InferenceSession& sess) { - // Run the delegate if it was passed in lieu of any other policy - if (options.value.ep_selection_policy.delegate) { - OrtKeyValuePairs model_metadata; - - if (metadata_from_model) { - model_metadata = *metadata_from_model; - } else { - model_metadata = GetModelMetadataFromSession(sess); - } - OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? - - std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); - std::array selected_devices{nullptr}; - size_t num_selected = 0; - - EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate; - auto* status = delegate(delegate_devices.data(), delegate_devices.size(), - &model_metadata, &runtime_metadata, - selected_devices.data(), selected_devices.size(), &num_selected, - options.value.ep_selection_policy.state); - - // return or fall-through for both these cases - // going with explicit failure for now so it's obvious to user what is happening - if (status != nullptr) { - std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy - OrtApis::ReleaseStatus(status); - - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg); - } - - if (num_selected == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); - } - - if (num_selected > selected_devices.size()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "EP selection delegate selected too many EP devices (", num_selected, "). ", - "The limit is ", selected_devices.size(), " EP devices."); - } - - // Copy the selected devices to the output vector - devices_selected.reserve(num_selected); - for (size_t i = 0; i < num_selected; ++i) { - devices_selected.push_back(selected_devices[i]); - } - } else { - // Create the selector for the chosen policy - std::unique_ptr selector; - switch (options.value.ep_selection_policy.policy) { - case OrtExecutionProviderDevicePolicy_DEFAULT: - selector = std::make_unique(); - break; - case OrtExecutionProviderDevicePolicy_PREFER_CPU: - selector = std::make_unique(); - break; - case OrtExecutionProviderDevicePolicy_PREFER_NPU: - case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: - case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: - selector = std::make_unique(); - break; - case OrtExecutionProviderDevicePolicy_PREFER_GPU: - case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: - selector = std::make_unique(); - break; - } - - // Execute policy - - selector->SelectProvidersForDevices(execution_devices, devices_selected); - } - - // Fail if we did not find any device matches - if (devices_selected.empty()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "No execution providers selected. Please check the device policy and available devices."); - } return Status::OK(); } diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h index 5f4cdf477f8ce..295ac21ca4aa5 100644 --- a/onnxruntime/core/session/provider_policy_context.h +++ b/onnxruntime/core/session/provider_policy_context.h @@ -7,7 +7,6 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/environment.h" -#include "core/session/inference_session.h" #include "core/session/onnxruntime_c_api.h" // For OrtExecutionProviderDevicePolicy namespace onnxruntime { @@ -42,9 +41,6 @@ class ProviderPolicyContext { ProviderPolicyContext() = default; Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); - Status SelectEpDevices(const OrtSessionOptions& options, const std::vector& execution_devices, - std::vector& devices_selected, const OrtKeyValuePairs* model_metadata, - const InferenceSession& sess); Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); void RemoveOrtCpuDevice(std::vector& devices); Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, @@ -79,10 +75,6 @@ class PreferGpuEpPolicy : public IEpPolicySelector { std::vector& selected_devices) override; }; -std::vector OrderDevices(const std::vector& devices); - -OrtKeyValuePairs GetModelMetadataKeyValuePairs(const ModelMetadata& session); - } // namespace onnxruntime #endif // !ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 0dd8318537834..ca126b0c19693 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -20,8 +20,6 @@ #include "core/session/ort_apis.h" #include "core/session/ort_env.h" #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" -#include "core/session/provider_policy_context.h" -#include "core/graph/model.h" #if !defined(ORT_MINIMAL_BUILD) #include "core/session/plugin_ep/ep_factory_internal.h" @@ -33,8 +31,8 @@ #endif // !defined(ORT_MINIMAL_BUILD) using namespace onnxruntime; -namespace { #if !defined(ORT_MINIMAL_BUILD) +namespace { // temporary implementation for testing. EP to 'select' is specified in config option Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, const std::string& ep_to_select) { const auto& execution_devices = env.GetOrtEpDevices(); @@ -131,69 +129,8 @@ bool DoesDomainWithNameExist(const std::string& domain_name, const std::vector void { - if (model_proto.has_producer_name()) { - model_metadata.producer_name = model_proto.producer_name(); - } - - if (model_proto.has_producer_version()) { - model_metadata.producer_version = model_proto.producer_version(); - } - - if (model_proto.has_doc_string()) { - model_metadata.description = model_proto.doc_string(); - } - - if (model_proto.has_graph() && model_proto.graph().has_doc_string()) { - model_metadata.graph_description = model_proto.graph().doc_string(); - } - - if (model_proto.has_domain()) { - model_metadata.domain = model_proto.domain(); - } - - if (model_proto.has_model_version()) { - model_metadata.version = model_proto.model_version(); - } - - ModelMetaData metadata; - for (auto& prop : model_proto.metadata_props()) { - metadata[prop.key()] = prop.value(); - } - model_metadata.custom_metadata_map = metadata; - - if (model_proto.has_graph() && model_proto.graph().has_name()) { - model_metadata.graph_name = model_proto.graph().name(); - } - }; - - ONNX_NAMESPACE::ModelProto model_proto; - - if (model_path != nullptr) { - PathString path(model_path); - ORT_RETURN_IF_ERROR(Model::Load(path, model_proto)); - } else if (model_data != nullptr && model_data_length > 0) { - const bool result = model_proto.ParseFromArray(model_data, static_cast(model_data_length)); - if (!result) { - return Status(common::ONNXRUNTIME, common::INVALID_PROTOBUF, - "Failed to load model because protobuf parsing failed."); - } - } else { - return Status::OK(); - } - - get_model_metadata(model_proto, model_metadata); - - return Status::OK(); -} -#endif // !defined(ORT_MINIMAL_BUILD) } // namespace +#endif // !defined(ORT_MINIMAL_BUILD) common::Status CopyStringToOutputArg(std::string_view str, const char* err_msg, char* out, size_t* size) { if (size == nullptr) { @@ -290,36 +227,14 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op #endif #if !defined(ORT_MINIMAL_BUILD) - // Add custom domains if the selected ep from auto ep selection has custom domains to register. + // Add custom domains for all OrtEpDevice instances to inference session. // The custom domains should be registered before model load for ORT to validate the custom ops. if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) { - ProviderPolicyContext context; - - // Get the list of devices from the environment and order them. - // Ordered by preference within each type. NPU -> GPU -> NPU - // TODO: Should environment.cc do the ordering? - std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); - - // The list of devices selected by policies - std::vector devices_selected; - - // If the selection policy is delegate, the model metadata as key-value paris should be provided to - // the delegate function - ModelMetadata model_metadata; - OrtKeyValuePairs model_metadata_key_value_pairs; - if (options->value.ep_selection_policy.delegate) { - ORT_API_RETURN_IF_STATUS_NOT_OK(GetModelMetadata(model_path, model_data, model_data_length, model_metadata)); - model_metadata_key_value_pairs = GetModelMetadataKeyValuePairs(model_metadata); - } - - ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpDevices(*options, execution_devices, - devices_selected, &model_metadata_key_value_pairs, *sess)); - InlinedVector all_ep_custom_op_domains; - for (const OrtEpDevice* ep_device : devices_selected) { + for (const OrtEpDevice* ep_device : env.GetOrtEpDevices()) { InlinedVector domains; ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); @@ -338,14 +253,6 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op if (!all_ep_custom_op_domains.empty()) { ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains)); } - - if (!execution_devices.empty()) { - sess->SetExecutionDevices(std::move(execution_devices)); - } - - if (!devices_selected.empty()) { - sess->SetSelectedDevices(std::move(devices_selected)); - } } #endif From 32e2e577759413fe5549e25fd5524e44d5fef4b4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 13:56:38 -0800 Subject: [PATCH 40/47] add back accidentaly removed code --- onnxruntime/core/session/inference_session.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 8bea15c169ed4..ff54f6fa7bca0 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -484,6 +484,15 @@ class InferenceSession { * This is required for a user to know the location of the input/output when autoep selection is enabled. */ common::Status GetEpDeviceForInputs(InlinedVector& memory_info) const; + + /** + * Get the OrtEpDevice (if available) for the outputs of the model. + * + * This is required for a user to validate that outputs will be placed on the expected device + * for external resource sharing. + */ + common::Status GetEpDeviceForOutputs(InlinedVector& memory_info) const; + /** * Get the current number of in-progress concurrent Run calls. */ From b80b451a4c7d809dd34cfff4268eda86b09152f7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 16:41:15 -0800 Subject: [PATCH 41/47] address reviewer's comments --- onnxruntime/core/session/utils.cc | 33 +++++++++++++++---- .../python/onnxruntime_pybind_state.cc | 2 +- .../unittest_util/test_dynamic_plugin_ep.cc | 2 +- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index ca126b0c19693..6e92f2ccd9bf7 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -617,17 +617,36 @@ Status AddEpCustomDomainsToSessionOptions(gsl::span ep const auto domains_span = gsl::span(domains.data(), domains.size()); for (auto domain : domains_span) { - if (!DoesDomainWithNameExist(domain->domain_, ort_session_options.custom_op_domains_) && - domain->custom_ops_.size() > 0) { + const bool has_custom_ops = !domain->custom_ops_.empty(); + const bool domain_name_exists = + DoesDomainWithNameExist(domain->domain_, ort_session_options.custom_op_domains_); + + // new domain + has ops => add it to session options. + if (!domain_name_exists && has_custom_ops) { ort_session_options.custom_op_domains_.push_back(domain); - } else { - LOGS_DEFAULT(WARNING) << "The custom op domain name " - << domain->domain_ << " is already in the session option. Skip it."; + continue; + } + + // Everything else is a skip; log a reason. + if (!has_custom_ops) { + if (domain_name_exists) { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain->domain_ + << "': domain already exists in session options and this domain " + << "provides no custom ops."; + } else { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain->domain_ + << "': no custom ops provided."; + } + continue; } + + // has_custom_ops && domain_name_exists + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain->domain_ + << "': domain already exists in session options."; } - } - return Status::OK(); + return Status::OK(); + } } #endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 1ce320629c0f2..118768e4c38dc 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1349,7 +1349,7 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, py_sess_options.value)); ORT_RETURN_IF_ERROR(AddEpCustomDomainsToSessionOptions(ep_devices, - py_sess_options.value)); + py_sess_options.GetOrtSessionOptions())); py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc index ec9c5f7f0397f..fd2cf2f712628 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc @@ -186,7 +186,7 @@ std::unique_ptr MakeEp(const logging::Logger* logger) { ORT_THROW_IF_ERROR(AddEpOptionsToSessionOptions(state.selected_c_ep_devices, default_ep_option_key_cstrs, default_ep_option_value_cstrs, - ort_session_options)); + ort_session_options.value)); return state.ep_factory->CreateProvider(ort_session_options, *logger->ToExternal()); } From 0b7302e5712e5f12c1b8f7fb396d8c32efbb2315 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 16:53:03 -0800 Subject: [PATCH 42/47] update --- onnxruntime/core/session/utils.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 6e92f2ccd9bf7..cc1ad9ca0e986 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -644,9 +644,8 @@ Status AddEpCustomDomainsToSessionOptions(gsl::span ep LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain->domain_ << "': domain already exists in session options."; } - - return Status::OK(); } + return Status::OK(); } #endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime From be94f18923f35bdb0a460301775b376931fb6ff7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 15 Jan 2026 22:55:27 -0800 Subject: [PATCH 43/47] fix compile error for onnxruntime_pybind_state.cc --- onnxruntime/python/onnxruntime_pybind_state.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 118768e4c38dc..0a5cb812be106 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1349,7 +1349,7 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, py_sess_options.value)); ORT_RETURN_IF_ERROR(AddEpCustomDomainsToSessionOptions(ep_devices, - py_sess_options.GetOrtSessionOptions())); + py_sess_options)); py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); From 1841117ebacb9046de9ccdf21f735d33331df245 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 16 Jan 2026 09:28:06 -0800 Subject: [PATCH 44/47] address reveiwer's comment --- onnxruntime/core/session/utils.cc | 59 +++++++++++++++---------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index cc1ad9ca0e986..78a326c8b0e48 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -121,7 +121,7 @@ Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVecto return Status::OK(); } -bool DoesDomainWithNameExist(const std::string& domain_name, const std::vector& domains) { +bool DoesDomainWithNameExist(const std::string& domain_name, gsl::span domains) { for (auto ptr : domains) { if (domain_name == ptr->domain_) { return true; @@ -129,6 +129,27 @@ bool DoesDomainWithNameExist(const std::string& domain_name, const std::vector existing_domains) { + if (!domain_to_add) { + return false; + } + + if (domain_to_add->custom_ops_.size() == 0) { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain_to_add->domain_ + << "': custom ops is empty."; + return false; + } + + if (DoesDomainWithNameExist(domain_to_add->domain_, existing_domains)) { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain_to_add->domain_ + << "': domain already exists in session options."; + return false; + } + + return true; +} } // namespace #endif // !defined(ORT_MINIMAL_BUILD) @@ -239,13 +260,11 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); const auto domains_span = gsl::span(domains.data(), domains.size()); + const auto existing_domains = gsl::span(options->custom_op_domains_.data(), + options->custom_op_domains_.size()); for (auto domain : domains_span) { - if (!DoesDomainWithNameExist(domain->domain_, options->custom_op_domains_) && - domain->custom_ops_.size() > 0) { + if (ShouldAddDomain(domain, existing_domains)) { all_ep_custom_op_domains.push_back(domain); - } else { - LOGS_DEFAULT(WARNING) << "The custom op domain name " - << domain->domain_ << " is already in the session option. Skip it."; } } } @@ -616,35 +635,15 @@ Status AddEpCustomDomainsToSessionOptions(gsl::span ep ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); const auto domains_span = gsl::span(domains.data(), domains.size()); + const auto existing_domains = gsl::span(ort_session_options.custom_op_domains_.data(), + ort_session_options.custom_op_domains_.size()); for (auto domain : domains_span) { - const bool has_custom_ops = !domain->custom_ops_.empty(); - const bool domain_name_exists = - DoesDomainWithNameExist(domain->domain_, ort_session_options.custom_op_domains_); - - // new domain + has ops => add it to session options. - if (!domain_name_exists && has_custom_ops) { + if (ShouldAddDomain(domain, existing_domains)) { ort_session_options.custom_op_domains_.push_back(domain); - continue; } - - // Everything else is a skip; log a reason. - if (!has_custom_ops) { - if (domain_name_exists) { - LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain->domain_ - << "': domain already exists in session options and this domain " - << "provides no custom ops."; - } else { - LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain->domain_ - << "': no custom ops provided."; - } - continue; - } - - // has_custom_ops && domain_name_exists - LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain->domain_ - << "': domain already exists in session options."; } } + return Status::OK(); } #endif // !defined(ORT_MINIMAL_BUILD) From 6a571ef9d14c091f0ede054f7fcdd35d0f6f66c2 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 16 Jan 2026 11:18:18 -0800 Subject: [PATCH 45/47] address reviewer's comments --- .../core/session/onnxruntime_ep_c_api.h | 5 +++-- onnxruntime/core/session/utils.cc | 14 ++++---------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 1d2d18417dd77..7d34bae267404 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2069,7 +2069,8 @@ struct OrtEpFactory { * 2. The application either 1) calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing * the plugin EP's factory or 2) enables auto ep selection. * 3. 1) SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomains to the - * session options or 2) ORT registers the OrtCustomOpDomains provided by the selected EP devices. + * session options or 2) ORT registers the OrtCustomOpDomains provided by the EP devices + * that could be potentially selected. * * As a result, any session created from these session options will have these custom op domains registered * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. @@ -2097,7 +2098,7 @@ struct OrtEpFactory { * \param[in] this_ptr The OrtEpFactory instance. * \param[out] domains Array of `num_domains` elements pre-allocated by ORT that should be filled with OrtCustomOpDomain instances created by the EP. The `num_domains` is the value returned by - GetNumCustomOpDomains(). The implementation is expected to treat `domains` as a buffer. + GetNumCustomOpDomains(). * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. * * \snippet{doc} snippets.dox OrtStatus Return Value diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 78a326c8b0e48..9bed045bb609f 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -259,11 +259,8 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op InlinedVector domains; ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); - const auto domains_span = gsl::span(domains.data(), domains.size()); - const auto existing_domains = gsl::span(options->custom_op_domains_.data(), - options->custom_op_domains_.size()); - for (auto domain : domains_span) { - if (ShouldAddDomain(domain, existing_domains)) { + for (auto domain : domains) { + if (ShouldAddDomain(domain, options->custom_op_domains_)) { all_ep_custom_op_domains.push_back(domain); } } @@ -634,11 +631,8 @@ Status AddEpCustomDomainsToSessionOptions(gsl::span ep InlinedVector domains; ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); - const auto domains_span = gsl::span(domains.data(), domains.size()); - const auto existing_domains = gsl::span(ort_session_options.custom_op_domains_.data(), - ort_session_options.custom_op_domains_.size()); - for (auto domain : domains_span) { - if (ShouldAddDomain(domain, existing_domains)) { + for (auto domain : domains) { + if (ShouldAddDomain(domain, ort_session_options.custom_op_domains_)) { ort_session_options.custom_op_domains_.push_back(domain); } } From 3b2e5a51d71591a7e322bb166f20394053fb2587 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 16 Jan 2026 21:14:00 -0800 Subject: [PATCH 46/47] address Copilot comment --- onnxruntime/test/testdata/custom_mul.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/test/testdata/custom_mul.py b/onnxruntime/test/testdata/custom_mul.py index c8fd8b0b720a3..e29e57bbb579e 100644 --- a/onnxruntime/test/testdata/custom_mul.py +++ b/onnxruntime/test/testdata/custom_mul.py @@ -1,10 +1,9 @@ import onnx -from onnx import TensorProto, helper def create_custom_mul_model(): # === Inputs === - x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) + x = onnx.helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [3, 2]) # === Output === From 5d0b15bd9e9a283b0aa318a061f58e5511895f15 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Fri, 16 Jan 2026 21:22:57 -0800 Subject: [PATCH 47/47] address Copilot comment --- onnxruntime/test/testdata/custom_mul.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/testdata/custom_mul.py b/onnxruntime/test/testdata/custom_mul.py index e29e57bbb579e..2639648561fe1 100644 --- a/onnxruntime/test/testdata/custom_mul.py +++ b/onnxruntime/test/testdata/custom_mul.py @@ -3,15 +3,15 @@ def create_custom_mul_model(): # === Inputs === - x = onnx.helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) - w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [3, 2]) + x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2]) + w = onnx.helper.make_tensor_value_info("W", onnx.TensorProto.FLOAT, [3, 2]) # === Output === - y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2]) + y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [3, 2]) # === Custom Node: Custom_Mul === # Replace "Mul" with your custom op name and domain - custom_node = helper.make_node( + custom_node = onnx.helper.make_node( op_type="Custom_Mul", # <-- custom op name inputs=["X", "W"], outputs=["Y"], @@ -19,7 +19,7 @@ def create_custom_mul_model(): ) # === Graph === - graph = helper.make_graph( + graph = onnx.helper.make_graph( nodes=[custom_node], name="CustomMulGraph", inputs=[x, w], @@ -27,11 +27,11 @@ def create_custom_mul_model(): ) # === Model (opset version 13 or later is fine) === - model = helper.make_model( + model = onnx.helper.make_model( graph, opset_imports=[ - helper.make_opsetid("", 13), # standard ONNX domain - helper.make_opsetid("com.example", 1), + onnx.helper.make_opsetid("", 13), # standard ONNX domain + onnx.helper.make_opsetid("com.example", 1), ], # your custom domain producer_name="custom_mul_builder", )