Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2134,7 +2134,13 @@ if (onnxruntime_BUILD_SHARED_LIB AND
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc")
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/binary_op.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/scan.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/loop.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/if.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/if.cc")
onnxruntime_add_shared_library_module(example_plugin_ep_kernel_registry ${onnxruntime_autoep_test_example_plugin_ep_kernel_registry_src})
target_include_directories(example_plugin_ep_kernel_registry PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session)
target_link_libraries(example_plugin_ep_kernel_registry PRIVATE onnxruntime ${GSL_TARGET})
Expand Down
154 changes: 152 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,16 @@ struct OrtNodeComputeInfo {
void(ORT_API_CALL* ReleaseState)(_In_ OrtNodeComputeInfo* this_ptr, _Frees_ptr_opt_ void* compute_state);
};

/**
* \brief Used to denote the creator of an OrtKernelImpl. Most OrtKernelImpl instances are created by
* EPs. However, OrtKernelImpl instances for control flow operators (e.g., If, Loop, and Scan) are created by ORT.
* \since Version 1.24.
*/
typedef enum OrtKernelImplCreator {
ORT_KERNEL_IMPL_CREATOR_EP = 0,
ORT_KERNEL_IMPL_CREATOR_ORT = 1,
} OrtKernelImplCreator;

struct OrtKernelImpl;
typedef struct OrtKernelImpl OrtKernelImpl;

Expand All @@ -290,6 +300,7 @@ typedef struct OrtKernelImpl OrtKernelImpl;
*/
struct OrtKernelImpl {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION
uint32_t creator; ///< EP must set to ORT_KERNEL_IMPL_CREATOR_EP (0) if it created the OrtKernelImpl

/** \brief Computation function called to execute the kernel on an EP.
*
Expand Down Expand Up @@ -411,7 +422,9 @@ struct OrtKernelImpl {
* \param[in] kernel_create_func_state Opaque state initially provided by the EP that registered the kernel.
* Refer to OrtEpApi::KernelRegistry_AddKernel(). May be null.
* \param[in] info The OrtKernelInfo instance that provides access to the kernel's input and output characteristics.
* \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance.
* \param[out] kernel_out Output parameter set to the new OrtKernelImpl instance. On success, ownership of this
* OrtKernelImpl instance transfers to ORT, which will call OrtKernelImpl::Release() to
* release the instance when it is no longer used.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
Expand All @@ -421,6 +434,82 @@ typedef OrtStatus*(ORT_API_CALL* OrtKernelCreateFunc)(_In_ void* kernel_create_f
_In_ const OrtKernelInfo* info,
_Outptr_result_maybenull_ OrtKernelImpl** kernel_out);

struct OrtLoopKernelHelper;
typedef struct OrtLoopKernelHelper OrtLoopKernelHelper;

/**
* \brief Contains helper functions for a Loop OrtKernelImpl created via ::CreateLoopKernel().
* \since Version 1.24.
*/
struct OrtLoopKernelHelper {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION

/** \brief Called by ORT to release the OrtLoopKernelHelper instance and its resources.
*
* \param[in] this_ptr The OrtLoopKernelHelper instance.
*
* \since Version 1.24.
*/
ORT_API_T(void, Release, _In_ OrtLoopKernelHelper* this_ptr);

/** \brief Helper function that concatenates OrtValue instances from each loop iteration into a single
* pre-allocated output buffer.
*
* \param[in] this_ptr The OrtLoopKernelHelper instance.
* \param[in] stream_handle Optional native stream handle that enables asynchronous operations. May be NULL.
* \param[in] per_iteration_output Array of OrtValue instances from each iteration. All OrtValue elements have the
* same shape.
* \param[in] num_iteration_outputs The number of OrtValue* elements in the `per_iteration_output` array.
* \param[out] output The pre-allocated output buffer. Memory is allocated on the device for the EP running the
* Loop node.
* \param[in] output_size_in_bytes The size in bytes of the `output` buffer.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(ConcatOutput, _In_ OrtLoopKernelHelper* this_ptr, _In_opt_ void* stream_handle,
_In_ OrtValue* const* per_iteration_output, _In_ size_t num_iteration_outputs,
_Out_writes_bytes_all_(output_size_in_bytes) void* output, _In_ size_t output_size_in_bytes);
};

struct OrtScanKernelHelper;
typedef struct OrtScanKernelHelper OrtScanKernelHelper;

/**
* \brief Contains helper functions for a Scan OrtKernelImpl created via ::CreateScanKernel().
* \since Version 1.24.
*/
struct OrtScanKernelHelper {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION

/** \brief Called by ORT to release the OrtScanKernelHelper instance and its resources.
*
* \param[in] this_ptr The OrtScanKernelHelper instance.
*
* \since Version 1.24.
*/
ORT_API_T(void, Release, _In_ OrtScanKernelHelper* this_ptr);

/** \brief Helper function that transposes an OrtValue instance during execution of a Scan kernel.
*
* \param[in] this_ptr The OrtScanKernelHelper instance.
* \param[in] permutation An array of integers that defines how the input tensor's axes should be permuted.
* \param[in] num_permutation_elems The number of integer elements in the `permutation` array.
* \param[in] input The input OrtValue tensor to transpose.
* \param[in] stream An optional OrtSyncStream instance to be used for asynchronous operations. May be NULL.
* \param[out] output The pre-allocated output OrtValue instance into which to store the results of the
* transpose operation. Must not be released as it is owned by ORT.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(Transpose, _In_ OrtScanKernelHelper* this_ptr,
_In_reads_(num_permutation_elems) const size_t* permutation, _In_ size_t num_permutation_elems,
_In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, _Inout_ OrtValue* output);
};

/**
* \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider.
*
Expand Down Expand Up @@ -976,13 +1065,74 @@ struct OrtEpApi {
*
* \note Used within OrtKernelImpl implementations to obtain a reference to the OrtEp.
*
* \param[in] info The ::OrtKernelInfo instance.
* \param[in] info The OrtApi::OrtKernelInfo instance.
* \param[out] ep Output parameter set to the OrtEp instance associated with the OrtKernelInfo.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep);

/** \brief Creates an OrtKernelImpl instance for an If operator.
*
* \note Control flow operators normally require access to ORT session internals to orchestrate subgraph operations,
* This function allows an EP to create a properly configured kernel for an If operator that the EP can then add
* to its kernel registry.
*
* \param[in] kernel_info The OrtApi::OrtKernelInfo instance for an If node. This function returns error ORT_FAIL
* if the opset version specified by `kernel_info` is unsupported.
* \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the If node.
* Must be released via ::ReleaseKernelImpl, unless ownership is transferred
* to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out);

/** \brief Creates an OrtKernelImpl instance for a Loop operator.
*
* \note Control flow operators normally require access to ORT session internals to orchestrate subgraph operations,
* This function allows an EP to create a properly configured kernel for a Loop operator that the EP can then
* add to its kernel registry.
*
* \param[in] kernel_info The OrtApi::OrtKernelInfo instance for a Loop node. This function returns error ORT_FAIL
* if the opset version specified by `kernel_info` is unsupported.
* \param[in] helper A OrtLoopKernelHelper instance that contains helper functions that ORT calls during kernel
* execution to operate on tensors allocated with the EP's device memory.
* ORT will call OrtLoopKernelHelper::Release() to release the helper and its resources.
* \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the Loop node.
* Must be released via ::ReleaseKernelImpl, unless ownership is transferred
* to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtLoopKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out);

/** \brief Creates an OrtKernelImpl instance for a Scan operator. Does not support opset versions older than 9.
*
* \note Control flow operators normally require access to ORT session internals to orchestrate subgraph operations,
* This function allows an EP to create a properly configured kernel for a Scan operator that the EP can then
* add to its kernel registry.
*
* \param[in] kernel_info The OrtApi::OrtKernelInfo instance for a Scan node. This function returns error ORT_FAIL
* if the opset version specified by `kernel_info` is unsupported.
* \param[in] helper A OrtScanKernelHelper instance that contains helper functions that ORT calls during kernel
* execution to operate on tensors allocated with the EP's device memory.
* ORT will call OrtScanKernelHelper::Release() to release the helper and its resources.
* \param[out] kernel_out Output parameter set to the OrtKernelImpl instance for the Scan node.
* Must be released via ::ReleaseKernelImpl, unless ownership is transferred
* to ORT (see OrtKernelCreateFunc and ::KernelRegistry_AddKernel()).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
* \since Version 1.24
*/
ORT_API2_STATUS(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out);

ORT_CLASS_RELEASE(KernelImpl);
};

/**
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1183,11 +1183,14 @@ const InlinedHashSet<NodeIndex>* SessionState::GetToBeExecutedRange(
Status SessionState::CreateSubgraphSessionState() {
for (auto& node : graph_.Nodes()) {
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
const auto& ep = node.GetExecutionProviderType();
if (!ep.empty() &&
ep != kCpuExecutionProvider && ep != kCudaExecutionProvider &&
ep != kDmlExecutionProvider &&
ep != kJsExecutionProvider && ep != kWebGpuExecutionProvider) {
const auto& ep_type = node.GetExecutionProviderType();
const bool is_plugin_ep = execution_providers_.Get(ep_type)->GetOrtEp() != nullptr;

if (!ep_type.empty() &&
ep_type != kCpuExecutionProvider && ep_type != kCudaExecutionProvider &&
ep_type != kDmlExecutionProvider &&
ep_type != kJsExecutionProvider && ep_type != kWebGpuExecutionProvider &&
!is_plugin_ep) {
// SessionState is only used when ORT is executing the subgraph. If a non-ORT EP has taken the control flow
// node containing the subgraph it will create whatever state it needs internally.
continue;
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/core/providers/cpu/controlflow/loop.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class Loop : public controlflow::IControlFlowKernel {

static std::unique_ptr<OpKernel> Create(const OpKernelInfo& info, const ConcatOutput& concat_output_func, void* stream);

protected:
// derived class can provide implementation for handling concatenation of Loop output on a different device
void SetConcatOutputFunc(const ConcatOutput& concat_output_func) { concat_output_func_ = concat_output_func; }

Expand Down
107 changes: 107 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h"
#include "core/session/ort_apis.h"
#include "core/session/plugin_ep/ep_kernel_registration.h"
#include "core/session/plugin_ep/ep_control_flow_kernel_impls.h"
#include "core/session/utils.h"

using namespace onnxruntime;
Expand Down Expand Up @@ -655,6 +656,108 @@ ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ c
API_IMPL_END
}

// Control flow kernel APIs
ORT_API_STATUS_IMPL(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out) {
API_IMPL_BEGIN
if (kernel_info == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null OrtKernelInfo instance to create an If OrtKernelImpl");
}

const auto* op_kernel_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(kernel_info);
auto kernel_unique_ptr = std::make_unique<PluginEpIfKernelImpl>(*op_kernel_info);

*kernel_out = kernel_unique_ptr.release();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtLoopKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out) {
API_IMPL_BEGIN
if (kernel_info == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null OrtKernelInfo instance to create a Loop OrtKernelImpl");
}

if (helper == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null OrtLoopKernelHelper instance to create a Loop OrtKernelImpl");
}

if (helper->Release == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtLoopKernelHelper must have a non-null OrtLoopKernelHelper::Release function");
}

if (helper->ConcatOutput == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtLoopKernelHelper must have a non-null OrtLoopKernelHelper::ConcatOutput function");
}

if (kernel_out == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter to hold the OrtKernelImpl for Loop");
}

const auto* op_kernel_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(kernel_info);
auto kernel_unique_ptr = std::make_unique<PluginEpLoopKernelImpl>(*op_kernel_info, helper);

*kernel_out = kernel_unique_ptr.release();
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out) {
API_IMPL_BEGIN
if (kernel_info == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null OrtKernelInfo instance to create a Scan OrtKernelImpl");
}

if (helper == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null OrtScanKernelHelper instance to create a Scan OrtKernelImpl");
}

if (helper->Release == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtScanKernelHelper must have a non-null OrtScanKernelHelper::Release function");
}

if (helper->Transpose == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"OrtScanKernelHelper must have a non-null OrtScanKernelHelper::Transpose function");
}

if (kernel_out == nullptr) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
"Must specify a non-null output parameter to hold the OrtKernelImpl for Scan");
}

const auto* op_kernel_info = reinterpret_cast<const onnxruntime::OpKernelInfo*>(kernel_info);
int opset = op_kernel_info->node().SinceVersion();

if (opset >= 9) {
// Note: CPU EP always uses Scan<9> for all opsets >= 9.
auto kernel_unique_ptr = std::make_unique<PluginEpScanKernelImpl>(*op_kernel_info, helper);
*kernel_out = kernel_unique_ptr.release();
} else /*if (opset < 8)*/ {
return OrtApis::CreateStatus(ORT_FAIL,
"Kernel implementations for Scan older than opset version 9 are not supported");
}

return nullptr;
API_IMPL_END
}

ORT_API(void, ReleaseKernelImpl, _Frees_ptr_opt_ OrtKernelImpl* kernel_impl) {
if (kernel_impl != nullptr && kernel_impl->Release != nullptr) {
kernel_impl->Release(kernel_impl);
}
}

static constexpr OrtEpApi ort_ep_api = {
// NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end,
// and no functions can be removed (the implementation needs to change to return an error).
Expand Down Expand Up @@ -711,6 +814,10 @@ static constexpr OrtEpApi ort_ep_api = {
&OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel,
&OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData,
&OrtExecutionProviderApi::KernelInfo_GetEp,
&OrtExecutionProviderApi::CreateIfKernel,
&OrtExecutionProviderApi::CreateLoopKernel,
&OrtExecutionProviderApi::CreateScanKernel,
&OrtExecutionProviderApi::ReleaseKernelImpl,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,12 @@ ORT_API_STATUS_IMPL(SharedPrePackedWeightCache_StoreWeightData,

// KernelInfo
ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep);

// Control flow kernel APIs
ORT_API_STATUS_IMPL(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out);
ORT_API_STATUS_IMPL(CreateLoopKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtLoopKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out);
ORT_API_STATUS_IMPL(CreateScanKernel, _In_ const OrtKernelInfo* kernel_info, _In_ OrtScanKernelHelper* helper,
_Outptr_ OrtKernelImpl** kernel_out);
ORT_API(void, ReleaseKernelImpl, _Frees_ptr_opt_ OrtKernelImpl* kernel_impl);
} // namespace OrtExecutionProviderApi
Loading
Loading