diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 01c7fa3ac1126..668a141595f1b 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -2134,7 +2134,13 @@ if (onnxruntime_BUILD_SHARED_LIB AND "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h" - "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc") + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/if.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/if.cc") onnxruntime_add_shared_library_module(example_plugin_ep_kernel_registry ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src}) target_include_directories(example_plugin_ep_kernel_registry PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session) target_link_libraries(example_plugin_ep_kernel_registry PRIVATE onnxruntime ${GSL_TARGET}) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 617788fcab8bb..0db06f23dcd4a 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -524,8 +524,11 @@ typedef struct OrtKernelImpl OrtKernelImpl; */ struct OrtKernelImpl { uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + uint32_t flags; ///< EP must initialize to 0. Used internally by ORT. /** \brief Computation function called to execute the kernel on an EP. + * + * \note Implementation of this function is required. * * \param[in] this_ptr The OrtKernelImpl instance. * \param[in] context The OrtKernelContext instance that provides access to the inputs and outputs. @@ -537,6 +540,8 @@ struct OrtKernelImpl { ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context); /** \brief Called by ORT to release the OrtKernelImpl instance and its resources. + * + * \note Implementation of this function is required. * * \param[in] this_ptr The OrtKernelImpl instance. * @@ -645,7 +650,9 @@ struct OrtKernelImpl { * \param[in] kernel_create_func_state Opaque state initially provided by the EP that registered the kernel. * Refer to OrtEpApi::KernelRegistry_AddKernel(). May be null. * \param[in] info The OrtKernelInfo instance that provides access to the kernel's input and output characteristics. - * \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. + * \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. On success, ownership of this + * OrtKernelImpl instance transfers to ORT, which will call OrtKernelImpl::Release() to + * release the instance when it is no longer used. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -655,6 +662,89 @@ typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ void* kernel_create_f _In_ const OrtKernelInfo* info, _Outptr_result_maybenull_ OrtKernelImpl** kernel_out); +struct OrtLoopKernelHelper; +typedef struct OrtLoopKernelHelper OrtLoopKernelHelper; + +/** + * \brief Contains helper functions for a Loop OrtKernelImpl created via ::CreateLoopKernel(). + * \since Version 1.24. + */ +struct OrtLoopKernelHelper { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Called by ORT to release the OrtLoopKernelHelper instance and its resources. + * + * \param[in] this_ptr The OrtLoopKernelHelper instance. + * + * \since Version 1.24. + */ + ORT_API_T(void, Release, _In_ OrtLoopKernelHelper* this_ptr); + + /** \brief Helper function that concatenates OrtValue instances from each loop iteration into a single + * pre-allocated output buffer. + * + * \note Implementing this function is required for all Loop opset versions. + * + * \param[in] this_ptr The OrtLoopKernelHelper instance. + * \param[in] stream_handle Optional native stream handle that enables asynchronous operations. May be NULL. + * \param[in] per_iteration_outputs Array of OrtValue instances from each iteration. All OrtValue elements have the + * same shape. + * \param[in] num_per_iteration_outputs The number of OrtValue* elements in the `per_iteration_outputs` array. + * \param[out] output The pre-allocated output buffer. Memory is allocated on the device for the EP running the + * Loop node. + * \param[in] output_size_in_bytes The size in bytes of the `output` buffer. It is guaranteed to be large enough + * to hold the concatenated data of each element in `per_iteration_outputs`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ConcatOutput, _In_ OrtLoopKernelHelper* this_ptr, _In_opt_ void* stream_handle, + _In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs, + _In_ size_t num_per_iteration_outputs, _Out_writes_bytes_all_(output_size_in_bytes) void* output, + _In_ size_t output_size_in_bytes); +}; + +struct OrtScanKernelHelper; +typedef struct OrtScanKernelHelper OrtScanKernelHelper; + +/** + * \brief Contains helper functions for a Scan OrtKernelImpl created via ::CreateScanKernel(). + * \since Version 1.24. + */ +struct OrtScanKernelHelper { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Called by ORT to release the OrtScanKernelHelper instance and its resources. + * + * \param[in] this_ptr The OrtScanKernelHelper instance. + * + * \since Version 1.24. + */ + ORT_API_T(void, Release, _In_ OrtScanKernelHelper* this_ptr); + + /** \brief Helper function that transposes an OrtValue instance during execution of a Scan kernel. + * + * \note Called for Scan (opset >= 9) when the 'scan_input_axes' or 'scan_output_axes' attributes contain + * non-zero values. Implementing this function is required for Scan opset versions >= 9. + * + * \param[in] this_ptr The OrtScanKernelHelper instance. + * \param[in] permutation An array of integers that defines how the input tensor's axes should be permuted. + * \param[in] num_permutation_elems The number of integer elements in the `permutation` array. + * \param[in] input The input OrtValue tensor to transpose. + * \param[in] stream An optional OrtSyncStream instance to be used for asynchronous operations. May be NULL. + * \param[out] output The pre-allocated output OrtValue instance into which to store the results of the + * transpose operation. Must not be released as it is owned by ORT. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(Transpose, _In_ OrtScanKernelHelper* this_ptr, + _In_reads_(num_permutation_elems) const size_t* permutation, _In_ size_t num_permutation_elems, + _In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, _Inout_ OrtValue* output); +}; + /** * \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider. * @@ -1217,6 +1307,107 @@ struct OrtEpApi { * \since Version 1.24 */ ORT_API2_STATUS(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep); + + /** \brief Creates an OrtKernelImpl instance for an If operator. + * + * Control flow operators require access to ORT session internals to orchestrate subgraph operations. + * This function allows an EP to create a properly configured OrtKernelImpl with access to ORT internals that + * the EP can add to its kernel registry. + * + * An EP is required to create an OrtKernelDef that keeps input[0] ('cond') on the CPU (i.e., OrtMemTypeCPUInput) + * as this input is used by CPU logic. The output should remain on the device (i.e., OrtMemTypeDefault), which is + * the default setting, to avoid copying to/from CPU. + * + * Example kernel definition (CXX API): + * Ort::KernelDef kernel_def = Ort::KernelDefBuilder() + * .SetDomain("").SetOperatorType("If").SetSinceVersion(21, 22) + * .SetExecutionProvider("MyEp") + * .SetInputMemType(0, OrtMemTypeCPUInput) // 'cond' on CPU + * .SetOutputMemType(0, OrtMemTypeDefault) // output on EP device + * .AddTypeConstraint("B", ...) + * .AddTypeConstraint("V", ...).Build(); + * + * \param[in] kernel_info The ::OrtKernelInfo instance for an If node. This function returns error ORT_FAIL + * if the opset version specified by `kernel_info` is unsupported. + * \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the If node. + * Must be released via ::ReleaseKernelImpl, unless ownership is transferred + * to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out); + + /** \brief Creates an OrtKernelImpl instance for a Loop operator. + * + * Control flow operators require access to ORT session internals to orchestrate subgraph operations. + * This function allows an EP to create a properly configured OrtKernelImpl with access to ORT internals that + * the EP can add to its kernel registry. + * + * An EP is required to create an OrtKernelDef that keeps input[0] ('M') and input[1] ('cond') on the CPU + * (i.e., OrtMemTypeCPUInput) as these inputs are used by CPU logic. Input[2] ('v_initial') and the output should + * remain on the device (i.e., OrtMemTypeDefault), which is the default setting, to avoid copying to/from CPU. + * + * Example kernel definition (CXX API): + * Ort::KernelDef kernel_def = Ort::KernelDefBuilder() + * .SetDomain("").SetOperatorType("Loop").SetSinceVersion(21, 22) + * .SetExecutionProvider("MyEp") + * .SetInputMemType(0, OrtMemTypeCPUInput) // 'M' on CPU + * .SetInputMemType(1, OrtMemTypeCPUInput) // 'cond' on CPU + * .SetInputMemType(2, OrtMemTypeDefault) // 'v_initial' on EP device + * .SetOutputMemType(0, OrtMemTypeDefault) // output on EP device + * .AddTypeConstraint("I", ...) + * .AddTypeConstraint("B", ...) + * .AddTypeConstraint("V", ...).Build(); + * + * \param[in] kernel_info The ::OrtKernelInfo instance for a Loop node. This function returns error ORT_FAIL + * if the opset version specified by `kernel_info` is unsupported. + * \param[in] helper A OrtLoopKernelHelper instance that contains helper functions that ORT calls during kernel + * execution to operate on tensors allocated with the EP's device memory. + * ORT will call OrtLoopKernelHelper::Release() to release the helper and its resources. + * \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the Loop node. + * Must be released via ::ReleaseKernelImpl, unless ownership is transferred + * to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtLoopKernelHelper* helper, + _Outptr_ OrtKernelImpl** kernel_out); + + /** \brief Creates an OrtKernelImpl instance for a Scan operator. Does not support opset versions older than 9. + * + * Control flow operators require access to ORT session internals to orchestrate subgraph operations. + * This function allows an EP to create a properly configured OrtKernelImpl with access to ORT internals that + * the EP can add to its kernel registry. + * + * It is recommended that an EP create an OrtKernelDef that keeps the inputs and outputs on the EP's + * device (i.e., OrtMemTypeDefault), which is the default setting, to avoid copying to/from CPU. + * + * Example kernel definition (CXX API): + * Ort::KernelDef kernel_def = Ort::KernelDefBuilder() + * .SetDomain("").SetOperatorType("Scan").SetSinceVersion(21, 22) + * .SetExecutionProvider("MyEp") + * .SetInputMemType(0, OrtMemTypeDefault) // input[0] on EP device + * .SetOutputMemType(0, OrtMemTypeDefault) // output[0] on EP device + * .AddTypeConstraint("V", ...).Build(); + * + * \param[in] kernel_info The ::OrtKernelInfo instance for a Scan node. This function returns error ORT_FAIL + * if the opset version specified by `kernel_info` is unsupported. + * \param[in] helper A OrtScanKernelHelper instance that contains helper functions that ORT calls during kernel + * execution to operate on tensors allocated with the EP's device memory. + * ORT will call OrtScanKernelHelper::Release() to release the helper and its resources. + * \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the Scan node. + * Must be released via ::ReleaseKernelImpl, unless ownership is transferred + * to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.24 + */ + ORT_API2_STATUS(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper, + _Outptr_ OrtKernelImpl** kernel_out); + + ORT_CLASS_RELEASE(KernelImpl); }; /** diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index be51e19023037..c97c5c128e258 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1183,11 +1183,15 @@ const InlinedHashSet* SessionState::GetToBeExecutedRange( Status SessionState::CreateSubgraphSessionState() { for (auto& node : graph_.Nodes()) { for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { - const auto& ep = node.GetExecutionProviderType(); - if (!ep.empty() && - ep != kCpuExecutionProvider && ep != kCudaExecutionProvider && - ep != kDmlExecutionProvider && - ep != kJsExecutionProvider && ep != kWebGpuExecutionProvider) { + const auto& ep_type = node.GetExecutionProviderType(); + const IExecutionProvider* ep = execution_providers_.Get(ep_type); + const bool is_plugin_ep = ep != nullptr && ep->GetOrtEp() != nullptr; + + if (!ep_type.empty() && + ep_type != kCpuExecutionProvider && ep_type != kCudaExecutionProvider && + ep_type != kDmlExecutionProvider && + ep_type != kJsExecutionProvider && ep_type != kWebGpuExecutionProvider && + !is_plugin_ep) { // SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow // node containing the subgraph it will create whatever state it needs internally. continue; diff --git a/onnxruntime/core/providers/cpu/controlflow/loop.h b/onnxruntime/core/providers/cpu/controlflow/loop.h index a648f2181e9c8..544c8d5a7bcbd 100644 --- a/onnxruntime/core/providers/cpu/controlflow/loop.h +++ b/onnxruntime/core/providers/cpu/controlflow/loop.h @@ -48,7 +48,6 @@ class Loop : public controlflow::IControlFlowKernel { static std::unique_ptr Create(const OpKernelInfo& info, const ConcatOutput& concat_output_func, void* stream); - protected: // derived class can provide implementation for handling concatenation of Loop output on a different device void SetConcatOutputFunc(const ConcatOutput& concat_output_func) { concat_output_func_ = concat_output_func; } diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index 47b44b8eeba64..a598de3053133 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -24,6 +24,7 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/ort_apis.h" #include "core/session/plugin_ep/ep_kernel_registration.h" +#include "core/session/plugin_ep/ep_control_flow_kernel_impls.h" #include "core/session/utils.h" using namespace onnxruntime; @@ -655,6 +656,113 @@ ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ c API_IMPL_END } +// Control flow kernel APIs +ORT_API_STATUS_IMPL(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out) { + API_IMPL_BEGIN + if (kernel_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null OrtKernelInfo instance to create an If OrtKernelImpl"); + } + + if (kernel_out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null output parameter to hold the OrtKernelImpl for If"); + } + + const auto* op_kernel_info = reinterpret_cast(kernel_info); + auto kernel_unique_ptr = std::make_unique(*op_kernel_info); + + *kernel_out = kernel_unique_ptr.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtLoopKernelHelper* helper, + _Outptr_ OrtKernelImpl** kernel_out) { + API_IMPL_BEGIN + if (kernel_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null OrtKernelInfo instance to create a Loop OrtKernelImpl"); + } + + if (helper == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null OrtLoopKernelHelper instance to create a Loop OrtKernelImpl"); + } + + if (helper->Release == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "OrtLoopKernelHelper must have a non-null OrtLoopKernelHelper::Release function"); + } + + if (helper->ConcatOutput == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "OrtLoopKernelHelper must have a non-null OrtLoopKernelHelper::ConcatOutput function"); + } + + if (kernel_out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null output parameter to hold the OrtKernelImpl for Loop"); + } + + const auto* op_kernel_info = reinterpret_cast(kernel_info); + auto kernel_unique_ptr = std::make_unique(*op_kernel_info, helper); + + *kernel_out = kernel_unique_ptr.release(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper, + _Outptr_ OrtKernelImpl** kernel_out) { + API_IMPL_BEGIN + if (kernel_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null OrtKernelInfo instance to create a Scan OrtKernelImpl"); + } + + if (helper == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null OrtScanKernelHelper instance to create a Scan OrtKernelImpl"); + } + + if (helper->Release == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "OrtScanKernelHelper must have a non-null OrtScanKernelHelper::Release function"); + } + + if (helper->Transpose == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "OrtScanKernelHelper must have a non-null OrtScanKernelHelper::Transpose function"); + } + + if (kernel_out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Must specify a non-null output parameter to hold the OrtKernelImpl for Scan"); + } + + const auto* op_kernel_info = reinterpret_cast(kernel_info); + int opset = op_kernel_info->node().SinceVersion(); + + if (opset >= 9) { + // Note: CPU EP always uses Scan<9> for all opsets >= 9. + auto kernel_unique_ptr = std::make_unique(*op_kernel_info, helper); + *kernel_out = kernel_unique_ptr.release(); + } else { + return OrtApis::CreateStatus(ORT_FAIL, + "Kernel implementations for Scan older than opset version 9 are not supported"); + } + + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseKernelImpl, _Frees_ptr_opt_ OrtKernelImpl* kernel_impl) { + if (kernel_impl != nullptr && kernel_impl->Release != nullptr) { + kernel_impl->Release(kernel_impl); + } +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -711,6 +819,10 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel, &OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData, &OrtExecutionProviderApi::KernelInfo_GetEp, + &OrtExecutionProviderApi::CreateIfKernel, + &OrtExecutionProviderApi::CreateLoopKernel, + &OrtExecutionProviderApi::CreateScanKernel, + &OrtExecutionProviderApi::ReleaseKernelImpl, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index 2d504f5ad2a64..853342c2c3a53 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -108,4 +108,12 @@ ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData, // KernelInfo ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep); + +// Control flow kernel APIs +ORT_API_STATUS_IMPL(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out); +ORT_API_STATUS_IMPL(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtLoopKernelHelper* helper, + _Outptr_ OrtKernelImpl** kernel_out); +ORT_API_STATUS_IMPL(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper, + _Outptr_ OrtKernelImpl** kernel_out); +ORT_API(void, ReleaseKernelImpl, _Frees_ptr_opt_ OrtKernelImpl* kernel_impl); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_control_flow_kernel_impls.cc b/onnxruntime/core/session/plugin_ep/ep_control_flow_kernel_impls.cc new file mode 100644 index 0000000000000..aa57beb6b16cf --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_control_flow_kernel_impls.cc @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_control_flow_kernel_impls.h" + +#include +#include + +#include "core/framework/error_code_helper.h" +#include "core/providers/cpu/controlflow/utils.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +// +// PluginEpControlFlowKernelImpl +// + +PluginEpControlFlowKernelImpl::PluginEpControlFlowKernelImpl() : OrtKernelImpl{} { + ort_version_supported = ORT_API_VERSION; + + // Indicate that this is a control flow OrtKernelImpl created by ORT. + // Without RTTI, this gives ORT some way to check that static casting a OrtKernelImpl to + // PluginEpControlFlowKernelImpl is valid. + flags = OrtKernelImplFlags::kIsControlFlowKernelImpl; +} + +// +// PluginEpIfKernelImpl +// + +PluginEpIfKernelImpl::PluginEpIfKernelImpl(const OpKernelInfo& info) : kernel_(info) { + Compute = ComputeImpl; + Release = ReleaseImpl; +} + +/*static*/ +OrtStatus* ORT_API_CALL PluginEpIfKernelImpl::ComputeImpl(OrtKernelImpl* this_ptr, + OrtKernelContext* kernel_ctx) noexcept { + API_IMPL_BEGIN + auto* plugin_ep_kernel = static_cast(this_ptr); + ORT_API_RETURN_IF_STATUS_NOT_OK(plugin_ep_kernel->kernel_.Compute(reinterpret_cast(kernel_ctx))); + + return nullptr; + API_IMPL_END +} + +/*static*/ +void ORT_API_CALL PluginEpIfKernelImpl::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +// +// PluginEpLoopKernelImpl +// + +PluginEpLoopKernelImpl::PluginEpLoopKernelImpl(const OpKernelInfo& info, gsl::not_null helper) + : kernel_(info), helper_(helper) { + Compute = ComputeImpl; + Release = ReleaseImpl; + + auto concat_output_func = [this](void* stream, std::vector& per_iteration_outputs, + void* output, size_t output_size_in_bytes) -> Status { + std::vector value_ptrs; + + value_ptrs.reserve(per_iteration_outputs.size()); + std::transform(per_iteration_outputs.begin(), per_iteration_outputs.end(), std::back_inserter(value_ptrs), + [](OrtValue& value) -> OrtValue* { return &value; }); + + return ToStatusAndRelease(helper_->ConcatOutput(helper_, stream, value_ptrs.data(), value_ptrs.size(), + output, output_size_in_bytes)); + }; + + kernel_.SetConcatOutputFunc(concat_output_func); +} + +PluginEpLoopKernelImpl::~PluginEpLoopKernelImpl() { + helper_->Release(helper_); +} + +/*static*/ +OrtStatus* ORT_API_CALL PluginEpLoopKernelImpl::ComputeImpl(OrtKernelImpl* this_ptr, + OrtKernelContext* kernel_ctx) noexcept { + API_IMPL_BEGIN + auto* plugin_ep_kernel = static_cast(this_ptr); + ORT_API_RETURN_IF_STATUS_NOT_OK(plugin_ep_kernel->kernel_.Compute(reinterpret_cast(kernel_ctx))); + + return nullptr; + API_IMPL_END +} + +/*static*/ +void ORT_API_CALL PluginEpLoopKernelImpl::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +// +// PluginEpScanKernelImpl +// + +PluginEpScanKernelImpl::PluginEpScanKernelImpl(const OpKernelInfo& info, gsl::not_null helper) + : kernel_(info), helper_(helper) { + Compute = ComputeImpl; + Release = ReleaseImpl; + + // Bundle EP's function + state into a functor. + auto transpose_func = [this](const gsl::span& permutation, + const Tensor& input, Tensor& output, Stream* stream) -> Status { + auto empty_tensor_deleter = [](void* /*data*/) -> void { /* do not delete Tensor (not owned) */ }; + const OrtValue ort_value_input(const_cast(&input), DataTypeImpl::GetType(), empty_tensor_deleter); + OrtValue ort_value_output(&output, DataTypeImpl::GetType(), empty_tensor_deleter); + OrtSyncStream* ort_stream = reinterpret_cast(stream); + + return ToStatusAndRelease(helper_->Transpose(helper_, permutation.data(), permutation.size(), + &ort_value_input, ort_stream, &ort_value_output)); + }; + + scan::detail::DeviceHelpers device_helpers{}; + device_helpers.transpose_func = transpose_func; + + kernel_.SetDeviceHelpers(device_helpers); +} + +PluginEpScanKernelImpl::~PluginEpScanKernelImpl() { + helper_->Release(helper_); +} + +/*static*/ +OrtStatus* ORT_API_CALL PluginEpScanKernelImpl::ComputeImpl(OrtKernelImpl* this_ptr, + OrtKernelContext* kernel_ctx) noexcept { + API_IMPL_BEGIN + auto* plugin_ep_kernel = static_cast(this_ptr); + ORT_API_RETURN_IF_STATUS_NOT_OK(plugin_ep_kernel->kernel_.Compute(reinterpret_cast(kernel_ctx))); + + return nullptr; + API_IMPL_END +} + +/*static*/ +void ORT_API_CALL PluginEpScanKernelImpl::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_control_flow_kernel_impls.h b/onnxruntime/core/session/plugin_ep/ep_control_flow_kernel_impls.h new file mode 100644 index 0000000000000..91d469b296ffe --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_control_flow_kernel_impls.h @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/session/onnxruntime_c_api.h" +#include "core/providers/cpu/controlflow/if.h" +#include "core/providers/cpu/controlflow/loop.h" +#include "core/providers/cpu/controlflow/scan.h" + +namespace onnxruntime { +/// +/// Flags that ORT can set on OrtKernelImpl instances. +/// Note: This enum can be moved to a more central location if/when we add other flags. +/// +/// IMPORTANT: When adding a new flag, update kOrtKernelImplFlags_MAX_VALUE. +/// +enum OrtKernelImplFlags : uint32_t { + // Denotes a control flow kernel created by ORT (i.e., a PluginEpControlFlowKernelImpl) + kIsControlFlowKernelImpl = 1 << 0, + + // The largest flag value. Used to validate that flags are within the expected range. + // Must be updated when a new flag is added. + kOrtKernelImplFlags_MAX_VALUE = kIsControlFlowKernelImpl +}; + +/// +/// Base class for ORT-defined OrtKernelImpl classes for control flow operators. +/// Provides polymorphic access to the controlflow::IControlFlowKernel interface, which allows setting up subgraph +/// session state. +/// +struct PluginEpControlFlowKernelImpl : public OrtKernelImpl { + PluginEpControlFlowKernelImpl(); + virtual ~PluginEpControlFlowKernelImpl() {} + virtual controlflow::IControlFlowKernel& GetIControlFlowKernel() = 0; +}; + +/// +/// OrtKernelImpl class for an If kernel. The OrtKernelImpl function calls are forwarded to an internal +/// onnxruntime::If operator kernel instance. +/// +/// An EP can create an instance of this class by calling OrtEpApi::CreateIfKernel(). +/// +class PluginEpIfKernelImpl final : public PluginEpControlFlowKernelImpl { + public: + PluginEpIfKernelImpl(const OpKernelInfo& info); + controlflow::IControlFlowKernel& GetIControlFlowKernel() override { return kernel_; } + + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + + private: + If kernel_; +}; + +/// +/// OrtKernelImpl class for a Loop kernel. The OrtKernelImpl function calls are forwarded to an internal +/// onnxruntime::Loop operator kernel instance. +/// +/// An EP can create an instance of this class by calling OrtEpApi::CreateLoopKernel(). +/// +class PluginEpLoopKernelImpl final : public PluginEpControlFlowKernelImpl { + public: + PluginEpLoopKernelImpl(const OpKernelInfo& info, gsl::not_null helper); + ~PluginEpLoopKernelImpl(); + + controlflow::IControlFlowKernel& GetIControlFlowKernel() override { return kernel_; } + + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + + private: + Loop kernel_; + gsl::not_null helper_; +}; + +/// +/// OrtKernelImpl class for a Scan kernel (opset >= 9). The OrtKernelImpl function calls are forwarded to an internal +/// onnxruntime::Scan operator kernel instance. +/// +/// An EP can create an instance of this class by calling OrtEpApi::CreateScanKernel(). +/// +class PluginEpScanKernelImpl final : public PluginEpControlFlowKernelImpl { + public: + PluginEpScanKernelImpl(const OpKernelInfo& info, gsl::not_null helper); + ~PluginEpScanKernelImpl(); + + controlflow::IControlFlowKernel& GetIControlFlowKernel() override { return kernel_; } + + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + + private: + Scan<9> kernel_; + gsl::not_null helper_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc index fe96bb577d925..625645e71cfec 100644 --- a/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc +++ b/onnxruntime/core/session/plugin_ep/ep_kernel_registration.cc @@ -12,8 +12,10 @@ #include "core/framework/error_code_helper.h" #include "core/framework/kernel_registry.h" #include "core/framework/tensor.h" +#include "core/providers/cpu/controlflow/utils.h" #include "core/session/allocator_adapters.h" #include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_control_flow_kernel_impls.h" // // OrtSharedPrePackedWeightCache @@ -54,13 +56,14 @@ namespace onnxruntime { /// /// OpKernel that wraps a OrtKernelImpl provided by a plugin EP. /// -class PluginEpOpKernel final : public OpKernel { +class PluginEpOpKernel final : public controlflow::IControlFlowKernel { private: // Prevents calling constructor directly without having to make it private (required by std::make_unique). struct PrivateTag {}; public: - PluginEpOpKernel(const OpKernelInfo& info, PrivateTag) : OpKernel{info} {} // must use ::Create() + PluginEpOpKernel(const OpKernelInfo& info, PrivateTag) + : controlflow::IControlFlowKernel{info} {} // must use ::Create() static Status Create(FuncManager& fn_manager, const OpKernelInfo& info, OrtKernelCreateFunc kernel_create_func, void* kernel_create_func_state, @@ -149,6 +152,25 @@ class PluginEpOpKernel final : public OpKernel { return Status::OK(); } + Status SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) override { + assert(kernel_impl_ != nullptr); // Should be ensured by PluginEpOpKernel::Create(). + + if ((kernel_impl_->flags & OrtKernelImplFlags::kIsControlFlowKernelImpl) == 0) { + // This is not a control flow OrtKernelImpl created by ORT, which prevents casting OrtKernelImpl to + // PluginEpControlFlowKernelImpl and setting up subgraph execution info. The plugin EP may have tried to create + // their own OrtKernelImpl, which is not supported for control flow ops. + const auto& op_type = Info().node().OpType(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "OrtKernelImpl instance for control flow operator ", op_type, + " was not originally created by ORT via an OrtEpApi function."); + } + + auto& cf_kernel = static_cast(*kernel_impl_); + return cf_kernel.GetIControlFlowKernel().SetupSubgraphExecutionInfo(session_state, attribute_name, + subgraph_session_state); + } + private: /// /// Gets the cached OrtAllocator for the given AllocatorPtr passed to PrePack(). @@ -193,7 +215,24 @@ Status PluginEpOpKernel::Create(FuncManager& /*fn_manager*/, const OpKernelInfo& ORT_RETURN_IF_ERROR(ToStatusAndRelease( kernel_create_func(kernel_create_func_state, kernel_info, &op_kernel->kernel_impl_))); - ORT_RETURN_IF(op_kernel->kernel_impl_ == nullptr, "OrtKernelCreateFunc returned a NULL OrtKernelImpl"); + + const auto& op_type = info.node().OpType(); + const auto& node_name = info.node().Name(); + const auto* ep = info.GetExecutionProvider(); + ORT_ENFORCE(ep != nullptr, "IExecutionProvider* retrieved from OpKernelInfo should never be nullptr"); + const auto& ep_name = ep->Type(); + + // Do some basic checks for the OrtKernelImpl provided by the EP. Other checks for missing function implementations + // that are only required in certain situations (e.g., pre-packing) happen later as soon as we know they are required. + ORT_RETURN_IF(op_kernel->kernel_impl_ == nullptr, "OrtKernelCreateFunc returned a NULL OrtKernelImpl for ", op_type, + " node named ", node_name, " assigned to ", ep_name); + ORT_RETURN_IF(op_kernel->kernel_impl_->flags > OrtKernelImplFlags::kOrtKernelImplFlags_MAX_VALUE, + "OrtKernelImpl::flags has been initialized to an unexpected value for ", op_type, + " node named ", node_name, " assigned to ", ep_name); + ORT_RETURN_IF(op_kernel->kernel_impl_->Compute == nullptr, "OrtKernelImpl is missing an implementation of the ", + " Compute() function for ", op_type, " node named ", node_name, " assigned to ", ep_name); + ORT_RETURN_IF(op_kernel->kernel_impl_->Release == nullptr, "OrtKernelImpl is missing an implementation of the ", + " Release() function for ", op_type, " node named ", node_name, " assigned to ", ep_name); return Status::OK(); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc index 9ce5dbbe91d75..427be1fa3081c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc @@ -66,7 +66,7 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons for (const auto& node : all_nodes) { std::string op_type = node.GetOperatorType(); - if (op_type == "Relu" || op_type == "Squeeze") { + if (op_type == "Relu" || op_type == "Squeeze" || op_type == "If" || op_type == "Loop" || op_type == "Scan") { candidate_nodes.push_back(node); } else if (op_type == "Mul" || op_type == "Sub") { std::vector inputs = node.GetInputs(); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc index 8b8cc35afe1ae..80926fac5f48e 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc @@ -22,6 +22,21 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // If versions 21, 23, and 24. + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Loop versions 21, 23, and 24. + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Scan versions 21, 23, and 24. + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; size_t GetNumKernels() { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc index fb3bf6cfdb347..04a9f8f69c52d 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc @@ -27,7 +27,7 @@ ONNX_OPERATOR_KERNEL_EX( BinaryOp) BinaryOp::BinaryOp(Ort::ConstKernelInfo info, void* state, PrivateTag) - : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + : OrtKernelImpl{}, // Initialize all OrtKernelImpl members to NULL/zero info_{info}, data_transfer_impl_{reinterpret_cast(state)} { ort_version_supported = ORT_API_VERSION; @@ -41,8 +41,7 @@ BinaryOp::BinaryOp(Ort::ConstKernelInfo info, void* state, PrivateTag) } /*static*/ -OrtStatus* BinaryOp::Create(const OrtKernelInfo* info, void* state, - /*out*/ std::unique_ptr& result) noexcept { +OrtStatus* BinaryOp::CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& result) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN Ort::ConstKernelInfo kernel_info(info); @@ -58,7 +57,8 @@ OrtStatus* BinaryOp::Create(const OrtKernelInfo* info, void* state, return Ort::GetApi().CreateStatus(ORT_EP_FAIL, oss.str().c_str()); } - result = std::make_unique(kernel_info, state, PrivateTag{}); + auto binary_op = std::make_unique(kernel_info, state, PrivateTag{}); + result = binary_op.release(); return nullptr; EXCEPTION_TO_RETURNED_STATUS_END } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h index b6cddccb22290..fcae3e5d08f0b 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h @@ -29,7 +29,7 @@ class BinaryOp : public OrtKernelImpl { }; public: - static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; + static OrtStatus* CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept; BinaryOp(Ort::ConstKernelInfo info, void* state, PrivateTag); // Static functions assigned to the OrtKernelImpl fields: diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/if.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/if.cc new file mode 100644 index 0000000000000..123d11ef076fe --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/if.cc @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "if.h" + +#include "utils.h" + +// Defines a kernel creation function for If opset 21 +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + If, + kOnnxDomain, + /*start version*/ 21, /*end version*/ 22, + (Ort::KernelDefBuilder() + .SetInputMemType(0, OrtMemTypeCPUInput) // 'cond' needs to be on CPU + .AddTypeConstraint("B", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + IfHelper) + +// Defines a kernel creation function for If opset 23 +ONNX_OPERATOR_KERNEL_EX( + If, + kOnnxDomain, + /*version*/ 23, + (Ort::KernelDefBuilder() + .SetInputMemType(0, OrtMemTypeCPUInput) // 'cond' needs to be on CPU + .AddTypeConstraint("B", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + IfHelper) + +// Defines a kernel creation function for If opset 24 +ONNX_OPERATOR_KERNEL_EX( + If, + kOnnxDomain, + /*version*/ 24, + (Ort::KernelDefBuilder() + .SetInputMemType(0, OrtMemTypeCPUInput) // 'cond' needs to be on CPU + .AddTypeConstraint("B", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + IfHelper) + +/*static*/ +OrtStatus* IfHelper::CreateKernelImpl(const OrtKernelInfo* info, void* /*state*/, + /*out*/ OrtKernelImpl*& kernel) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + RETURN_IF_ERROR(Ort::GetEpApi().CreateIfKernel(info, &kernel)); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/if.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/if.h new file mode 100644 index 0000000000000..5d44a154dc2dd --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/if.h @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../../plugin_ep_utils.h" + +struct IfHelper { + static OrtStatus* CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.cc new file mode 100644 index 0000000000000..46ccafe623778 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.cc @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "loop.h" + +#include +#include "utils.h" +#include "../ep.h" + +// Defines a kernel creation function for Loop opset 21 +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Loop, + kOnnxDomain, + /*start version*/ 21, /*end version*/ 22, + (Ort::KernelDefBuilder() + .SetInputMemType(0, OrtMemTypeCPUInput) // 'M' needs to be on CPU + .SetInputMemType(1, OrtMemTypeCPUInput) // 'cond' needs to be on CPU + .AddTypeConstraint("I", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddTypeConstraint("B", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + LoopHelper) + +// Defines a kernel creation function for Loop opset 23 +ONNX_OPERATOR_KERNEL_EX( + Loop, + kOnnxDomain, + /*version*/ 23, + (Ort::KernelDefBuilder() + .SetInputMemType(0, OrtMemTypeCPUInput) // 'M' needs to be on CPU + .SetInputMemType(1, OrtMemTypeCPUInput) // 'cond' needs to be on CPU + .AddTypeConstraint("I", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddTypeConstraint("B", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + LoopHelper) + +// Defines a kernel creation function for Loop opset 24 +ONNX_OPERATOR_KERNEL_EX( + Loop, + kOnnxDomain, + /*version*/ 24, + (Ort::KernelDefBuilder() + .SetInputMemType(0, OrtMemTypeCPUInput) // 'M' needs to be on CPU + .SetInputMemType(1, OrtMemTypeCPUInput) // 'cond' needs to be on CPU + .AddTypeConstraint("I", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddTypeConstraint("B", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + LoopHelper) + +/*static*/ +OrtStatus* LoopHelper::CreateKernelImpl(const OrtKernelInfo* ort_kernel_info, void* state, + /*out*/ OrtKernelImpl*& kernel) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + const OrtEpApi& ep_api = Ort::GetEpApi(); + Ort::ConstKernelInfo kernel_info(ort_kernel_info); + + // Ask ORT to create a OrtKernelImpl for Loop. + auto loop_helper = std::make_unique(kernel_info, state); + RETURN_IF_ERROR(ep_api.CreateLoopKernel(kernel_info, loop_helper.get(), &kernel)); + loop_helper.release(); // ORT owns this instance on successful call to CreateLoopKernel. + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +LoopHelper::LoopHelper(Ort::ConstKernelInfo info, void* state) + : OrtLoopKernelHelper{}, // Initialize all OrtLoopKernelHelper members to NULL/zero + info_{info}, + data_transfer_impl_{reinterpret_cast(state)} { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + ConcatOutput = ConcatOutputImpl; +} + +/*static*/ +void ORT_API_CALL LoopHelper::ReleaseImpl(_In_ OrtLoopKernelHelper* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ +OrtStatus* ORT_API_CALL LoopHelper::ConcatOutputImpl( + _In_ OrtLoopKernelHelper* this_ptr, + _In_opt_ void* /*stream_handle*/, + _In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs, + _In_ size_t num_per_iteration_outputs, + _Out_writes_bytes_all_(output_size_in_bytes) void* output, + _In_ size_t output_size_in_bytes) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // Concatenates loop iteration outputs. Ignores native stream handle argument as this example EP kernel + // uses CPU memory. Based on the default implementation in CPU EP. + + // An actual implementation can retrieve state from the OrtLoopKernelHelper (e.g., OrtDataTransferImpl, etc.). + LoopHelper* loop_kernel_helper = static_cast(this_ptr); + (void)loop_kernel_helper->info_; // Unused in this example. + (void)loop_kernel_helper->data_transfer_impl_; // Unused in this example. + + Ort::ConstValue first_output{per_iteration_outputs[0]}; + Ort::TensorTypeAndShapeInfo type_shape = first_output.GetTensorTypeAndShapeInfo(); + std::vector per_iteration_shape = type_shape.GetShape(); + size_t bytes_per_iteration = first_output.GetTensorSizeInBytes(); + + gsl::span output_span = gsl::make_span(static_cast(output), + output_size_in_bytes); + + for (size_t i = 0; i < num_per_iteration_outputs; i++) { + Ort::ConstValue ort_value{per_iteration_outputs[i]}; + + // Sanity check that all OrtValue's have the same amount of data. + RETURN_IF(bytes_per_iteration != ort_value.GetTensorSizeInBytes(), Ort::GetApi(), + "OrtLoopConcatOutputFunc received outputs with different sizes."); + + auto src = gsl::make_span(static_cast(ort_value.GetTensorRawData()), + bytes_per_iteration); + auto dst = output_span.subspan(i * bytes_per_iteration, bytes_per_iteration); + gsl::copy(src, dst); + } + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.h new file mode 100644 index 0000000000000..b1b52956afed2 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../../plugin_ep_utils.h" + +class LoopHelper : public OrtLoopKernelHelper { + public: + static OrtStatus* CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept; + LoopHelper(Ort::ConstKernelInfo info, void* state); + + // Static functions assigned to the OrtLoopKernelHelper fields: + static void ORT_API_CALL ReleaseImpl(_In_ OrtLoopKernelHelper* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL ConcatOutputImpl( + _In_ OrtLoopKernelHelper* this_ptr, + _In_opt_ void* stream_handle, + _In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs, + _In_ size_t num_per_iteration_outputs, + _Out_writes_bytes_all_(output_size_in_bytes) void* output, + _In_ size_t output_size_in_bytes) noexcept; + + private: + Ort::ConstKernelInfo info_; + OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc index 89f52c4b53dc3..f199fdd087142 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc @@ -20,7 +20,7 @@ ONNX_OPERATOR_KERNEL_EX( Relu) Relu::Relu(const OrtKernelInfo* info, void* /*state*/, PrivateTag) - : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + : OrtKernelImpl{}, // Initialize all OrtKernelImpl members to NULL/zero info_{info} { ort_version_supported = ORT_API_VERSION; Compute = ComputeImpl; @@ -28,10 +28,12 @@ Relu::Relu(const OrtKernelInfo* info, void* /*state*/, PrivateTag) } /*static*/ -OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept { +OrtStatus* Relu::CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN Ort::ConstKernelInfo kernel_info(info); - kernel = std::make_unique(info, state, PrivateTag{}); + auto relu_kernel = std::make_unique(info, state, PrivateTag{}); + + kernel = relu_kernel.release(); return nullptr; EXCEPTION_TO_RETURNED_STATUS_END } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h index cdeb450435c29..a1187bfd34521 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h @@ -10,7 +10,7 @@ class Relu : public OrtKernelImpl { struct PrivateTag {}; public: - static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; + static OrtStatus* CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept; Relu(const OrtKernelInfo* info, void* state, PrivateTag); // Static functions assigned to the OrtKernelImpl fields: diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.cc new file mode 100644 index 0000000000000..5d5b4e8cff5cf --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.cc @@ -0,0 +1,117 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "scan.h" + +#include +#include "utils.h" + +// Defines a kernel creation function for Scan opset 21 +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Scan, + kOnnxDomain, + /*start version*/ 21, /*end version*/ 22, + (Ort::KernelDefBuilder() + // 'I' is in the ONNX spec but is not used for any inputs or outputs + // .AddTypeConstraint("I", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + ScanHelper) + +// Defines a kernel creation function for Scan opset 23 +ONNX_OPERATOR_KERNEL_EX( + Scan, + kOnnxDomain, + /*version*/ 23, + (Ort::KernelDefBuilder() + // 'I' is in the ONNX spec but is not used for any inputs or outputs + // .AddTypeConstraint("I", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + ScanHelper) + +// Defines a kernel creation function for Scan opset 24 +ONNX_OPERATOR_KERNEL_EX( + Scan, + kOnnxDomain, + /*version*/ 24, + (Ort::KernelDefBuilder() + // 'I' is in the ONNX spec but is not used for any inputs or outputs + // .AddTypeConstraint("I", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddTypeConstraint("V", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))), + ScanHelper) + +/*static*/ +OrtStatus* ScanHelper::CreateKernelImpl(const OrtKernelInfo* ort_kernel_info, void* state, + /*out*/ OrtKernelImpl*& kernel) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + const OrtEpApi& ep_api = Ort::GetEpApi(); + Ort::ConstKernelInfo kernel_info(ort_kernel_info); + + // Ask ORT to create a OrtKernelImpl for Scan. + auto scan_helper = std::make_unique(kernel_info, state); + RETURN_IF_ERROR(ep_api.CreateScanKernel(kernel_info, scan_helper.get(), &kernel)); + scan_helper.release(); // ORT owns this instance on successful call to CreateScanKernel. + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +ScanHelper::ScanHelper(Ort::ConstKernelInfo info, void* state) + : OrtScanKernelHelper{}, // Initialize all OrtScanKernelHelper members to NULL/zero + info_{info}, + data_transfer_impl_{reinterpret_cast(state)} { + ort_version_supported = ORT_API_VERSION; + Release = ReleaseImpl; + Transpose = TransposeImpl; +} + +/*static*/ +void ORT_API_CALL ScanHelper::ReleaseImpl(_In_ OrtScanKernelHelper* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +/*static*/ +OrtStatus* ORT_API_CALL ScanHelper::TransposeImpl(_In_ OrtScanKernelHelper* this_ptr, + _In_reads_(num_permutation_elems) const size_t* permutation, + _In_ size_t num_permutation_elems, + _In_ const OrtValue* ort_input, _In_opt_ OrtSyncStream* /*stream*/, + _Inout_ OrtValue* ort_output) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + // An actual implementation can retrieve state from the OrtScanKernelHelper (e.g., OrtDataTransferImpl, etc.). + ScanHelper* scan_kernel_helper = static_cast(this_ptr); + (void)scan_kernel_helper->info_; // Unused in this example. + (void)scan_kernel_helper->data_transfer_impl_; // Unused in this example. + + Ort::ConstValue input(ort_input); + Ort::UnownedValue output(ort_output); + gsl::span perm(permutation, num_permutation_elems); + + // Note: This example implementation only supports 2D transpose (perm: [1, 0]) for convenience. A correct implementation + // should support more general dimensions/permutations. + RETURN_IF(perm.size() != 2 || perm[0] != 1 || perm[1] != 0, Ort::GetApi(), + "Scan kernel for ExampleKernelEp only supports 2D transpose."); + + Ort::TensorTypeAndShapeInfo input_type_shape = input.GetTensorTypeAndShapeInfo(); + Ort::TensorTypeAndShapeInfo output_type_shape = output.GetTensorTypeAndShapeInfo(); + std::vector input_shape = input_type_shape.GetShape(); + size_t num_elems = input_type_shape.GetElementCount(); + + RETURN_IF(output_type_shape.GetElementCount() != num_elems, Ort::GetApi(), + "Expected input and output of Scan's transpose helper to have the same number of elements"); + + gsl::span src(input.GetTensorData(), num_elems); + gsl::span dst(output.GetTensorMutableData(), num_elems); + + size_t num_rows = static_cast(input_shape[0]); + size_t num_cols = static_cast(input_shape[1]); + + for (size_t r = 0; r < num_rows; r++) { + for (size_t c = 0; c < num_cols; c++) { + size_t src_idx = r * num_cols + c; + size_t dst_idx = c * num_rows + r; + dst[dst_idx] = src[src_idx]; + } + } + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.h new file mode 100644 index 0000000000000..cdd7ff4283d2c --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../../plugin_ep_utils.h" + +class ScanHelper : public OrtScanKernelHelper { + public: + static OrtStatus* CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept; + ScanHelper(Ort::ConstKernelInfo info, void* state); + + // Static functions assigned to the OrtScanKernelHelper fields: + static void ORT_API_CALL ReleaseImpl(_In_ OrtScanKernelHelper* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL TransposeImpl(_In_ OrtScanKernelHelper* this_ptr, + _In_reads_(num_permutation_elems) const size_t* permutation, + _In_ size_t num_permutation_elems, + _In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, + _Inout_ OrtValue* output) noexcept; + + private: + Ort::ConstKernelInfo info_; + OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc index 3d6a2527476e8..eeef5e696a7b1 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc @@ -50,7 +50,7 @@ ONNX_OPERATOR_KERNEL_EX( Squeeze) Squeeze::Squeeze(const OrtKernelInfo* info, void* state, PrivateTag) - : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + : OrtKernelImpl{}, // Initialize all OrtKernelImpl members to NULL/zero info_{info}, data_transfer_impl_{reinterpret_cast(state)} { ort_version_supported = ORT_API_VERSION; @@ -59,10 +59,12 @@ Squeeze::Squeeze(const OrtKernelInfo* info, void* state, PrivateTag) } /*static*/ -OrtStatus* Squeeze::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept { +OrtStatus* Squeeze::CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN Ort::ConstKernelInfo kernel_info(info); - kernel = std::make_unique(info, state, PrivateTag{}); + auto squeeze_kernel = std::make_unique(info, state, PrivateTag{}); + + kernel = squeeze_kernel.release(); return nullptr; EXCEPTION_TO_RETURNED_STATUS_END } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h index d179b95d73f80..3aa8d6d5b050a 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h @@ -10,7 +10,7 @@ class Squeeze : public OrtKernelImpl { struct PrivateTag {}; public: - static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; + static OrtStatus* CreateKernelImpl(const OrtKernelInfo* info, void* state, /*out*/ OrtKernelImpl*& kernel) noexcept; Squeeze(const OrtKernelInfo* info, void* state, PrivateTag); // Static functions assigned to the OrtKernelImpl fields: diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h index 506392abb6149..ffad70ec5ebda 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h @@ -106,11 +106,11 @@ static constexpr const char* kOnnxDomain = ""; \ auto kernel_create_func = [](void* state, const OrtKernelInfo* info, \ OrtKernelImpl** kernel_out) noexcept -> OrtStatus* { \ - *kernel_out = nullptr; \ + RETURN_IF(kernel_out == nullptr, Ort::GetApi(), \ + "OrtKernelCreateFunc received a NULL kernel_out argument"); \ \ - std::unique_ptr kernel; \ - RETURN_IF_ERROR(kernel_class::Create(info, state, kernel)); \ - *kernel_out = kernel.release(); \ + *kernel_out = nullptr; \ + RETURN_IF_ERROR(kernel_class::CreateKernelImpl(info, state, *kernel_out)); \ return nullptr; \ }; \ \ diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 437ca37c1a7b6..c20a5455e5eae 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -110,6 +110,114 @@ void RunSubMulSubModel(const Ort::SessionOptions& session_options) { EXPECT_THAT(output_span, ::testing::ElementsAre(-3, -5, -7, -9, -11, -13)); } +void RunIfMulModel(const Ort::SessionOptions& session_options, bool if_condition) { + // Model graph does the following computation: + // if (A) { C = B * 2.0; } + // else { C = B * 3; } + Ort::Session session(*ort_env, ORT_TSTR("testdata/if_mul.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::array a_shape = {1}; + std::array b_shape = {3, 2}; + + std::array a_data = {if_condition}; + std::array b_data = {2.f, 3.f, 4.f, -5.f, 6.f, 7.f}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, b_data.data(), b_data.size(), b_shape.data(), b_shape.size())); + + std::array ort_input_names{"A", "B"}; + + // Run session and get outputs + std::array output_names{"C"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + + if (if_condition) { + // Expect that model multiplies input B elements by 2.0. + EXPECT_THAT(output_span, ::testing::ElementsAre(4.f, 6.f, 8.f, -10.f, 12.f, 14.f)); + } else { + // Expect that model multiplies input B elements by 3.0. + EXPECT_THAT(output_span, ::testing::ElementsAre(6.f, 9.f, 12.f, -15.f, 18.f, 21.f)); + } +} + +void RunLoopSubOneModel(const Ort::SessionOptions& session_options) { + // Model graph does the following computation: + // x = A + // for (int i = 0; i < MAX_ITERS; i++) { + // y = x - 1.0; + // user_val = x - 1.0; + // x = y; + // } + // C = x; + // D = user_val (will be concatenated result of each iteration) + Ort::Session session(*ort_env, ORT_TSTR("testdata/loop_sub_one.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::array max_iters_shape = {1}; + std::array a_shape = {1}; + + std::array max_iters_data = {3}; + std::array a_data = {10.0f}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, max_iters_data.data(), max_iters_data.size(), + max_iters_shape.data(), max_iters_shape.size())); + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size())); + + std::array ort_input_names{"MAX_ITERS", "A"}; + + // Run session and get outputs + std::array output_names{"C", "D"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + // Expect that input elems are all subtracted by 3 (1 each iteration). + gsl::span output_c_span(ort_outputs[0].GetTensorData(), 1); + EXPECT_THAT(output_c_span, ::testing::ElementsAre(7.f)); + + gsl::span output_d_span(ort_outputs[1].GetTensorData(), 3); + EXPECT_THAT(output_d_span, ::testing::ElementsAre(9.f, 8.f, 7.f)); +} + +void RunScanMulModel(const Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/scan_mul.onnx"), session_options); + + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::array x_shape = {3, 3}; + std::array x_data = {1.f, 2.f, 3.f, 10.f, 20.f, 30.f, 100.f, 200.f, 300.f}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, x_data.data(), x_data.size(), x_shape.data(), x_shape.size())); + + std::array ort_input_names{"X"}; + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + gsl::span output_span(ort_outputs[0].GetTensorData(), 9); + EXPECT_THAT(output_span, ::testing::ElementsAre(2.f, 4.f, 6.f, 20.f, 40.f, 60.f, 200.f, 400.f, 600.f)); +} + void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) { // This model has Add -> Mul -> Add. The example plugin EP supports Mul but not Add. Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options); @@ -352,5 +460,94 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) { ASSERT_NO_FATAL_FAILURE(RunSubMulSubModel(session_options)); } } + +TEST(OrtEpLibrary, KernelPluginEp_ControlFlow_If) { + RegisteredEpDeviceUniquePtr example_kernel_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_kernel_registry_info, + example_kernel_ep)); + Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get()); + + // Run model with If and Mul ops. + // No sharing of pre-packed weights. + { + std::unordered_map ep_options; + Ort::SessionOptions session_options; + + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP. + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunIfMulModel(session_options, /*if_condition*/ true)); + ASSERT_NO_FATAL_FAILURE(RunIfMulModel(session_options, /*if_condition*/ false)); + } + + // Run model with If and Mul ops. + // Enable sharing of pre-packed weights. + { + std::unordered_map ep_options = {{"enable_prepack_weight_sharing", "1"}}; + Ort::SessionOptions session_options; + + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunIfMulModel(session_options, /*if_condition*/ true)); + ASSERT_NO_FATAL_FAILURE(RunIfMulModel(session_options, /*if_condition*/ false)); + } +} + +TEST(OrtEpLibrary, KernelPluginEp_ControlFlow_Loop) { + RegisteredEpDeviceUniquePtr example_kernel_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_kernel_registry_info, + example_kernel_ep)); + Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get()); + + // Run model with Loop and Sub ops. + // No sharing of pre-packed weights. + { + std::unordered_map ep_options; + Ort::SessionOptions session_options; + + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunLoopSubOneModel(session_options)); + } + + // Run model with Loop and Sub ops. + // Enable sharing of pre-packed weights. + { + std::unordered_map ep_options = {{"enable_prepack_weight_sharing", "1"}}; + Ort::SessionOptions session_options; + + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunLoopSubOneModel(session_options)); + } +} + +TEST(OrtEpLibrary, KernelPluginEp_ControlFlow_Scan) { + RegisteredEpDeviceUniquePtr example_kernel_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_kernel_registry_info, + example_kernel_ep)); + Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get()); + + // Run model with Scan and Mul ops. + // No sharing of pre-packed weights. + { + std::unordered_map ep_options; + Ort::SessionOptions session_options; + + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunScanMulModel(session_options)); + } + + // Run model with Scan and Mul ops. + // Enable sharing of pre-packed weights. + { + std::unordered_map ep_options = {{"enable_prepack_weight_sharing", "1"}}; + Ort::SessionOptions session_options; + + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunScanMulModel(session_options)); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/if_mul.onnx b/onnxruntime/test/testdata/if_mul.onnx new file mode 100644 index 0000000000000..fddfa83f3c32d Binary files /dev/null and b/onnxruntime/test/testdata/if_mul.onnx differ diff --git a/onnxruntime/test/testdata/if_mul.py b/onnxruntime/test/testdata/if_mul.py new file mode 100644 index 0000000000000..f2a557dedcdd2 --- /dev/null +++ b/onnxruntime/test/testdata/if_mul.py @@ -0,0 +1,70 @@ +from onnx import TensorProto, checker, helper, save, shape_inference + +# if A: C = B * 2 +# else: C = B * 3 + +if_then_branch = helper.make_graph( + nodes=[ + helper.make_node( + "Mul", + inputs=["B", "ConstTwo"], + outputs=["if_output"], + name="mul_0", + ), + ], + name="if_then_branch", + inputs=[ + # No explicit inputs + ], + outputs=[ + helper.make_tensor_value_info("if_output", TensorProto.FLOAT, [3, 2]), + ], +) + +if_else_branch = helper.make_graph( + nodes=[ + helper.make_node( + "Mul", + inputs=["B", "ConstThree"], + outputs=["if_output"], + name="mul_1", + ), + ], + name="if_else_branch", + inputs=[ + # No explicit inputs + ], + outputs=[ + helper.make_tensor_value_info("if_output", TensorProto.FLOAT, [3, 2]), + ], +) + +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "If", + inputs=["A"], + outputs=["C"], + name="if_0", + then_branch=if_then_branch, + else_branch=if_else_branch, + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("A", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + outputs=[ + helper.make_tensor_value_info("C", TensorProto.FLOAT, [3, 2]), + ], + initializer=[ + helper.make_tensor("ConstTwo", TensorProto.FLOAT, [3, 2], [2.0] * 6), + helper.make_tensor("ConstThree", TensorProto.FLOAT, [3, 2], [3.0] * 6), + ], +) + +model = helper.make_model(graph_proto) +model = shape_inference.infer_shapes(model) +checker.check_model(model, True) +save(model, "if_mul.onnx") diff --git a/onnxruntime/test/testdata/loop_sub_one.onnx b/onnxruntime/test/testdata/loop_sub_one.onnx new file mode 100644 index 0000000000000..838a26e594be0 Binary files /dev/null and b/onnxruntime/test/testdata/loop_sub_one.onnx differ diff --git a/onnxruntime/test/testdata/loop_sub_one.py b/onnxruntime/test/testdata/loop_sub_one.py new file mode 100644 index 0000000000000..f7bbe0f3fc068 --- /dev/null +++ b/onnxruntime/test/testdata/loop_sub_one.py @@ -0,0 +1,69 @@ +from onnx import TensorProto, checker, helper, save, shape_inference + +""" +x = A +for (int i = 0; i < MAX_ITERS; i++) { + y = x - 1.0 + user_val = x - 1.0 + x = y +} +C = x +D = user_val (will be the concatenated result of all iterations) +""" + +loop_body = helper.make_graph( + nodes=[ + helper.make_node( + "Sub", + inputs=["loop_state_in", "ConstOne"], + outputs=["loop_state_out"], + name="sub_0", + ), + helper.make_node( + "Sub", + inputs=["loop_state_in", "ConstOne"], + outputs=["user_defined_val"], + name="sub_1", + ), + ], + name="loop_body", + inputs=[ + helper.make_tensor_value_info("index", TensorProto.INT64, [1]), + helper.make_tensor_value_info("subgraph_keep_going_in", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_in", TensorProto.FLOAT, [1]), + ], + outputs=[ + helper.make_tensor_value_info("subgraph_keep_going_in", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_out", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("user_defined_val", TensorProto.FLOAT, [1]), + ], + initializer=[ + helper.make_tensor("ConstOne", TensorProto.FLOAT, [1], [1.0]), + ], +) + +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "Loop", + inputs=["MAX_ITERS", "", "A"], + outputs=["C", "D"], + name="loop_0", + body=loop_body, + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("MAX_ITERS", TensorProto.INT64, [1]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, [1]), + ], + outputs=[ + helper.make_tensor_value_info("C", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("D", TensorProto.FLOAT, None), + ], +) + +model = helper.make_model(graph_proto) +model = shape_inference.infer_shapes(model) +checker.check_model(model, True) +save(model, "loop_sub_one.onnx") diff --git a/onnxruntime/test/testdata/scan_mul.onnx b/onnxruntime/test/testdata/scan_mul.onnx new file mode 100644 index 0000000000000..ab75d117fe564 Binary files /dev/null and b/onnxruntime/test/testdata/scan_mul.onnx differ diff --git a/onnxruntime/test/testdata/scan_mul.py b/onnxruntime/test/testdata/scan_mul.py new file mode 100644 index 0000000000000..62e3274f32e13 --- /dev/null +++ b/onnxruntime/test/testdata/scan_mul.py @@ -0,0 +1,51 @@ +from onnx import TensorProto, checker, helper, save, shape_inference + +# Scan body: y_t = x_t * 2.0 +scan_body = helper.make_graph( + nodes=[ + helper.make_node( + "Mul", + inputs=["x_t", "ConstTwo"], + outputs=["y_t"], + name="mul_0", + ), + ], + name="scan_body", + inputs=[ + helper.make_tensor_value_info("x_t", TensorProto.FLOAT, [3]), + ], + outputs=[ + helper.make_tensor_value_info("y_t", TensorProto.FLOAT, [3]), + ], + initializer=[ + helper.make_tensor("ConstTwo", TensorProto.FLOAT, [3], [2.0] * 3), + ], +) + +# Top graph: Y = Scan(X) +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "Scan", + inputs=["X"], + outputs=["Y"], + name="scan_0", + body=scan_body, + num_scan_inputs=1, + scan_input_axes=[1], + scan_output_axes=[1], + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 3]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 3]), + ], +) + +model = helper.make_model(graph_proto) +model = shape_inference.infer_shapes(model) +checker.check_model(model, True) +save(model, "scan_mul.onnx")