diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 617788fcab8bb..42a44bc082583 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1824,6 +1824,64 @@ struct OrtEpFactory { ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr, _In_ const OrtEpDevice* ep_device, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer); + + /** \brief Returns the number of OrtCustomOpDomains that this factory provides. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] num_domains Output parameter set to the number of created OrtCustomOpDomain instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains); + + /** \brief Gets the EP-specific OrtCustomOpDomains. + * + * This function is used when running inference on a model that contains EP-specific custom operations. + * + * Workflow: + * 1. The EP factory implements this function to supply a list of OrtCustomOpDomain instances. + * 2. The application calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing + * the plugin EP's factory. + * 3. SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomain list to the + * session options. + * + * As a result, any session created from these session options will have these custom op domains registered + * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. + * + * Plugin EPs can provide two types of custom ops: + * 1. A full OrtCustomOp with a concrete kernel implementation + * - This Example EP demonstrates this approach. + * - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT + * that the custom node should NOT be fused or compiled. Instead, ORT should invoke + * the custom node's Compute() function at runtime. + * + * 2. A "placeholder" OrtCustomOp with an empty kernel implementation + * - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() + * does nothing. The purpose is to satisfy model validation during model loading by + * registering the custom op as a valid operator in the session. + * - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to + * notify ORT that this custom node should be fused and compiled by the EP. + * - In Compile(), the EP executes its compiled bits to perform inference for + * the fused custom node. + * + * Note: The OrtCustomOpDomain instances must be valid while any session is using them. + EP factory has the responsibility to release OrtCustomOpDomain instances it creates. It happens + * automatically if using ORT C++ api. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with + OrtCustomOpDomain created by the EP. The `num_domains` is the value returned by + GetNumCustomOpDomains(). The implementation is expected to treat `domains` as a buffer. + * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr, + _Out_writes_all_(num_domains) OrtCustomOpDomain** domains, _In_ size_t num_domains); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index c3bf74a4607e8..b67927e5548a1 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3333,7 +3333,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS ep_devices_span, ep_option_keys_span, ep_option_vals_span, - session_options->value)); + *session_options)); session_options->provider_factories.push_back(std::move(provider_factory)); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 20a47715df2b8..2fa3f456658ac 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -96,6 +96,18 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept { + *num_domains = 0; + return nullptr; + } + + virtual OrtStatus* GetCustomOpDomains(_Out_writes_all_(num_domains) OrtCustomOpDomain** domains, + _In_ size_t num_domains) const noexcept { + ORT_UNUSED_PARAMETER(domains); + ORT_UNUSED_PARAMETER(num_domains); + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 944e83d8cad66..699a3d89f6784 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -519,7 +519,8 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, Status AddEpOptionsToSessionOptions(gsl::span ep_devices, gsl::span ep_option_keys, gsl::span ep_option_vals, - SessionOptions& session_options) { + OrtSessionOptions& ort_session_options) { + SessionOptions& session_options = ort_session_options.value; const size_t num_ep_options = ep_option_keys.size(); if (ep_option_vals.size() != num_ep_options) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -542,6 +543,46 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); } + + // Add custom op domain provided by EP to the session options if any. + // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains + // were added in ORT 1.24. + OrtEpFactory* ep_factory = ep_device->ep_factory; + if (ep_factory && + ep_factory->ort_version_supported >= 24 && + ep_factory->GetNumCustomOpDomains != nullptr && + ep_factory->GetCustomOpDomains != nullptr) { + auto is_already_in_domains = + [&](const std::string& domain_name, const std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + + size_t num_domains = 0; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); + + InlinedVector domains; + domains.resize(num_domains); + + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, + domains.data(), + domains.size()))); + + const auto domains_span = gsl::span(domains.data(), domains.size()); + for (auto domain : domains_span) { + if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) && + domain->custom_ops_.size() > 0) { + ort_session_options.custom_op_domains_.push_back(domain); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " + << domain->domain_ << " is already in the session option. Skip it."; + } + } + } } return Status::OK(); diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 2ccd4d464a261..da951b5cb9810 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -69,7 +69,7 @@ Status CreateIExecutionProviderFactoryForEpDevices(const Environment& env, Status AddEpOptionsToSessionOptions(gsl::span ep_devices, gsl::span ep_options_keys, gsl::span ep_options_vals, - SessionOptions& session_options); + OrtSessionOptions& session_options); } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f0d8906d99c14..f2aea061b244d 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1346,7 +1346,7 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, ORT_RETURN_IF_ERROR(AddEpOptionsToSessionOptions(ep_devices, ep_option_keys, ep_option_vals, - py_sess_options.value)); + py_sess_options)); py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index bce9b59ff0ea4..76b2502da5c3c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -15,117 +15,97 @@ #include "ep_factory.h" #include "ep_stream_support.h" -/// -/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. -/// -struct MulKernel { - MulKernel(const OrtApi& ort_api, const OrtLogger& logger, - const std::unordered_map& float_initializers, - std::string input0_name, std::string input1_name) - : ort_api(ort_api), - logger(logger), - float_initializers(float_initializers), - input0_name(input0_name), - input1_name(input1_name) {} - - const FloatInitializer* TryGetSavedInitializer(const std::string& name) const { - auto iter = float_initializers.find(name); - return iter != float_initializers.end() ? &iter->second : nullptr; - } +const FloatInitializer* MulKernel::TryGetSavedInitializer(const std::string& name) const { + auto iter = float_initializers.find(name); + return iter != float_initializers.end() ? &iter->second : nullptr; +} - void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) const { - Ort::ConstValue input = kernel_context.GetInput(index); - auto type_shape = input.GetTensorTypeAndShapeInfo(); +void MulKernel::GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); - ONNXTensorElementDataType elem_type = type_shape.GetElementType(); - if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) - throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); - const float* float_data = input.GetTensorData(); - size_t num_elems = type_shape.GetElementCount(); - data = gsl::span(float_data, num_elems); - shape = type_shape.GetShape(); - } - - OrtStatus* Compute(OrtKernelContext* kernel_ctx) { - RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - Ort::KernelContext kernel_context(kernel_ctx); - try { - gsl::span input0; - gsl::span input1; - std::vector shape0; - std::vector shape1; - - size_t num_inputs = kernel_context.GetInputCount(); - - if (num_inputs == 2) { - // Both inputs are non-constant. Get them from ORT's KernelContext. - GetInputDataAndShape(kernel_context, 0, input0, shape0); - GetInputDataAndShape(kernel_context, 1, input1, shape1); - } else if (num_inputs == 1) { - // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. - // Get the constant input from the initializers saved by the EP. - // Refer to "NodeFusionOptions_DropConstantInitializers()". - - if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input1, shape1); - input0 = gsl::span(const_input0->data); - shape0 = const_input0->shape; - } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input0, shape0); - input1 = gsl::span(const_input1->data); - shape1 = const_input1->shape; - } - } else { - // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) - // are disabled. - const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); - const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); - RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, - "Expected 2 initializer inputs to be saved by EP"); + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); + data = gsl::span(float_data, num_elems); + shape = type_shape.GetShape(); +} +OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input1, shape1); input0 = gsl::span(const_input0->data); - input1 = gsl::span(const_input1->data); shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input0, shape0); + input1 = gsl::span(const_input1->data); shape1 = const_input1->shape; } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + + input0 = gsl::span(const_input0->data); + input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; + shape1 = const_input1->shape; + } - if (shape0 != shape1) { - throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); - } + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } - size_t num_outputs = kernel_context.GetOutputCount(); - if (num_outputs != 1) { - throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); - } + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + } - auto output = kernel_context.GetOutput(0, shape0); - float* output_data = output.GetTensorMutableData(); + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; - } - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } catch (const std::exception& ex) { - Ort::Status status(ex.what(), ORT_EP_FAIL); - return status.release(); + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; } - - return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } - const OrtApi& ort_api; - const OrtLogger& logger; - const std::unordered_map& float_initializers; - std::string input0_name; - std::string input1_name; -}; + return nullptr; +} /// /// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. @@ -230,6 +210,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); + auto domain = node.GetDomain(); if (op_type == "Mul") { // Check that Mul has inputs/output of type float @@ -262,6 +243,8 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG supported_nodes.push_back(node); // Only support a single Mul for now. break; + } else if (op_type == "Custom_Mul" && domain == "test") { + supported_nodes.push_back(node); } } @@ -269,19 +252,26 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; } - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; - - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, - reinterpret_cast(supported_nodes.data()), - supported_nodes.size(), - &node_fusion_options)); + if (supported_nodes[0].GetOperatorType() == "Mul") { + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { + // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, + // as CustomMul has the concrete kernel implementation. + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); + } + } catch (const Ort::Exception& ex) { Ort::Status status(ex); return status.release(); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 7e96a523cf285..5d4788ed76bf2 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -8,7 +8,34 @@ #include "../plugin_ep_utils.h" class ExampleEpFactory; -struct MulKernel; + +/// +/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. +/// +struct MulKernel { + MulKernel(const OrtApi& ort_api, const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, std::string input1_name) + : ort_api(ort_api), + logger(logger), + float_initializers(float_initializers), + input0_name(input0_name), + input1_name(input1_name) {} + + const FloatInitializer* TryGetSavedInitializer(const std::string& name) const; + + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const; + + OrtStatus* Compute(OrtKernelContext* kernel_ctx); + + const OrtApi& ort_api; + const OrtLogger& logger; + const std::unordered_map& float_initializers; + std::string input0_name; + std::string input1_name; +}; /// /// Example EP that can compile a single Mul operator. diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h new file mode 100644 index 0000000000000..c37038a727067 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "onnxruntime_c_api.h" +#include "ep.h" + +// Plugin EPs can provide two types of custom ops: +// +// 1. A full OrtCustomOp with a concrete kernel implementation +// - This Example EP demonstrates this approach. +// - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT +// that the custom node should NOT be fused or compiled. Instead, ORT should invoke +// the custom node's Compute() function at runtime. +// +// 2. A "placeholder" OrtCustomOp with an empty kernel implementation +// - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() +// does nothing. The purpose is to satisfy model validation during model loading by +// registering the custom op as a valid operator in the session. +// - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to +// notify ORT that this custom node should be fused and compiled by the EP. +// - In Compile(), the EP executes its compiled bits to perform inference for +// the fused custom node. +// +// Note: Approach #2 is suitable for plugin TRT RTX EP to support TRT plugins. + +struct CustomMulKernel : MulKernel { + CustomMulKernel(const OrtApi& ort_api, + const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, + std::string input1_name) : MulKernel(ort_api, logger, float_initializers, + input0_name, input1_name) { + } + + OrtStatusPtr ComputeV2(OrtKernelContext* kernel_ctx) { + return MulKernel::Compute(kernel_ctx); + } +}; + +struct ExampleEpCustomOp : Ort::CustomOpBase { + explicit ExampleEpCustomOp(const char* provider, ExampleEpFactory* factory) : provider_(provider), + factory_(factory) { + } + + OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const; + + OrtStatusPtr KernelComputeV2(void* op_kernel, OrtKernelContext* context) const; + + const char* GetName() const { return name_; }; + + void SetName(const char* name) { name_ = name; }; + + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return num_inputs_; }; + + void SetInputTypeCount(size_t num) { num_inputs_ = num; }; + + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + size_t GetOutputTypeCount() const { return num_outputs_; }; + + void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; + + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + bool GetVariadicInputHomogeneity() const { + return false; // heterogenous + } + + bool GetVariadicOutputHomogeneity() const { + return false; // heterogeneous + } + + private: + const char* provider_ = nullptr; + const char* name_ = nullptr; + size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input + size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output + ExampleEpFactory* factory_ = nullptr; + std::unordered_map float_initializers_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 7c2b8e59ade89..437d2afcef90d 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -39,6 +39,9 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + GetNumCustomOpDomains = GetNumCustomOpDomainsImpl; + GetCustomOpDomains = GetCustomOpDomainsImpl; + // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", @@ -70,6 +73,22 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator}; + // Custom Op Domains + custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; + custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; + + std::vector> created_custom_op_list; + created_custom_op_list.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list.back().get()->SetName("Custom_Mul"); + custom_op_domains_[0].Add(created_custom_op_list.back().get()); + + std::vector> created_custom_op_list_2; + created_custom_op_list_2.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list_2.back().get()->SetName("Custom_Mul2"); + custom_op_domains_[1].Add(created_custom_op_list_2.back().get()); + + created_custom_op_lists_[0] = std::move(created_custom_op_list); + created_custom_op_lists_[1] = std::move(created_custom_op_list_2); } /*static*/ @@ -312,6 +331,48 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac } /*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Out_ size_t* num_domains) noexcept { + auto* factory = static_cast(this_ptr); + *num_domains = factory->custom_op_domains_.size(); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetCustomOpDomainsImpl( + OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept { + auto* factory = static_cast(this_ptr); + + // The `num_domains` should be 2 as ORT calls GetNumCustomOpDomainsImpl() to get the number prior to + // call this function. + gsl::span domains_span(domains, num_domains); + domains_span[0] = factory->custom_op_domains_[0]; + domains_span[1] = factory->custom_op_domains_[1]; + + return nullptr; +} + +OrtStatusPtr ExampleEpCustomOp::CreateKernelV2(const OrtApi& /*api*/, + const OrtKernelInfo* /*info*/, + void** op_kernel) const { + std::string node_input_0 = "X"; + std::string node_input_1 = "W"; + auto custom_kernel_op = std::make_unique(factory_->ort_api, + factory_->default_logger_, + float_initializers_, + node_input_0, + node_input_1); + *op_kernel = custom_kernel_op.release(); + return nullptr; +} + +OrtStatusPtr ExampleEpCustomOp::KernelComputeV2(void* op_kernel, OrtKernelContext* context) const { + return static_cast(op_kernel)->ComputeV2(context); +} + OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDeviceImpl( OrtEpFactory* this_ptr, const OrtEpDevice* /*ep_device*/, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 230fdef772e2f..737276203826c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -9,6 +9,8 @@ #include "ep_data_transfer.h" #include "ep_external_resource_importer.h" #include "../plugin_ep_utils.h" +#include "ep.h" +#include "ep_custom_op.h" /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. @@ -26,6 +28,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return arena_allocator_.get(); } + const OrtLogger& default_logger_; // default logger for the EP factory + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -73,7 +77,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtEpDevice* ep_device, OrtExternalResourceImporterImpl** out_importer) noexcept; - const OrtLogger& default_logger_; // default logger for the EP factory + static OrtStatus* ORT_API_CALL GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Out_ size_t* num_domains) noexcept; + + static OrtStatus* ORT_API_CALL GetCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept; + const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID @@ -89,4 +99,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { std::mutex mutex_; // mutex to protect arena_allocator_ and num_arena_users_ std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory + + std::vector custom_op_domains_{2}; + std::vector>> created_custom_op_lists_{2}; }; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 437ca37c1a7b6..6643de46fb9dd 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -47,6 +47,35 @@ void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } +void RunCustomMulModelWithPluginEp(const Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/custom_mul.onnx"), session_options); + + // Create two inputs with same values + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + ort_input_names.push_back("W"); + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(1, 4, 9, 16, 25, 36)); +} + void RunSqueezeMulReluModel(const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options); @@ -352,5 +381,20 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) { ASSERT_NO_FATAL_FAILURE(RunSubMulSubModel(session_options)); } } + +// Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. +// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_Create_OrtCustomOpDomain) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + RunCustomMulModelWithPluginEp(session_options); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/custom_mul.onnx b/onnxruntime/test/testdata/custom_mul.onnx new file mode 100644 index 0000000000000..87bb64764a669 Binary files /dev/null and b/onnxruntime/test/testdata/custom_mul.onnx differ diff --git a/onnxruntime/test/testdata/custom_mul.py b/onnxruntime/test/testdata/custom_mul.py new file mode 100644 index 0000000000000..c8fd8b0b720a3 --- /dev/null +++ b/onnxruntime/test/testdata/custom_mul.py @@ -0,0 +1,46 @@ +import onnx +from onnx import TensorProto, helper + + +def create_custom_mul_model(): + # === Inputs === + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2]) + w = helper.make_tensor_value_info("W", TensorProto.FLOAT, [3, 2]) + + # === Output === + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2]) + + # === Custom Node: Custom_Mul === + # Replace "Mul" with your custom op name and domain + custom_node = helper.make_node( + op_type="Custom_Mul", # <-- custom op name + inputs=["X", "W"], + outputs=["Y"], + domain="test", # <-- custom domain + ) + + # === Graph === + graph = helper.make_graph( + nodes=[custom_node], + name="CustomMulGraph", + inputs=[x, w], + outputs=[y], + ) + + # === Model (opset version 13 or later is fine) === + model = helper.make_model( + graph, + opset_imports=[ + helper.make_opsetid("", 13), # standard ONNX domain + helper.make_opsetid("com.example", 1), + ], # your custom domain + producer_name="custom_mul_builder", + ) + + return model + + +# ===== Save the Model ===== +model = create_custom_mul_model() +onnx.save(model, "custom_mul.onnx") +print("Saved custom_mul.onnx") diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc index fd2cf2f712628..ec9c5f7f0397f 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc @@ -186,7 +186,7 @@ std::unique_ptr MakeEp(const logging::Logger* logger) { ORT_THROW_IF_ERROR(AddEpOptionsToSessionOptions(state.selected_c_ep_devices, default_ep_option_key_cstrs, default_ep_option_value_cstrs, - ort_session_options.value)); + ort_session_options)); return state.ep_factory->CreateProvider(ort_session_options, *logger->ToExternal()); }