diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index b64e13531c260..7d34bae267404 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2048,6 +2048,65 @@ struct OrtEpFactory { ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr, _In_ const OrtEpDevice* ep_device, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer); + + /** \brief Returns the number of OrtCustomOpDomains that this factory provides. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] num_domains Output parameter set to the number of provided OrtCustomOpDomain instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains); + + /** \brief Gets the EP-specific OrtCustomOpDomains. + * + * This function is used when running inference on a model that contains EP-specific custom operations. + * + * Workflow: + * 1. The EP factory implements this function to supply a list of OrtCustomOpDomain instances. + * 2. The application either 1) calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing + * the plugin EP's factory or 2) enables auto ep selection. + * 3. 1) SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomains to the + * session options or 2) ORT registers the OrtCustomOpDomains provided by the EP devices + * that could be potentially selected. + * + * As a result, any session created from these session options will have these custom op domains registered + * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. + * + * Plugin EPs can provide two types of custom ops: + * 1. A full OrtCustomOp with a concrete kernel implementation + * - A Plugin EP can supply an OrtCustomOp and a corresponding CustomKernel::Compute() implementation. + * - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT + * that the custom node should NOT be fused or compiled. Instead, ORT should invoke + * the custom node's Compute() function at runtime. + * + * 2. A "placeholder" OrtCustomOp with an empty kernel implementation + * - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() + * does nothing. The purpose is to satisfy model validation during model loading by + * registering the custom op as a valid operator in the session. + * - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to + * notify ORT that this custom node should be fused and compiled by the EP. + * - In Compile(), the EP executes its compiled bits to perform inference for + * the fused custom node. + * + * Note: The OrtCustomOpDomain instances must be valid while any session is using them. + EP factory has the responsibility to release OrtCustomOpDomain instances it creates. It happens + * automatically if using the C++ Ort::CustomOpDomain class. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] domains Array of `num_domains` elements pre-allocated by ORT that should be filled with + OrtCustomOpDomain instances created by the EP. The `num_domains` is the value returned by + GetNumCustomOpDomains(). + * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr, + _Out_writes_all_(num_domains) OrtCustomOpDomain** domains, _In_ size_t num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index afb17f867fc00..810adeadf5f29 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3390,6 +3390,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS ep_option_vals_span, session_options->value)); + ORT_API_RETURN_IF_STATUS_NOT_OK(AddEpCustomDomainsToSessionOptions( + ep_devices_span, + *session_options)); + session_options->provider_factories.push_back(std::move(provider_factory)); return nullptr; diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index f562ee73f2aaa..01f7bc67a522e 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -97,6 +97,18 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept { + *num_domains = 0; + return nullptr; + } + + virtual OrtStatus* GetCustomOpDomains(_Out_writes_all_(num_domains) OrtCustomOpDomain** domains, + _In_ size_t num_domains) const noexcept { + ORT_UNUSED_PARAMETER(domains); + ORT_UNUSED_PARAMETER(num_domains); + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 944e83d8cad66..9bed045bb609f 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -98,6 +98,58 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con return Status::OK(); } + +Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVector& domains_out) { + InlinedVector domains{}; + + // Get custom op domain provided by EP factory if any. + // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains were added in ORT 1.24. + OrtEpFactory* ep_factory = ep_device.ep_factory; + if (ep_factory && + ep_factory->ort_version_supported >= 24 && + ep_factory->GetNumCustomOpDomains != nullptr && + ep_factory->GetCustomOpDomains != nullptr) { + size_t num_domains = 0; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); + + domains.resize(num_domains); + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, domains.data(), + domains.size()))); + } + + domains_out = std::move(domains); + return Status::OK(); +} + +bool DoesDomainWithNameExist(const std::string& domain_name, gsl::span domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; +} + +bool ShouldAddDomain(const OrtCustomOpDomain* domain_to_add, + gsl::span existing_domains) { + if (!domain_to_add) { + return false; + } + + if (domain_to_add->custom_ops_.size() == 0) { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain_to_add->domain_ + << "': custom ops is empty."; + return false; + } + + if (DoesDomainWithNameExist(domain_to_add->domain_, existing_domains)) { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain_to_add->domain_ + << "': domain already exists in session options."; + return false; + } + + return true; +} } // namespace #endif // !defined(ORT_MINIMAL_BUILD) @@ -195,6 +247,31 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op } #endif +#if !defined(ORT_MINIMAL_BUILD) + // Add custom domains for all OrtEpDevice instances to inference session. + // The custom domains should be registered before model load for ORT to validate the custom ops. + if (options != nullptr && + options->provider_factories.empty() && + options->value.ep_selection_policy.enable) { + InlinedVector all_ep_custom_op_domains; + + for (const OrtEpDevice* ep_device : env.GetOrtEpDevices()) { + InlinedVector domains; + ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); + + for (auto domain : domains) { + if (ShouldAddDomain(domain, options->custom_op_domains_)) { + all_ep_custom_op_domains.push_back(domain); + } + } + } + + if (!all_ep_custom_op_domains.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains)); + } + } +#endif + // Finish load if (load_config_from_model) { #if !defined(ORT_MINIMAL_BUILD) @@ -546,5 +623,22 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic return Status::OK(); } + +Status AddEpCustomDomainsToSessionOptions(gsl::span ep_devices, + OrtSessionOptions& ort_session_options) { + for (const OrtEpDevice* ep_device : ep_devices) { + // Add custom domains if EP factory has any. + InlinedVector domains; + ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); + + for (auto domain : domains) { + if (ShouldAddDomain(domain, ort_session_options.custom_op_domains_)) { + ort_session_options.custom_op_domains_.push_back(domain); + } + } + } + + return Status::OK(); +} #endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 2ccd4d464a261..59b4d9f0944c3 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -71,5 +71,9 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic gsl::span ep_options_vals, SessionOptions& session_options); +// Adss EP specific custom domains to the OrtSessionOptions configuration. +Status AddEpCustomDomainsToSessionOptions(gsl::span ep_devices, + OrtSessionOptions& ort_session_options); + } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f0d8906d99c14..0a5cb812be106 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1348,6 +1348,9 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, ep_option_vals, py_sess_options.value)); + ORT_RETURN_IF_ERROR(AddEpCustomDomainsToSessionOptions(ep_devices, + py_sess_options)); + py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index bce9b59ff0ea4..76b2502da5c3c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -15,117 +15,97 @@ #include "ep_factory.h" #include "ep_stream_support.h" -/// -/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. -/// -struct MulKernel { - MulKernel(const OrtApi& ort_api, const OrtLogger& logger, - const std::unordered_map& float_initializers, - std::string input0_name, std::string input1_name) - : ort_api(ort_api), - logger(logger), - float_initializers(float_initializers), - input0_name(input0_name), - input1_name(input1_name) {} - - const FloatInitializer* TryGetSavedInitializer(const std::string& name) const { - auto iter = float_initializers.find(name); - return iter != float_initializers.end() ? &iter->second : nullptr; - } +const FloatInitializer* MulKernel::TryGetSavedInitializer(const std::string& name) const { + auto iter = float_initializers.find(name); + return iter != float_initializers.end() ? &iter->second : nullptr; +} - void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) const { - Ort::ConstValue input = kernel_context.GetInput(index); - auto type_shape = input.GetTensorTypeAndShapeInfo(); +void MulKernel::GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); - ONNXTensorElementDataType elem_type = type_shape.GetElementType(); - if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) - throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); - const float* float_data = input.GetTensorData(); - size_t num_elems = type_shape.GetElementCount(); - data = gsl::span(float_data, num_elems); - shape = type_shape.GetShape(); - } - - OrtStatus* Compute(OrtKernelContext* kernel_ctx) { - RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - Ort::KernelContext kernel_context(kernel_ctx); - try { - gsl::span input0; - gsl::span input1; - std::vector shape0; - std::vector shape1; - - size_t num_inputs = kernel_context.GetInputCount(); - - if (num_inputs == 2) { - // Both inputs are non-constant. Get them from ORT's KernelContext. - GetInputDataAndShape(kernel_context, 0, input0, shape0); - GetInputDataAndShape(kernel_context, 1, input1, shape1); - } else if (num_inputs == 1) { - // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. - // Get the constant input from the initializers saved by the EP. - // Refer to "NodeFusionOptions_DropConstantInitializers()". - - if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input1, shape1); - input0 = gsl::span(const_input0->data); - shape0 = const_input0->shape; - } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input0, shape0); - input1 = gsl::span(const_input1->data); - shape1 = const_input1->shape; - } - } else { - // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) - // are disabled. - const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); - const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); - RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, - "Expected 2 initializer inputs to be saved by EP"); + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); + data = gsl::span(float_data, num_elems); + shape = type_shape.GetShape(); +} +OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input1, shape1); input0 = gsl::span(const_input0->data); - input1 = gsl::span(const_input1->data); shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input0, shape0); + input1 = gsl::span(const_input1->data); shape1 = const_input1->shape; } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + + input0 = gsl::span(const_input0->data); + input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; + shape1 = const_input1->shape; + } - if (shape0 != shape1) { - throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); - } + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } - size_t num_outputs = kernel_context.GetOutputCount(); - if (num_outputs != 1) { - throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); - } + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + } - auto output = kernel_context.GetOutput(0, shape0); - float* output_data = output.GetTensorMutableData(); + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; - } - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } catch (const std::exception& ex) { - Ort::Status status(ex.what(), ORT_EP_FAIL); - return status.release(); + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; } - - return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } - const OrtApi& ort_api; - const OrtLogger& logger; - const std::unordered_map& float_initializers; - std::string input0_name; - std::string input1_name; -}; + return nullptr; +} /// /// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. @@ -230,6 +210,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); + auto domain = node.GetDomain(); if (op_type == "Mul") { // Check that Mul has inputs/output of type float @@ -262,6 +243,8 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG supported_nodes.push_back(node); // Only support a single Mul for now. break; + } else if (op_type == "Custom_Mul" && domain == "test") { + supported_nodes.push_back(node); } } @@ -269,19 +252,26 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; } - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; - - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, - reinterpret_cast(supported_nodes.data()), - supported_nodes.size(), - &node_fusion_options)); + if (supported_nodes[0].GetOperatorType() == "Mul") { + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { + // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, + // as CustomMul has the concrete kernel implementation. + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); + } + } catch (const Ort::Exception& ex) { Ort::Status status(ex); return status.release(); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 7e96a523cf285..5d4788ed76bf2 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -8,7 +8,34 @@ #include "../plugin_ep_utils.h" class ExampleEpFactory; -struct MulKernel; + +/// +/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. +/// +struct MulKernel { + MulKernel(const OrtApi& ort_api, const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, std::string input1_name) + : ort_api(ort_api), + logger(logger), + float_initializers(float_initializers), + input0_name(input0_name), + input1_name(input1_name) {} + + const FloatInitializer* TryGetSavedInitializer(const std::string& name) const; + + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const; + + OrtStatus* Compute(OrtKernelContext* kernel_ctx); + + const OrtApi& ort_api; + const OrtLogger& logger; + const std::unordered_map& float_initializers; + std::string input0_name; + std::string input1_name; +}; /// /// Example EP that can compile a single Mul operator. diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h new file mode 100644 index 0000000000000..c37038a727067 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "onnxruntime_c_api.h" +#include "ep.h" + +// Plugin EPs can provide two types of custom ops: +// +// 1. A full OrtCustomOp with a concrete kernel implementation +// - This Example EP demonstrates this approach. +// - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT +// that the custom node should NOT be fused or compiled. Instead, ORT should invoke +// the custom node's Compute() function at runtime. +// +// 2. A "placeholder" OrtCustomOp with an empty kernel implementation +// - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() +// does nothing. The purpose is to satisfy model validation during model loading by +// registering the custom op as a valid operator in the session. +// - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to +// notify ORT that this custom node should be fused and compiled by the EP. +// - In Compile(), the EP executes its compiled bits to perform inference for +// the fused custom node. +// +// Note: Approach #2 is suitable for plugin TRT RTX EP to support TRT plugins. + +struct CustomMulKernel : MulKernel { + CustomMulKernel(const OrtApi& ort_api, + const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, + std::string input1_name) : MulKernel(ort_api, logger, float_initializers, + input0_name, input1_name) { + } + + OrtStatusPtr ComputeV2(OrtKernelContext* kernel_ctx) { + return MulKernel::Compute(kernel_ctx); + } +}; + +struct ExampleEpCustomOp : Ort::CustomOpBase { + explicit ExampleEpCustomOp(const char* provider, ExampleEpFactory* factory) : provider_(provider), + factory_(factory) { + } + + OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const; + + OrtStatusPtr KernelComputeV2(void* op_kernel, OrtKernelContext* context) const; + + const char* GetName() const { return name_; }; + + void SetName(const char* name) { name_ = name; }; + + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return num_inputs_; }; + + void SetInputTypeCount(size_t num) { num_inputs_ = num; }; + + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + size_t GetOutputTypeCount() const { return num_outputs_; }; + + void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; + + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + bool GetVariadicInputHomogeneity() const { + return false; // heterogenous + } + + bool GetVariadicOutputHomogeneity() const { + return false; // heterogeneous + } + + private: + const char* provider_ = nullptr; + const char* name_ = nullptr; + size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input + size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output + ExampleEpFactory* factory_ = nullptr; + std::unordered_map float_initializers_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 79ec3fe3a3780..c56f0f74ab74a 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -40,6 +40,9 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + GetNumCustomOpDomains = GetNumCustomOpDomainsImpl; + GetCustomOpDomains = GetCustomOpDomainsImpl; + // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", @@ -71,6 +74,22 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator}; + // Custom Op Domains + custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; + custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; + + std::vector> created_custom_op_list; + created_custom_op_list.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list.back().get()->SetName("Custom_Mul"); + custom_op_domains_[0].Add(created_custom_op_list.back().get()); + + std::vector> created_custom_op_list_2; + created_custom_op_list_2.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list_2.back().get()->SetName("Custom_Mul2"); + custom_op_domains_[1].Add(created_custom_op_list_2.back().get()); + + created_custom_op_lists_[0] = std::move(created_custom_op_list); + created_custom_op_lists_[1] = std::move(created_custom_op_list_2); } /*static*/ @@ -313,6 +332,48 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac } /*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Out_ size_t* num_domains) noexcept { + auto* factory = static_cast(this_ptr); + *num_domains = factory->custom_op_domains_.size(); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetCustomOpDomainsImpl( + OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept { + auto* factory = static_cast(this_ptr); + + // The `num_domains` should be 2 as ORT calls GetNumCustomOpDomainsImpl() to get the number prior to + // call this function. + gsl::span domains_span(domains, num_domains); + domains_span[0] = factory->custom_op_domains_[0]; + domains_span[1] = factory->custom_op_domains_[1]; + + return nullptr; +} + +OrtStatusPtr ExampleEpCustomOp::CreateKernelV2(const OrtApi& /*api*/, + const OrtKernelInfo* /*info*/, + void** op_kernel) const { + std::string node_input_0 = "X"; + std::string node_input_1 = "W"; + auto custom_kernel_op = std::make_unique(factory_->ort_api, + factory_->default_logger_, + float_initializers_, + node_input_0, + node_input_1); + *op_kernel = custom_kernel_op.release(); + return nullptr; +} + +OrtStatusPtr ExampleEpCustomOp::KernelComputeV2(void* op_kernel, OrtKernelContext* context) const { + return static_cast(op_kernel)->ComputeV2(context); +} + OrtStatus* ORT_API_CALL ExampleEpFactory::GetHardwareDeviceIncompatibilityDetailsImpl( OrtEpFactory* this_ptr, const OrtHardwareDevice* hw, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 9306b0fc88ec9..244051dd5e4d0 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -9,6 +9,8 @@ #include "ep_data_transfer.h" #include "ep_external_resource_importer.h" #include "../plugin_ep_utils.h" +#include "ep.h" +#include "ep_custom_op.h" /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. @@ -26,6 +28,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return arena_allocator_.get(); } + const OrtLogger& default_logger_; // default logger for the EP factory + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -78,7 +82,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtHardwareDevice* hw, OrtDeviceEpIncompatibilityDetails* details) noexcept; - const OrtLogger& default_logger_; // default logger for the EP factory + static OrtStatus* ORT_API_CALL GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Out_ size_t* num_domains) noexcept; + + static OrtStatus* ORT_API_CALL GetCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept; + const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID @@ -94,4 +104,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { std::mutex mutex_; // mutex to protect arena_allocator_ and num_arena_users_ std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory + + std::vector custom_op_domains_{2}; + std::vector>> created_custom_op_lists_{2}; }; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 0970654b48ca1..bf805b0707e7d 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -48,6 +48,35 @@ void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } +void RunCustomMulModelWithPluginEp(const Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/custom_mul.onnx"), session_options); + + // Create two inputs with same values + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + ort_input_names.push_back("W"); + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(1, 4, 9, 16, 25, 36)); +} + void RunSqueezeMulReluModel(const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options); @@ -551,6 +580,36 @@ TEST(OrtEpLibrary, KernelPluginEp_ControlFlow_Scan) { } } +// Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. +// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_Custom_Op_Inference_With_Explicit_Ep) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + RunCustomMulModelWithPluginEp(session_options); +} + +// Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. +// Uses the PREFER_CPU policy to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_Custom_Op_Inference_With_Prefer_Cpu) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + { + // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunCustomMulModelWithPluginEp(session_options); + } +} + // Tests the GetHardwareDeviceEpIncompatibilityDetails C API with the example plugin EP. // The example plugin EP supports CPU devices, so this test verifies that a CPU device // is reported as compatible (reasons_bitmask == 0). @@ -646,6 +705,5 @@ TEST(OrtEpLibrary, PluginEp_GpuDevice_ReturnsInCompatible) { api->ReleaseDeviceEpIncompatibilityDetails(details); } - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/custom_mul.onnx b/onnxruntime/test/testdata/custom_mul.onnx new file mode 100644 index 0000000000000..87bb64764a669 Binary files /dev/null and b/onnxruntime/test/testdata/custom_mul.onnx differ diff --git a/onnxruntime/test/testdata/custom_mul.py b/onnxruntime/test/testdata/custom_mul.py new file mode 100644 index 0000000000000..2639648561fe1 --- /dev/null +++ b/onnxruntime/test/testdata/custom_mul.py @@ -0,0 +1,45 @@ +import onnx + + +def create_custom_mul_model(): + # === Inputs === + x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2]) + w = onnx.helper.make_tensor_value_info("W", onnx.TensorProto.FLOAT, [3, 2]) + + # === Output === + y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [3, 2]) + + # === Custom Node: Custom_Mul === + # Replace "Mul" with your custom op name and domain + custom_node = onnx.helper.make_node( + op_type="Custom_Mul", # <-- custom op name + inputs=["X", "W"], + outputs=["Y"], + domain="test", # <-- custom domain + ) + + # === Graph === + graph = onnx.helper.make_graph( + nodes=[custom_node], + name="CustomMulGraph", + inputs=[x, w], + outputs=[y], + ) + + # === Model (opset version 13 or later is fine) === + model = onnx.helper.make_model( + graph, + opset_imports=[ + onnx.helper.make_opsetid("", 13), # standard ONNX domain + onnx.helper.make_opsetid("com.example", 1), + ], # your custom domain + producer_name="custom_mul_builder", + ) + + return model + + +# ===== Save the Model ===== +model = create_custom_mul_model() +onnx.save(model, "custom_mul.onnx") +print("Saved custom_mul.onnx")