Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2134,7 +2134,13 @@ if (onnxruntime_BUILD_SHARED_LIB AND
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc")
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/if.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/if.cc")
onnxruntime_add_shared_library_module(example_plugin_ep_kernel_registry ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src})
target_include_directories(example_plugin_ep_kernel_registry PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep_kernel_registry PRIVATE onnxruntime ${GSL_TARGET})
Expand Down
193 changes: 192 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,11 @@ typedef struct OrtKernelImpl OrtKernelImpl;
*/
struct OrtKernelImpl {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
uint32_t flags; ///< EP must initialize to 0. Used internally by ORT.

/** \brief Computation function called to execute the kernel on an EP.
*
* \note Implementation of this function is required.
*
* \param[in] this_ptr The OrtKernelImpl instance.
* \param[in] context The OrtKernelContext instance that provides access to the inputs and outputs.
Expand All @@ -537,6 +540,8 @@ struct OrtKernelImpl {
ORT_API2_STATUS(Compute, _In_ OrtKernelImpl* this_ptr, _In_ OrtKernelContext* context);

/** \brief Called by ORT to release the OrtKernelImpl instance and its resources.
*
* \note Implementation of this function is required.
*
* \param[in] this_ptr The OrtKernelImpl instance.
*
Expand Down Expand Up @@ -645,7 +650,9 @@ struct OrtKernelImpl {
* \param[in] kernel_create_func_state Opaque state initially provided by the EP that registered the kernel.
* Refer to OrtEpApi::KernelRegistry_AddKernel(). May be null.
* \param[in] info The OrtKernelInfo instance that provides access to the kernel's input and output characteristics.
* \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance.
* \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. On success, ownership of this
* OrtKernelImpl instance transfers to ORT, which will call OrtKernelImpl::Release() to
* release the instance when it is no longer used.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
Expand All @@ -655,6 +662,89 @@ typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ void* kernel_create_f
_In_ const OrtKernelInfo* info,
_Outptr_result_maybenull_ OrtKernelImpl** kernel_out);

struct OrtLoopKernelHelper;
typedef struct OrtLoopKernelHelper OrtLoopKernelHelper;

/**
* \brief Contains helper functions for a Loop OrtKernelImpl created via ::CreateLoopKernel().
* \since Version 1.24.
*/
struct OrtLoopKernelHelper {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION

/** \brief Called by ORT to release the OrtLoopKernelHelper instance and its resources.
*
* \param[in] this_ptr The OrtLoopKernelHelper instance.
*
* \since Version 1.24.
*/
ORT_API_T(void, Release, _In_ OrtLoopKernelHelper* this_ptr);

/** \brief Helper function that concatenates OrtValue instances from each loop iteration into a single
* pre-allocated output buffer.
*
* \note Implementing this function is required for all Loop opset versions.
*
* \param[in] this_ptr The OrtLoopKernelHelper instance.
* \param[in] stream_handle Optional native stream handle that enables asynchronous operations. May be NULL.
* \param[in] per_iteration_outputs Array of OrtValue instances from each iteration. All OrtValue elements have the
* same shape.
* \param[in] num_per_iteration_outputs The number of OrtValue* elements in the `per_iteration_outputs` array.
* \param[out] output The pre-allocated output buffer. Memory is allocated on the device for the EP running the
* Loop node.
* \param[in] output_size_in_bytes The size in bytes of the `output` buffer. It is guaranteed to be large enough
* to hold the concatenated data of each element in `per_iteration_outputs`.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(ConcatOutput, _In_ OrtLoopKernelHelper* this_ptr, _In_opt_ void* stream_handle,
_In_reads_(num_per_iteration_outputs) const OrtValue* const* per_iteration_outputs,
_In_ size_t num_per_iteration_outputs, _Out_writes_bytes_all_(output_size_in_bytes) void* output,
_In_ size_t output_size_in_bytes);
};

struct OrtScanKernelHelper;
typedef struct OrtScanKernelHelper OrtScanKernelHelper;

/**
* \brief Contains helper functions for a Scan OrtKernelImpl created via ::CreateScanKernel().
* \since Version 1.24.
*/
struct OrtScanKernelHelper {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION

/** \brief Called by ORT to release the OrtScanKernelHelper instance and its resources.
*
* \param[in] this_ptr The OrtScanKernelHelper instance.
*
* \since Version 1.24.
*/
ORT_API_T(void, Release, _In_ OrtScanKernelHelper* this_ptr);

/** \brief Helper function that transposes an OrtValue instance during execution of a Scan kernel.
*
* \note Called for Scan (opset >= 9) when the 'scan_input_axes' or 'scan_output_axes' attributes contain
* non-zero values. Implementing this function is required for Scan opset versions >= 9.
*
* \param[in] this_ptr The OrtScanKernelHelper instance.
* \param[in] permutation An array of integers that defines how the input tensor's axes should be permuted.
* \param[in] num_permutation_elems The number of integer elements in the `permutation` array.
* \param[in] input The input OrtValue tensor to transpose.
* \param[in] stream An optional OrtSyncStream instance to be used for asynchronous operations. May be NULL.
* \param[out] output The pre-allocated output OrtValue instance into which to store the results of the
* transpose operation. Must not be released as it is owned by ORT.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(Transpose, _In_ OrtScanKernelHelper* this_ptr,
_In_reads_(num_permutation_elems) const size_t* permutation, _In_ size_t num_permutation_elems,
_In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, _Inout_ OrtValue* output);
};

/**
* \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider.
*
Expand Down Expand Up @@ -1217,6 +1307,107 @@ struct OrtEpApi {
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep);

/** \brief Creates an OrtKernelImpl instance for an If operator.
*
* Control flow operators require access to ORT session internals to orchestrate subgraph operations.
* This function allows an EP to create a properly configured OrtKernelImpl with access to ORT internals that
* the EP can add to its kernel registry.
*
* An EP is required to create an OrtKernelDef that keeps input[0] ('cond') on the CPU (i.e., OrtMemTypeCPUInput)
* as this input is used by CPU logic. The output should remain on the device (i.e., OrtMemTypeDefault), which is
* the default setting, to avoid copying to/from CPU.
*
* Example kernel definition (CXX API):
* Ort::KernelDef kernel_def = Ort::KernelDefBuilder()
* .SetDomain("").SetOperatorType("If").SetSinceVersion(21, 22)
* .SetExecutionProvider("MyEp")
* .SetInputMemType(0, OrtMemTypeCPUInput) // 'cond' on CPU
* .SetOutputMemType(0, OrtMemTypeDefault) // output on EP device
* .AddTypeConstraint("B", ...)
* .AddTypeConstraint("V", ...).Build();
*
* \param[in] kernel_info The ::OrtKernelInfo instance for an If node. This function returns error ORT_FAIL
* if the opset version specified by `kernel_info` is unsupported.
* \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the If node.
* Must be released via ::ReleaseKernelImpl, unless ownership is transferred
* to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out);

/** \brief Creates an OrtKernelImpl instance for a Loop operator.
*
* Control flow operators require access to ORT session internals to orchestrate subgraph operations.
* This function allows an EP to create a properly configured OrtKernelImpl with access to ORT internals that
* the EP can add to its kernel registry.
*
* An EP is required to create an OrtKernelDef that keeps input[0] ('M') and input[1] ('cond') on the CPU
* (i.e., OrtMemTypeCPUInput) as these inputs are used by CPU logic. Input[2] ('v_initial') and the output should
* remain on the device (i.e., OrtMemTypeDefault), which is the default setting, to avoid copying to/from CPU.
*
* Example kernel definition (CXX API):
* Ort::KernelDef kernel_def = Ort::KernelDefBuilder()
* .SetDomain("").SetOperatorType("Loop").SetSinceVersion(21, 22)
* .SetExecutionProvider("MyEp")
* .SetInputMemType(0, OrtMemTypeCPUInput) // 'M' on CPU
* .SetInputMemType(1, OrtMemTypeCPUInput) // 'cond' on CPU
* .SetInputMemType(2, OrtMemTypeDefault) // 'v_initial' on EP device
* .SetOutputMemType(0, OrtMemTypeDefault) // output on EP device
* .AddTypeConstraint("I", ...)
* .AddTypeConstraint("B", ...)
* .AddTypeConstraint("V", ...).Build();
*
* \param[in] kernel_info The ::OrtKernelInfo instance for a Loop node. This function returns error ORT_FAIL
* if the opset version specified by `kernel_info` is unsupported.
* \param[in] helper A OrtLoopKernelHelper instance that contains helper functions that ORT calls during kernel
* execution to operate on tensors allocated with the EP's device memory.
* ORT will call OrtLoopKernelHelper::Release() to release the helper and its resources.
* \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the Loop node.
* Must be released via ::ReleaseKernelImpl, unless ownership is transferred
* to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtLoopKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out);

/** \brief Creates an OrtKernelImpl instance for a Scan operator. Does not support opset versions older than 9.
*
* Control flow operators require access to ORT session internals to orchestrate subgraph operations.
* This function allows an EP to create a properly configured OrtKernelImpl with access to ORT internals that
* the EP can add to its kernel registry.
*
* It is recommended that an EP create an OrtKernelDef that keeps the inputs and outputs on the EP's
* device (i.e., OrtMemTypeDefault), which is the default setting, to avoid copying to/from CPU.
*
* Example kernel definition (CXX API):
* Ort::KernelDef kernel_def = Ort::KernelDefBuilder()
* .SetDomain("").SetOperatorType("Scan").SetSinceVersion(21, 22)
* .SetExecutionProvider("MyEp")
* .SetInputMemType(0, OrtMemTypeDefault) // input[0] on EP device
* .SetOutputMemType(0, OrtMemTypeDefault) // output[0] on EP device
* .AddTypeConstraint("V", ...).Build();
*
* \param[in] kernel_info The ::OrtKernelInfo instance for a Scan node. This function returns error ORT_FAIL
* if the opset version specified by `kernel_info` is unsupported.
* \param[in] helper A OrtScanKernelHelper instance that contains helper functions that ORT calls during kernel
* execution to operate on tensors allocated with the EP's device memory.
* ORT will call OrtScanKernelHelper::Release() to release the helper and its resources.
* \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the Scan node.
* Must be released via ::ReleaseKernelImpl, unless ownership is transferred
* to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out);

ORT_CLASS_RELEASE(KernelImpl);
};

/**
Expand Down
14 changes: 9 additions & 5 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1183,11 +1183,15 @@ const InlinedHashSet<NodeIndex>* SessionState::GetToBeExecutedRange(
Status SessionState::CreateSubgraphSessionState() {
for (auto& node : graph_.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
const auto& ep = node.GetExecutionProviderType();
if (!ep.empty() &&
ep != kCpuExecutionProvider && ep != kCudaExecutionProvider &&
ep != kDmlExecutionProvider &&
ep != kJsExecutionProvider && ep != kWebGpuExecutionProvider) {
const auto& ep_type = node.GetExecutionProviderType();
const IExecutionProvider* ep = execution_providers_.Get(ep_type);
const bool is_plugin_ep = ep != nullptr && ep->GetOrtEp() != nullptr;

if (!ep_type.empty() &&
ep_type != kCpuExecutionProvider && ep_type != kCudaExecutionProvider &&
ep_type != kDmlExecutionProvider &&
ep_type != kJsExecutionProvider && ep_type != kWebGpuExecutionProvider &&
!is_plugin_ep) {
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
// node containing the subgraph it will create whatever state it needs internally.
continue;
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/cpu/controlflow/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class Loop : public controlflow::IControlFlowKernel {

static std::unique_ptr<OpKernel> Create(const OpKernelInfo& info, const ConcatOutput& concat_output_func, void* stream);

protected:
// derived class can provide implementation for handling concatenation of Loop output on a different device
void SetConcatOutputFunc(const ConcatOutput& concat_output_func) { concat_output_func_ = concat_output_func; }

Expand Down
Loading
Loading