Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 337 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,20 @@ inline const OrtCompileApi& GetCompileApi() {
return *api;
}

/// <summary>
/// This returns a reference to the ORT C Interop API. Used for external resource import with EPs.
/// </summary>
/// <returns>ORT C Interop API reference</returns>
inline const OrtInteropApi& GetInteropApi() {
auto* api = GetApi().GetInteropApi();
if (api == nullptr) {
// minimal build
ORT_CXX_API_THROW("Interop API is not available in this build", ORT_FAIL);
}

return *api;
}

/// <summary>
/// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider.
/// </summary>
Expand Down Expand Up @@ -1610,6 +1624,7 @@ struct ConstSessionImpl : Base<T> {
std::vector<ConstMemoryInfo> GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs
std::vector<ConstMemoryInfo> GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs
std::vector<ConstEpDevice> GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs
std::vector<ConstEpDevice> GetEpDeviceForOutputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForOutputs

/** \brief Returns a copy of input name at the specified index.
*
Expand Down
13 changes: 13 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,19 @@ inline std::vector<ConstEpDevice> ConstSessionImpl<T>::GetEpDeviceForInputs() co
return input_devices;
}

template <typename T>
inline std::vector<ConstEpDevice> ConstSessionImpl<T>::GetEpDeviceForOutputs() const {
auto num_outputs = GetOutputCount();
std::vector<ConstEpDevice> output_devices;
if (num_outputs > 0) {
output_devices.resize(num_outputs);
ThrowOnError(GetApi().SessionGetEpDeviceForOutputs(this->p_,
reinterpret_cast<const OrtEpDevice**>(output_devices.data()),
num_outputs));
}
return output_devices;
}

template <typename T>
inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
uint64_t out;
Expand Down
260 changes: 260 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,66 @@ ORT_RUNTIME_CLASS(DataTransferImpl);
ORT_RUNTIME_CLASS(SyncNotificationImpl);
ORT_RUNTIME_CLASS(SyncStreamImpl);

ORT_RUNTIME_CLASS(ExternalResourceImporterImpl);

/** \brief Base struct for imported external memory handles.
*
* EPs derive from this struct to add EP-specific fields (e.g., CUdeviceptr for CUDA).
* EP is responsible for creating and releasing instances of the derived type.
*
* Example derived type for CUDA EP:
* \code
* struct MyCudaExternalMemoryHandle : OrtExternalMemoryHandle {
* CUexternalMemory ext_memory;
* CUdeviceptr mapped_ptr;
* bool is_dedicated;
* };
* \endcode
*
* \since Version 1.24.
*/
struct OrtExternalMemoryHandle {
uint32_t version; ///< Must be ORT_API_VERSION
const OrtEpDevice* ep_device; ///< EP device that created this handle
OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking
size_t size_bytes; ///< Size of the imported memory
size_t offset_bytes; ///< Offset into the imported memory

/** \brief Release callback for this handle. EP sets this to its release function.
*
* ORT calls this when ReleaseExternalMemoryHandle is invoked. The EP's callback
* should cast the handle to its derived type and delete it.
*/
void(ORT_API_CALL* Release)(_In_ OrtExternalMemoryHandle* handle);
};

/** \brief Base struct for imported external semaphore handles.
*
* EPs derive from this struct to add EP-specific fields (e.g., CUexternalSemaphore for CUDA).
* EP is responsible for creating and releasing instances of the derived type.
*
* Example derived type for CUDA EP:
* \code
* struct MyCudaExternalSemaphoreHandle : OrtExternalSemaphoreHandle {
* CUexternalSemaphore ext_semaphore;
* };
* \endcode
*
* \since Version 1.24.
*/
struct OrtExternalSemaphoreHandle {
uint32_t version; ///< Must be ORT_API_VERSION
const OrtEpDevice* ep_device; ///< EP device that created this handle
OrtExternalSemaphoreType type; ///< Original semaphore type

/** \brief Release callback for this handle. EP sets this to its release function.
*
* ORT calls this when ReleaseExternalSemaphoreHandle is invoked. The EP's callback
* should cast the handle to its derived type and delete it.
*/
void(ORT_API_CALL* Release)(_In_ OrtExternalSemaphoreHandle* handle);
};

// Opaque types for kernel-based EPs
ORT_RUNTIME_CLASS(KernelRegistry);
ORT_RUNTIME_CLASS(KernelDefBuilder);
Expand Down Expand Up @@ -191,6 +251,180 @@ struct OrtSyncStreamImpl {
ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr);
};

/** \brief Struct that an EP implements for external resource import (memory + semaphore import).
*
* This capability object provides methods for importing external GPU memory and semaphores
* for zero-copy import. EPs that support D3D12, CUDA, HIP, or Vulkan external resource APIs
* can implement this interface.
*
* \since Version 1.24.
*/
struct OrtExternalResourceImporterImpl {
uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION

// Memory operations (stream-independent)

/** \brief Check if the implementation can import external memory of the given handle type.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] handle_type The type of external memory handle to check.
* \return True if the handle type is supported.
*
* \since Version 1.24.
*/
ORT_API_T(bool, CanImportMemory,
_In_ const OrtExternalResourceImporterImpl* this_ptr,
_In_ OrtExternalMemoryHandleType handle_type);

/** \brief Import external memory.
*
* The EP creates a derived type of OrtExternalMemoryHandle and returns a pointer to the base.
* EP is responsible for the lifetime of the handle (release via ReleaseMemory).
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] desc Descriptor containing the external memory handle and properties.
* \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle (EP's derived type).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(ImportMemory,
_In_ OrtExternalResourceImporterImpl* this_ptr,
_In_ const OrtExternalMemoryDescriptor* desc,
_Outptr_ OrtExternalMemoryHandle** out_handle);

/** \brief Release an imported external memory handle.
*
* The EP deletes its derived type instance.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] handle The OrtExternalMemoryHandle to release (EP casts to its derived type).
*
* \since Version 1.24.
*/
ORT_API_T(void, ReleaseMemory,
_In_ OrtExternalResourceImporterImpl* this_ptr,
_In_ OrtExternalMemoryHandle* handle);

/** \brief Create a tensor backed by imported external memory.
*
* The created tensor is a view over the imported memory and does not copy data.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] mem_handle The imported external memory handle (EP casts to its derived type).
* \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset.
* \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(CreateTensorFromMemory,
_In_ OrtExternalResourceImporterImpl* this_ptr,
_In_ const OrtExternalMemoryHandle* mem_handle,
_In_ const OrtExternalTensorDescriptor* tensor_desc,
_Outptr_ OrtValue** out_tensor);

// Semaphore operations (require stream)

/** \brief Check if the implementation can import external semaphores of the given type.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] type The type of external semaphore to check.
* \return True if the semaphore type is supported.
*
* \since Version 1.24.
*/
ORT_API_T(bool, CanImportSemaphore,
_In_ const OrtExternalResourceImporterImpl* this_ptr,
_In_ OrtExternalSemaphoreType type);

/** \brief Import an external semaphore.
*
* The EP creates a derived type of OrtExternalSemaphoreHandle and returns a pointer to the base.
* EP is responsible for the lifetime of the handle (release via ReleaseSemaphore).
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] desc Descriptor containing the external semaphore handle and type.
* \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle (EP's derived type).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(ImportSemaphore,
_In_ OrtExternalResourceImporterImpl* this_ptr,
_In_ const OrtExternalSemaphoreDescriptor* desc,
_Outptr_ OrtExternalSemaphoreHandle** out_handle);

/** \brief Release an imported external semaphore handle.
*
* The EP deletes its derived type instance.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] handle The OrtExternalSemaphoreHandle to release (EP casts to its derived type).
*
* \since Version 1.24.
*/
ORT_API_T(void, ReleaseSemaphore,
_In_ OrtExternalResourceImporterImpl* this_ptr,
_In_ OrtExternalSemaphoreHandle* handle);

/** \brief Wait on an external semaphore on the EP's stream.
*
* Inserts a wait operation into the EP's stream that blocks until the semaphore
* reaches the specified value.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] handle The imported external semaphore (EP casts to its derived type).
* \param[in] stream The OrtSyncStream to wait on.
* \param[in] value The fence/semaphore value to wait for.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(WaitSemaphore,
_In_ OrtExternalResourceImporterImpl* this_ptr,
_In_ OrtExternalSemaphoreHandle* handle,
_In_ OrtSyncStream* stream,
_In_ uint64_t value);

/** \brief Signal an external semaphore from the EP's stream.
*
* Inserts a signal operation into the EP's stream that sets the semaphore
* to the specified value when reached.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
* \param[in] handle The imported external semaphore (EP casts to its derived type).
* \param[in] stream The OrtSyncStream to signal from.
* \param[in] value The fence/semaphore value to signal.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.24.
*/
ORT_API2_STATUS(SignalSemaphore,
_In_ OrtExternalResourceImporterImpl* this_ptr,
_In_ OrtExternalSemaphoreHandle* handle,
_In_ OrtSyncStream* stream,
_In_ uint64_t value);

// Release the capability object itself

/** \brief Release the OrtExternalResourceImporterImpl instance.
*
* This is called by ORT when the OrtExternalResourceImporterImpl instance is no longer needed.
* The implementation should release any resources held by the instance.
*
* \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance.
*
* \since Version 1.24.
*/
ORT_API_T(void, Release, _In_ OrtExternalResourceImporterImpl* this_ptr);
};

struct OrtNodeFusionOptions;
typedef struct OrtNodeFusionOptions OrtNodeFusionOptions;

Expand Down Expand Up @@ -1564,6 +1798,32 @@ struct OrtEpFactory {
* \since Version 1.24.
*/
ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options);

/** \brief Create an OrtExternalResourceImporterImpl for external resource import.
*
* This is used to create an external resource importer that enables zero-copy import of
* external GPU memory (e.g., D3D12 shared resources) and synchronization primitives
* (e.g., D3D12 timeline fences).
*
* EPs that support external resource import (via CUDA, HIP, Vulkan, or D3D12 APIs) can
* implement this to allow applications to share GPU resources without copies.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[in] ep_device The OrtEpDevice to create the external resource importer for.
* \param[out] out_importer The created OrtExternalResourceImporterImpl instance.
* Set to nullptr if external resource import is not supported.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note Implementation of this function is optional.
* An EP factory should only implement this if it supports external resource import.
* If not implemented or not supported, return ORT_NOT_IMPLEMENTED or set out_importer to nullptr.
*
* \since Version 1.24.
*/
ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr,
_In_ const OrtEpDevice* ep_device,
_Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer);
};

#ifdef __cplusplus
Expand Down
42 changes: 42 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3480,6 +3480,48 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector<const OrtEpD
#endif
}

common::Status InferenceSession::GetEpDeviceForOutputs(InlinedVector<const OrtEpDevice*>& ep_devices) const {
ep_devices.clear();

#if defined(ORT_MINIMAL_BUILD)
return common::Status(common::ONNXRUNTIME, common::FAIL,
"GetEpDeviceForOutputs is not available in a minimal build.");
#else
if (!is_inited_) {
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized.");
}

std::pair<common::Status, const OutputDefList*> outputs = GetModelOutputs();

ORT_RETURN_IF_ERROR(outputs.first);

const auto& def_list = *outputs.second;
ep_devices.reserve(def_list.size());

const auto& available_eps = environment_.GetOrtEpDevices();

for (const auto* def : def_list) {
InlinedVector<SessionState::NodeInfo> node_info_vec;
ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec));
assert(!node_info_vec.empty());
// If we have an output that is not produced by any node,
// then we return nullptr.
const auto* p_node = node_info_vec.front().p_node;
if (p_node != nullptr) {
const auto ep_name = p_node->GetExecutionProviderType();
auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) {
return entry->ep_name == ep_name;
});
ep_devices.push_back(it != available_eps.end() ? *it : nullptr);
} else {
ep_devices.push_back(nullptr);
}
}

return Status::OK();
#endif
}

common::Status InferenceSession::NewIOBinding(std::unique_ptr<IOBinding>* io_binding) {
{
std::lock_guard<std::mutex> l(session_mutex_);
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,15 @@ class InferenceSession {
* This is required for a user to know the location of the input/output when autoep selection is enabled.
*/
common::Status GetEpDeviceForInputs(InlinedVector<const OrtEpDevice*>& memory_info) const;

/**
* Get the OrtEpDevice (if available) for the outputs of the model.
*
* This is required for a user to validate that outputs will be placed on the expected device
* for external resource sharing.
*/
common::Status GetEpDeviceForOutputs(InlinedVector<const OrtEpDevice*>& memory_info) const;

/**
* Get the current number of in-progress concurrent Run calls.
*/
Expand Down
Loading
Loading