-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[EP ABI] Add CreateCustomOpDomains() API for plugin EP to register custom ops #26759
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
0b5b4d7
80561db
6bd316f
ad0a023
5e398d4
aeb2386
3849cd3
9c987be
fbe2434
40fa8fe
c7a0491
4787c3f
632ce31
5905434
6017c00
47bb4dc
6721a98
1ab246d
3478732
a1d36af
3065e9d
d340de5
ee8851b
15f5baf
6b01e7f
adf565e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1413,6 +1413,49 @@ struct OrtEpFactory { | |
| * \since Version 1.24. | ||
| */ | ||
| ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); | ||
|
|
||
| /** \brief Returns the number of OrtCustomOpDomains that this factory creates. | ||
| * | ||
| * \param[in] this_ptr The OrtEpFactory instance. | ||
| * \param[out] num_domains Output parameter set to the number of created OrtCustomOpDomain instances. | ||
| * | ||
| * \snippet{doc} snippets.dox OrtStatus Return Value | ||
| * | ||
| * \since Version 1.24. | ||
| */ | ||
| ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains); | ||
|
|
||
| /** \brief Creates the EP-specific OrtCustomOpDomains. | ||
| * | ||
| * This function is used when running inference on a model that contains EP-specific custom operations. | ||
| * For compile-based EPs, the EP does not need to provide a concrete kernel implementation for each custom op. | ||
| * Instead, it may provide only placeholder custom ops with the correct names so they can be recognized | ||
| * during model loading. | ||
| * | ||
| * Workflow: | ||
| * 1. The EP implements this function to supply a list of OrtCustomOpDomain instances. | ||
|
||
| * 2. The application calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing | ||
| * the plugin EP's factory. | ||
| * 3. SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomain list to the | ||
| * session options. | ||
| * | ||
| * As a result, any session created from these session options will have these custom op domains registered | ||
| * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. | ||
| * | ||
| * Note: EP has the responsibility to release OrtCustomOpDomain instances it creates. It happens | ||
|
||
| * automatically if using ORT C++ api. | ||
|
||
| * | ||
| * \param[in] this_ptr The OrtEpFactory instance. | ||
| * \param[out] domains Pre-allocated array of `num_domains` elements by ORT that should be filled with | ||
| OrtCustomOpDomain created by the EP. | ||
| * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. | ||
| * | ||
| * \snippet{doc} snippets.dox OrtStatus Return Value | ||
adrianlizarraga marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| * | ||
| * \since Version 1.24. | ||
| */ | ||
| ORT_API2_STATUS(CreateCustomOpDomains, _In_ OrtEpFactory* this_ptr, | ||
| _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains); | ||
|
||
| }; | ||
|
|
||
| #ifdef __cplusplus | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,13 @@ class EpFactoryInternalImpl { | |
| return nullptr; | ||
| } | ||
|
|
||
| virtual OrtStatus* CreateCustomOpDomains(_Outptr_result_maybenull_ OrtCustomOpDomain** out, | ||
|
||
| _Out_ size_t* num_domains) const noexcept { | ||
| *out = nullptr; | ||
| *num_domains = 0; | ||
| return nullptr; | ||
| } | ||
|
|
||
| // Function ORT calls to release an EP instance. | ||
| void ReleaseEp(OrtEp* ep); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -482,7 +482,8 @@ | |
| Status AddEpOptionsToSessionOptions(gsl::span<const OrtEpDevice* const> ep_devices, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this not called during auto ep selection path, when the session_options.set_provider_selection_policy(ort.OrtExecutionProviderDevicePolicy.PREFER_GPU) is set?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the case of EP using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Registering custom ops should be independent of EP , so it should be available in auto ep selectin path as well |
||
| gsl::span<const char* const> ep_option_keys, | ||
| gsl::span<const char* const> ep_option_vals, | ||
| SessionOptions& session_options) { | ||
| OrtSessionOptions& ort_session_options) { | ||
| SessionOptions& session_options = ort_session_options.value; | ||
| const size_t num_ep_options = ep_option_keys.size(); | ||
| if (ep_option_vals.size() != num_ep_options) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, | ||
|
|
@@ -505,6 +506,44 @@ | |
|
|
||
| ORT_RETURN_IF_ERROR(config_options.AddConfigEntry((prefix + ep_option_keys[j]).c_str(), ep_option_vals[j])); | ||
| } | ||
|
|
||
| // Add custom op domain provided by EP to the session options if any. | ||
| // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::CreateCustomOpDomains | ||
| // were added in ORT 1.24. | ||
| OrtEpFactory* ep_factory = ep_device->ep_factory; | ||
| if (ep_factory && | ||
| ep_factory->ort_version_supported >= 24 && | ||
| ep_factory->CreateCustomOpDomains != nullptr) { | ||
|
||
| auto is_already_in_domains = [&](const std::string& domain_name, const std::vector<OrtCustomOpDomain*>& domains) { | ||
|
Check warning on line 517 in onnxruntime/core/session/utils.cc
|
||
| for (auto ptr : domains) { | ||
| if (domain_name == ptr->domain_) { | ||
| return true; | ||
| } | ||
| } | ||
| return false; | ||
| }; | ||
|
|
||
| size_t num_domains = 0; | ||
| ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); | ||
|
|
||
| InlinedVector<OrtCustomOpDomain*> domains; | ||
| domains.resize(num_domains); | ||
|
|
||
| ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->CreateCustomOpDomains(ep_factory, | ||
| domains.data(), | ||
| num_domains))); | ||
|
|
||
| const auto domains_span = gsl::span<OrtCustomOpDomain*>(domains.data(), domains.size()); | ||
| for (auto domain : domains_span) { | ||
| if (!is_already_in_domains(domain->domain_, ort_session_options.custom_op_domains_) && | ||
| domain->custom_ops_.size() > 0) { | ||
| ort_session_options.custom_op_domains_.push_back(domain); | ||
| } else { | ||
| LOGS_DEFAULT(WARNING) << "The custom op domain name " | ||
| << domain->domain_ << " is already in the session option. Skip it."; | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return Status::OK(); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,117 +15,97 @@ | |
| #include "ep_factory.h" | ||
| #include "ep_stream_support.h" | ||
|
|
||
| /// <summary> | ||
| /// Example implementation of ONNX Mul. Does not handle many things like broadcasting. | ||
| /// </summary> | ||
| struct MulKernel { | ||
| MulKernel(const OrtApi& ort_api, const OrtLogger& logger, | ||
| const std::unordered_map<std::string, FloatInitializer>& float_initializers, | ||
| std::string input0_name, std::string input1_name) | ||
| : ort_api(ort_api), | ||
| logger(logger), | ||
| float_initializers(float_initializers), | ||
| input0_name(input0_name), | ||
| input1_name(input1_name) {} | ||
|
|
||
| const FloatInitializer* TryGetSavedInitializer(const std::string& name) const { | ||
| auto iter = float_initializers.find(name); | ||
| return iter != float_initializers.end() ? &iter->second : nullptr; | ||
| } | ||
| const FloatInitializer* MulKernel::TryGetSavedInitializer(const std::string& name) const { | ||
| auto iter = float_initializers.find(name); | ||
| return iter != float_initializers.end() ? &iter->second : nullptr; | ||
| } | ||
|
|
||
| void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, | ||
| /*out*/ gsl::span<const float>& data, | ||
| /*out*/ std::vector<int64_t>& shape) const { | ||
| Ort::ConstValue input = kernel_context.GetInput(index); | ||
| auto type_shape = input.GetTensorTypeAndShapeInfo(); | ||
| void MulKernel::GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, | ||
| /*out*/ gsl::span<const float>& data, | ||
| /*out*/ std::vector<int64_t>& shape) const { | ||
| Ort::ConstValue input = kernel_context.GetInput(index); | ||
| auto type_shape = input.GetTensorTypeAndShapeInfo(); | ||
|
|
||
| ONNXTensorElementDataType elem_type = type_shape.GetElementType(); | ||
| if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) | ||
| throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); | ||
| ONNXTensorElementDataType elem_type = type_shape.GetElementType(); | ||
| if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) | ||
| throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); | ||
|
|
||
| const float* float_data = input.GetTensorData<float>(); | ||
| size_t num_elems = type_shape.GetElementCount(); | ||
| data = gsl::span<const float>(float_data, num_elems); | ||
| shape = type_shape.GetShape(); | ||
| } | ||
|
|
||
| OrtStatus* Compute(OrtKernelContext* kernel_ctx) { | ||
| RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, | ||
| OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, | ||
| "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); | ||
| Ort::KernelContext kernel_context(kernel_ctx); | ||
| try { | ||
| gsl::span<const float> input0; | ||
| gsl::span<const float> input1; | ||
| std::vector<int64_t> shape0; | ||
| std::vector<int64_t> shape1; | ||
|
|
||
| size_t num_inputs = kernel_context.GetInputCount(); | ||
|
|
||
| if (num_inputs == 2) { | ||
| // Both inputs are non-constant. Get them from ORT's KernelContext. | ||
| GetInputDataAndShape(kernel_context, 0, input0, shape0); | ||
| GetInputDataAndShape(kernel_context, 1, input1, shape1); | ||
| } else if (num_inputs == 1) { | ||
| // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. | ||
| // Get the constant input from the initializers saved by the EP. | ||
| // Refer to "NodeFusionOptions_DropConstantInitializers()". | ||
|
|
||
| if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { | ||
| GetInputDataAndShape(kernel_context, 0, input1, shape1); | ||
| input0 = gsl::span<const float>(const_input0->data); | ||
| shape0 = const_input0->shape; | ||
| } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { | ||
| GetInputDataAndShape(kernel_context, 0, input0, shape0); | ||
| input1 = gsl::span<const float>(const_input1->data); | ||
| shape1 = const_input1->shape; | ||
| } | ||
| } else { | ||
| // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) | ||
| // are disabled. | ||
| const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); | ||
| const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); | ||
| RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, | ||
| "Expected 2 initializer inputs to be saved by EP"); | ||
| const float* float_data = input.GetTensorData<float>(); | ||
| size_t num_elems = type_shape.GetElementCount(); | ||
| data = gsl::span<const float>(float_data, num_elems); | ||
| shape = type_shape.GetShape(); | ||
| } | ||
|
|
||
| OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) { | ||
| RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, | ||
| OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, | ||
| "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); | ||
| Ort::KernelContext kernel_context(kernel_ctx); | ||
| try { | ||
| gsl::span<const float> input0; | ||
| gsl::span<const float> input1; | ||
| std::vector<int64_t> shape0; | ||
| std::vector<int64_t> shape1; | ||
|
|
||
| size_t num_inputs = kernel_context.GetInputCount(); | ||
|
|
||
| if (num_inputs == 2) { | ||
| // Both inputs are non-constant. Get them from ORT's KernelContext. | ||
| GetInputDataAndShape(kernel_context, 0, input0, shape0); | ||
| GetInputDataAndShape(kernel_context, 1, input1, shape1); | ||
| } else if (num_inputs == 1) { | ||
| // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. | ||
| // Get the constant input from the initializers saved by the EP. | ||
| // Refer to "NodeFusionOptions_DropConstantInitializers()". | ||
|
|
||
| if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { | ||
| GetInputDataAndShape(kernel_context, 0, input1, shape1); | ||
| input0 = gsl::span<const float>(const_input0->data); | ||
| input1 = gsl::span<const float>(const_input1->data); | ||
| shape0 = const_input0->shape; | ||
| } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { | ||
| GetInputDataAndShape(kernel_context, 0, input0, shape0); | ||
| input1 = gsl::span<const float>(const_input1->data); | ||
| shape1 = const_input1->shape; | ||
| } | ||
| } else { | ||
| // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) | ||
| // are disabled. | ||
| const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); | ||
| const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); | ||
| RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, | ||
| "Expected 2 initializer inputs to be saved by EP"); | ||
|
|
||
| input0 = gsl::span<const float>(const_input0->data); | ||
| input1 = gsl::span<const float>(const_input1->data); | ||
| shape0 = const_input0->shape; | ||
| shape1 = const_input1->shape; | ||
| } | ||
|
|
||
| if (shape0 != shape1) { | ||
| throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); | ||
| } | ||
| if (shape0 != shape1) { | ||
| throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); | ||
| } | ||
|
|
||
| size_t num_outputs = kernel_context.GetOutputCount(); | ||
| if (num_outputs != 1) { | ||
| throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); | ||
| } | ||
| size_t num_outputs = kernel_context.GetOutputCount(); | ||
| if (num_outputs != 1) { | ||
| throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); | ||
| } | ||
|
|
||
| auto output = kernel_context.GetOutput(0, shape0); | ||
| float* output_data = output.GetTensorMutableData<float>(); | ||
| auto output = kernel_context.GetOutput(0, shape0); | ||
| float* output_data = output.GetTensorMutableData<float>(); | ||
|
|
||
| for (size_t i = 0; i < input0.size(); ++i) { | ||
| output_data[i] = input0[i] * input1[i]; | ||
| } | ||
| } catch (const Ort::Exception& ex) { | ||
| Ort::Status status(ex); | ||
| return status.release(); | ||
| } catch (const std::exception& ex) { | ||
| Ort::Status status(ex.what(), ORT_EP_FAIL); | ||
| return status.release(); | ||
| for (size_t i = 0; i < input0.size(); ++i) { | ||
| output_data[i] = input0[i] * input1[i]; | ||
| } | ||
|
|
||
| return nullptr; | ||
| } catch (const Ort::Exception& ex) { | ||
| Ort::Status status(ex); | ||
| return status.release(); | ||
| } catch (const std::exception& ex) { | ||
| Ort::Status status(ex.what(), ORT_EP_FAIL); | ||
| return status.release(); | ||
| } | ||
|
|
||
| const OrtApi& ort_api; | ||
| const OrtLogger& logger; | ||
| const std::unordered_map<std::string, FloatInitializer>& float_initializers; | ||
| std::string input0_name; | ||
| std::string input1_name; | ||
| }; | ||
| return nullptr; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. | ||
|
|
@@ -262,26 +242,35 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG | |
|
|
||
| supported_nodes.push_back(node); // Only support a single Mul for now. | ||
| break; | ||
| } else if (op_type == "Custom_Mul") { | ||
|
||
| supported_nodes.push_back(node); | ||
| } | ||
| } | ||
|
|
||
| if (supported_nodes.empty()) { | ||
| return nullptr; | ||
| } | ||
|
|
||
| // Create (optional) fusion options for the supported nodes to fuse. | ||
| OrtNodeFusionOptions node_fusion_options = {}; | ||
| node_fusion_options.ort_version_supported = ORT_API_VERSION; | ||
|
|
||
| // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers | ||
| // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. | ||
| // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use | ||
| // during inference. | ||
| node_fusion_options.drop_constant_initializers = true; | ||
| RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, | ||
| reinterpret_cast<const OrtNode* const*>(supported_nodes.data()), | ||
| supported_nodes.size(), | ||
| &node_fusion_options)); | ||
| if (supported_nodes[0].GetOperatorType() == "Mul") { | ||
| // Create (optional) fusion options for the supported nodes to fuse. | ||
| OrtNodeFusionOptions node_fusion_options = {}; | ||
| node_fusion_options.ort_version_supported = ORT_API_VERSION; | ||
|
|
||
| // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers | ||
| // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. | ||
| // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use | ||
| // during inference. | ||
| node_fusion_options.drop_constant_initializers = true; | ||
| RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, | ||
| reinterpret_cast<const OrtNode* const*>(supported_nodes.data()), | ||
| supported_nodes.size(), | ||
| &node_fusion_options)); | ||
| } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { | ||
| // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, | ||
| // as CustomMul has the concrete kernel implementation. | ||
| RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); | ||
| } | ||
|
|
||
| } catch (const Ort::Exception& ex) { | ||
| Ort::Status status(ex); | ||
| return status.release(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we renamed it to
GetCustomOpDomains, maybe replace "creates" with "provides" or "supplies".There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
replaced with "provides"