diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3c85491837793..303bb5411ffd9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -330,6 +330,9 @@ ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. ORT_RUNTIME_CLASS(ExternalInitializerInfo); +ORT_RUNTIME_CLASS(ExternalResourceImporter); // Capability object for external resource import +ORT_RUNTIME_CLASS(ExternalMemoryHandle); // EP-imported view of shared external allocation +ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared external semaphore #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -878,6 +881,9 @@ typedef struct OrtModelEditorApi OrtModelEditorApi; struct OrtCompileApi; typedef struct OrtCompileApi OrtCompileApi; +struct OrtInteropApi; +typedef struct OrtInteropApi OrtInteropApi; + struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; @@ -955,6 +961,80 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n * * \nosubgrouping */ + +/** \addtogroup Global + * @{ + */ + +/** \brief External memory handle type for importing GPU resources. + * + * \todo Add OPAQUE_WIN32 for Windows Vulkan-specific memory handles + * \todo Add POSIX file descriptor (OPAQUE_FD) for Linux Vulkan/CUDA/OpenCL interop + * \todo Add Linux DMA-BUF file descriptor for embedded GPU memory sharing + * + * \since Version 1.24. + */ +typedef enum OrtExternalMemoryHandleType { + ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(resource) */ + ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 1, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(heap) */ +} OrtExternalMemoryHandleType; + +/** \brief Descriptor for importing external memory. + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +typedef struct OrtExternalMemoryDescriptor { + uint32_t version; /**< Must be ORT_API_VERSION */ + OrtExternalMemoryHandleType handle_type; /**< Type of the external memory handle */ + void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ + size_t size_bytes; /**< Total size in bytes of the external allocation */ + size_t offset_bytes; /**< Offset in bytes into the allocation (default 0). + Base offset for the imported memory region. */ +} OrtExternalMemoryDescriptor; + +/** \brief External semaphore type for GPU synchronization. + * + * \since Version 1.24. + */ +typedef enum OrtExternalSemaphoreType { + ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(fence) */ +} OrtExternalSemaphoreType; + +/** \brief Descriptor for importing external semaphores. + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +typedef struct OrtExternalSemaphoreDescriptor { + uint32_t version; /**< Must be ORT_API_VERSION */ + OrtExternalSemaphoreType type; /**< Type of the external semaphore */ + void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ +} OrtExternalSemaphoreDescriptor; + +/** \brief Descriptor for creating a tensor from imported external memory. + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +typedef struct OrtExternalTensorDescriptor { + uint32_t version; /**< Must be ORT_API_VERSION */ + ONNXTensorElementDataType element_type; /**< Data type of tensor elements */ + const int64_t* shape; /**< Array of dimension sizes */ + size_t rank; /**< Number of dimensions */ + size_t offset_bytes; /**< Additional offset within imported memory (default 0). + Applied relative to OrtExternalMemoryDescriptor::offset_bytes. + Enables multiple tensors from the same imported memory handle. */ +} OrtExternalTensorDescriptor; + +/// @} + /* * Public enum for compiled model compatibility across EPs. */ @@ -6667,6 +6747,45 @@ struct OrtApi { */ ORT_API2_STATUS(KernelInfo_GetOperatorSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version); + + /** \brief Get the EP Interop API instance. + * + * Get the Interop API instance to work with external resources. This API provides functions + * for importing external GPU memory and semaphores for zero-copy sharing between ORT inference + * and other GPU workloads. + * + * \return Interop API struct instance. + * + * \since Version 1.24. + */ + const OrtInteropApi*(ORT_API_CALL* GetInteropApi)(void); + + /** \brief Get the EP device assigned to each session output. + * + * Returns the OrtEpDevice assigned to each output of the session after graph partitioning. + * This allows validation that outputs are placed on the expected device for external resource sharing. + * + * The EP device for each output is determined by which execution provider will produce that output + * during inferencing. This information is useful for: + * - Validating that outputs will be placed on the expected device for external resource sharing + * - Deciding whether to use external memory handles for outputs + * + * \param[in] session The OrtSession instance to query. + * \param[out] outputs_ep_devices An array to be filled with the EP device for each output. + * The array must be allocated by the caller with space for + * OrtEpDevice* values for each output. + * The order is the same as returned by SessionGetOutputName. + * \param[in] num_outputs The number of outputs in the session. Must match SessionGetOutputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24 + */ + ORT_API2_STATUS(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_outputs); + + /// @} }; /* @@ -7464,6 +7583,224 @@ struct OrtCompileApi { _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); }; +/** + * \brief The OrtInteropApi struct provides functions for external resource interop with execution providers. + * + * This API enables importing external GPU resources (memory and semaphores) for zero-copy sharing + * between ORT inference and other GPU workloads (e.g., D3D12 applications, media pipelines). + * + * The API is designed to be EP-agnostic and can be extended to support various GPU interop mechanisms + * (D3D12 shared handles, CUDA external memory, Vulkan, etc.). + * + * Example usage (error handling not shown): + * const OrtInteropApi* interop_api = ort_api->GetInteropApi(); + * OrtExternalResourceImporter* importer = NULL; + * + * status = interop_api->CreateExternalResourceImporterForDevice(ep_device, &importer); + * if (importer == nullptr) { + * // External resource import is optional for EPs to implement + * return; + * } + * bool can_import = false; + * status = interop_api->CanImportMemory(importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import); + * if (can_import) { + * OrtExternalMemoryHandle* mem_handle = NULL; + * status = interop_api->ImportMemory(importer, &mem_desc, &mem_handle); + * // ... use mem_handle to create tensors ... + * interop_api->ReleaseExternalMemoryHandle(mem_handle); + * } + * interop_api->ReleaseExternalResourceImporter(importer); + * + * \since Version 1.24. + */ +struct OrtInteropApi { + /// \name OrtExternalResourceImporter + /// @{ + + /** \brief Create an external resource importer for a specific EP device. + * + * The external resource importer is a capability object that provides methods for importing + * external GPU memory and semaphores for zero-copy import with an execution provider. + * + * This is an optional EP capability. If the EP does not support external resource import, + * out_importer is set to nullptr and the function returns success (nullptr status). + * This allows callers to use the simple "if (status != nullptr) handle_error()" pattern + * and check out_importer separately for capability detection. + * + * \param[in] ep_device The OrtEpDevice instance to create the importer for. + * \param[out] out_importer Output parameter set to the created OrtExternalResourceImporter instance, + * or nullptr if the EP does not support external resource import. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateExternalResourceImporterForDevice, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + + /** \brief Release an OrtExternalResourceImporter instance. + * + * \param[in] importer The OrtExternalResourceImporter instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalResourceImporter); + + /// @} + /// \name Memory Import + /// @{ + + /** \brief Check if the external resource importer can import a specific memory handle type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] handle_type The type of external memory handle to check. + * \param[out] out_supported Set to true if the handle type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CanImportMemory, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + + /** \brief Import external memory into the execution provider. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external memory handle and properties. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle. + * The caller owns the returned handle and must call ReleaseExternalMemoryHandle to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + + /** \brief Release an OrtExternalMemoryHandle instance. + * + * \param[in] handle The OrtExternalMemoryHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalMemoryHandle); + + /** \brief Create a tensor backed by imported external memory. + * + * The created tensor is a view over the imported memory and does not copy data. + * The OrtExternalMemoryHandle must remain valid for the lifetime of the tensor. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] mem_handle The imported external memory handle. + * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. + * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * The caller owns the returned tensor and must call ReleaseValue to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateTensorFromMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor); + + /// @} + /// \name Semaphore Import + /// @{ + + /** \brief Check if the external resource importer can import a specific semaphore type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] type The type of external semaphore to check. + * \param[out] out_supported Set to true if the semaphore type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CanImportSemaphore, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + + /** \brief Import an external semaphore into the execution provider. + * + * The returned OrtExternalSemaphoreHandle can be used with WaitSemaphore and an OrtSyncStream + * to synchronize execution with external operations. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external semaphore handle and type. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + + /** \brief Release an OrtExternalSemaphoreHandle instance. + * + * \param[in] handle The OrtExternalSemaphoreHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalSemaphoreHandle); + + /** \brief Wait on an external semaphore on the EP's stream. + * + * Inserts a wait operation into the EP's stream that blocks until the semaphore + * reaches the specified value. This is used to synchronize with external GPU work + * (e.g., D3D12 timeline fence). + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to wait on. + * \param[in] value The fence/semaphore value to wait for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(WaitSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Signal an external semaphore from the EP's stream. + * + * Inserts a signal operation into the EP's stream that sets the semaphore + * to the specified value when reached. This is used to notify external GPU work + * (e.g., D3D12 timeline fence) that ORT inference is complete. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to signal from. + * \param[in] value The fence/semaphore value to signal. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SignalSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /// @} +}; + /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index d98757b2379a8..d24594f590619 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -236,6 +236,20 @@ inline const OrtCompileApi& GetCompileApi() { return *api; } +/// +/// This returns a reference to the ORT C Interop API. Used for external resource import with EPs. +/// +/// ORT C Interop API reference +inline const OrtInteropApi& GetInteropApi() { + auto* api = GetApi().GetInteropApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Interop API is not available in this build", ORT_FAIL); + } + + return *api; +} + /// /// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider. /// @@ -1610,6 +1624,7 @@ struct ConstSessionImpl : Base { std::vector GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs std::vector GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs std::vector GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs + std::vector GetEpDeviceForOutputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForOutputs /** \brief Returns a copy of input name at the specified index. * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index b7e1156f38a34..18299d2e49343 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1660,6 +1660,19 @@ inline std::vector ConstSessionImpl::GetEpDeviceForInputs() co return input_devices; } +template +inline std::vector ConstSessionImpl::GetEpDeviceForOutputs() const { + auto num_outputs = GetOutputCount(); + std::vector output_devices; + if (num_outputs > 0) { + output_devices.resize(num_outputs); + ThrowOnError(GetApi().SessionGetEpDeviceForOutputs(this->p_, + reinterpret_cast(output_devices.data()), + num_outputs)); + } + return output_devices; +} + template inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { uint64_t out; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index bd7f63f53ed8c..617788fcab8bb 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -24,6 +24,66 @@ ORT_RUNTIME_CLASS(DataTransferImpl); ORT_RUNTIME_CLASS(SyncNotificationImpl); ORT_RUNTIME_CLASS(SyncStreamImpl); +ORT_RUNTIME_CLASS(ExternalResourceImporterImpl); + +/** \brief Base struct for imported external memory handles. + * + * EPs derive from this struct to add EP-specific fields (e.g., CUdeviceptr for CUDA). + * EP is responsible for creating and releasing instances of the derived type. + * + * Example derived type for CUDA EP: + * \code + * struct MyCudaExternalMemoryHandle : OrtExternalMemoryHandle { + * CUexternalMemory ext_memory; + * CUdeviceptr mapped_ptr; + * bool is_dedicated; + * }; + * \endcode + * + * \since Version 1.24. + */ +struct OrtExternalMemoryHandle { + uint32_t version; ///< Must be ORT_API_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking + size_t size_bytes; ///< Size of the imported memory + size_t offset_bytes; ///< Offset into the imported memory + + /** \brief Release callback for this handle. EP sets this to its release function. + * + * ORT calls this when ReleaseExternalMemoryHandle is invoked. The EP's callback + * should cast the handle to its derived type and delete it. + */ + void(ORT_API_CALL* Release)(_In_ OrtExternalMemoryHandle* handle); +}; + +/** \brief Base struct for imported external semaphore handles. + * + * EPs derive from this struct to add EP-specific fields (e.g., CUexternalSemaphore for CUDA). + * EP is responsible for creating and releasing instances of the derived type. + * + * Example derived type for CUDA EP: + * \code + * struct MyCudaExternalSemaphoreHandle : OrtExternalSemaphoreHandle { + * CUexternalSemaphore ext_semaphore; + * }; + * \endcode + * + * \since Version 1.24. + */ +struct OrtExternalSemaphoreHandle { + uint32_t version; ///< Must be ORT_API_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalSemaphoreType type; ///< Original semaphore type + + /** \brief Release callback for this handle. EP sets this to its release function. + * + * ORT calls this when ReleaseExternalSemaphoreHandle is invoked. The EP's callback + * should cast the handle to its derived type and delete it. + */ + void(ORT_API_CALL* Release)(_In_ OrtExternalSemaphoreHandle* handle); +}; + // Opaque types for kernel-based EPs ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(KernelDefBuilder); @@ -191,6 +251,180 @@ struct OrtSyncStreamImpl { ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr); }; +/** \brief Struct that an EP implements for external resource import (memory + semaphore import). + * + * This capability object provides methods for importing external GPU memory and semaphores + * for zero-copy import. EPs that support D3D12, CUDA, HIP, or Vulkan external resource APIs + * can implement this interface. + * + * \since Version 1.24. + */ +struct OrtExternalResourceImporterImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + // Memory operations (stream-independent) + + /** \brief Check if the implementation can import external memory of the given handle type. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle_type The type of external memory handle to check. + * \return True if the handle type is supported. + * + * \since Version 1.24. + */ + ORT_API_T(bool, CanImportMemory, + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type); + + /** \brief Import external memory. + * + * The EP creates a derived type of OrtExternalMemoryHandle and returns a pointer to the base. + * EP is responsible for the lifetime of the handle (release via ReleaseMemory). + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] desc Descriptor containing the external memory handle and properties. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle (EP's derived type). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + + /** \brief Release an imported external memory handle. + * + * The EP deletes its derived type instance. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The OrtExternalMemoryHandle to release (EP casts to its derived type). + * + * \since Version 1.24. + */ + ORT_API_T(void, ReleaseMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandle* handle); + + /** \brief Create a tensor backed by imported external memory. + * + * The created tensor is a view over the imported memory and does not copy data. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] mem_handle The imported external memory handle (EP casts to its derived type). + * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. + * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateTensorFromMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor); + + // Semaphore operations (require stream) + + /** \brief Check if the implementation can import external semaphores of the given type. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] type The type of external semaphore to check. + * \return True if the semaphore type is supported. + * + * \since Version 1.24. + */ + ORT_API_T(bool, CanImportSemaphore, + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type); + + /** \brief Import an external semaphore. + * + * The EP creates a derived type of OrtExternalSemaphoreHandle and returns a pointer to the base. + * EP is responsible for the lifetime of the handle (release via ReleaseSemaphore). + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] desc Descriptor containing the external semaphore handle and type. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle (EP's derived type). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + + /** \brief Release an imported external semaphore handle. + * + * The EP deletes its derived type instance. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The OrtExternalSemaphoreHandle to release (EP casts to its derived type). + * + * \since Version 1.24. + */ + ORT_API_T(void, ReleaseSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle); + + /** \brief Wait on an external semaphore on the EP's stream. + * + * Inserts a wait operation into the EP's stream that blocks until the semaphore + * reaches the specified value. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The imported external semaphore (EP casts to its derived type). + * \param[in] stream The OrtSyncStream to wait on. + * \param[in] value The fence/semaphore value to wait for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(WaitSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Signal an external semaphore from the EP's stream. + * + * Inserts a signal operation into the EP's stream that sets the semaphore + * to the specified value when reached. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The imported external semaphore (EP casts to its derived type). + * \param[in] stream The OrtSyncStream to signal from. + * \param[in] value The fence/semaphore value to signal. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SignalSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + // Release the capability object itself + + /** \brief Release the OrtExternalResourceImporterImpl instance. + * + * This is called by ORT when the OrtExternalResourceImporterImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * + * \since Version 1.24. + */ + ORT_API_T(void, Release, _In_ OrtExternalResourceImporterImpl* this_ptr); +}; + struct OrtNodeFusionOptions; typedef struct OrtNodeFusionOptions OrtNodeFusionOptions; @@ -1564,6 +1798,32 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); + + /** \brief Create an OrtExternalResourceImporterImpl for external resource import. + * + * This is used to create an external resource importer that enables zero-copy import of + * external GPU memory (e.g., D3D12 shared resources) and synchronization primitives + * (e.g., D3D12 timeline fences). + * + * EPs that support external resource import (via CUDA, HIP, Vulkan, or D3D12 APIs) can + * implement this to allow applications to share GPU resources without copies. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] ep_device The OrtEpDevice to create the external resource importer for. + * \param[out] out_importer The created OrtExternalResourceImporterImpl instance. + * Set to nullptr if external resource import is not supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. + * An EP factory should only implement this if it supports external resource import. + * If not implemented or not supported, return ORT_NOT_IMPLEMENTED or set out_importer to nullptr. + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5523dc78b5d2..e9b0c2f230263 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3480,6 +3480,48 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector& ep_devices) const { + ep_devices.clear(); + +#if defined(ORT_MINIMAL_BUILD) + return common::Status(common::ONNXRUNTIME, common::FAIL, + "GetEpDeviceForOutputs is not available in a minimal build."); +#else + if (!is_inited_) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized."); + } + + std::pair outputs = GetModelOutputs(); + + ORT_RETURN_IF_ERROR(outputs.first); + + const auto& def_list = *outputs.second; + ep_devices.reserve(def_list.size()); + + const auto& available_eps = environment_.GetOrtEpDevices(); + + for (const auto* def : def_list) { + InlinedVector node_info_vec; + ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + assert(!node_info_vec.empty()); + // If we have an output that is not produced by any node, + // then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } + } + + return Status::OK(); +#endif +} + common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { { std::lock_guard l(session_mutex_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 8bea15c169ed4..ff54f6fa7bca0 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -484,6 +484,15 @@ class InferenceSession { * This is required for a user to know the location of the input/output when autoep selection is enabled. */ common::Status GetEpDeviceForInputs(InlinedVector& memory_info) const; + + /** + * Get the OrtEpDevice (if available) for the outputs of the model. + * + * This is required for a user to validate that outputs will be placed on the expected device + * for external resource sharing. + */ + common::Status GetEpDeviceForOutputs(InlinedVector& memory_info) const; + /** * Get the current number of in-progress concurrent Run calls. */ diff --git a/onnxruntime/core/session/interop_api.cc b/onnxruntime/core/session/interop_api.cc new file mode 100644 index 0000000000000..144cd82c02045 --- /dev/null +++ b/onnxruntime/core/session/interop_api.cc @@ -0,0 +1,386 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/interop_api.h" + +#if !defined(ORT_MINIMAL_BUILD) +#include + +#include "core/session/ort_apis.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_env.h" +#include "core/session/abi_devices.h" +#include "core/common/logging/logging.h" +#include "core/framework/error_code_helper.h" +#else +#include "core/framework/error_code_helper.h" +#include "core/session/ort_apis.h" +#endif // !defined(ORT_MINIMAL_BUILD) + +using namespace onnxruntime; + +#if !defined(ORT_MINIMAL_BUILD) + +// Wrapper class for OrtExternalResourceImporterImpl +namespace { +struct ExternalResourceImporterWrapper { + const OrtEpDevice* ep_device; + OrtExternalResourceImporterImpl* impl; + + ExternalResourceImporterWrapper(const OrtEpDevice* device, OrtExternalResourceImporterImpl* importer) + : ep_device(device), impl(importer) {} + + ~ExternalResourceImporterWrapper() { + if (impl && impl->Release) { + impl->Release(impl); + } + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalResourceImporterWrapper); +}; + +} // namespace + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { + API_IMPL_BEGIN + if (ep_device == nullptr || out_importer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and out_importer must be provided."); + } + + *out_importer = nullptr; + + // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. + const auto* factory = ep_device->ep_factory; + if (factory == nullptr || + factory->ort_version_supported < 24 || + factory->CreateExternalResourceImporterForDevice == nullptr) { + // EP doesn't support this optional feature + return nullptr; + } + + OrtExternalResourceImporterImpl* impl = nullptr; + ORT_API_RETURN_IF_ERROR(factory->CreateExternalResourceImporterForDevice( + ep_device->GetMutableFactory(), + ep_device, + &impl)); + + if (impl == nullptr) { + // EP doesn't support this for the specific device + return nullptr; + } + + auto wrapper = std::make_unique(ep_device, impl); + *out_importer = reinterpret_cast(wrapper.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer) { +#if !defined(ORT_MINIMAL_BUILD) + delete reinterpret_cast(importer); +#else + ORT_UNUSED_PARAMETER(importer); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportMemory == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportMemory(wrapper->impl, handle_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External memory import is not supported by this EP."); + } + + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalMemoryHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportMemory(wrapper->impl, desc, &handle)); + + if (handle == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportMemory returned null handle."); + } + + *out_handle = handle; + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle) { +#if !defined(ORT_MINIMAL_BUILD) + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); + } +#else + ORT_UNUSED_PARAMETER(handle); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) { + API_IMPL_BEGIN + if (importer == nullptr || mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, mem_handle, tensor_desc, and out_tensor must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->CreateTensorFromMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateTensorFromMemory is not supported by this EP."); + } + + OrtValue* tensor = nullptr; + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->CreateTensorFromMemory(imp_wrapper->impl, mem_handle, tensor_desc, &tensor)); + + *out_tensor = tensor; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportSemaphore == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportSemaphore(wrapper->impl, type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External semaphore import is not supported by this EP."); + } + + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalSemaphoreHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportSemaphore(wrapper->impl, desc, &handle)); + + if (handle == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportSemaphore returned null handle."); + } + + *out_handle = handle; + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle) { +#if !defined(ORT_MINIMAL_BUILD) + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); + } +#else + ORT_UNUSED_PARAMETER(handle); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->WaitSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "WaitSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->WaitSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->SignalSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SignalSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->SignalSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); + + return nullptr; + API_IMPL_END +} + +#else // defined(ORT_MINIMAL_BUILD) + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(ep_device); + ORT_UNUSED_PARAMETER(out_importer); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(handle_type); + ORT_UNUSED_PARAMETER(out_supported); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(desc); + ORT_UNUSED_PARAMETER(out_handle); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(mem_handle); + ORT_UNUSED_PARAMETER(tensor_desc); + ORT_UNUSED_PARAMETER(out_tensor); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(type); + ORT_UNUSED_PARAMETER(out_supported); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(desc); + ORT_UNUSED_PARAMETER(out_handle); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(semaphore_handle); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(value); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(semaphore_handle); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(value); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +#endif // !defined(ORT_MINIMAL_BUILD) + +static constexpr OrtInteropApi ort_interop_api = { + // NOTE: Application compatibility with newer versions of ORT depends on the Api order within this struct so + // all new functions must be added at the end, and no functions that already exist in an officially released version + // of ORT can be reordered or removed. + + &OrtInteropAPI::CreateExternalResourceImporterForDevice, + &OrtInteropAPI::ReleaseExternalResourceImporter, + &OrtInteropAPI::CanImportMemory, + &OrtInteropAPI::ImportMemory, + &OrtInteropAPI::ReleaseExternalMemoryHandle, + &OrtInteropAPI::CreateTensorFromMemory, + &OrtInteropAPI::CanImportSemaphore, + &OrtInteropAPI::ImportSemaphore, + &OrtInteropAPI::ReleaseExternalSemaphoreHandle, + &OrtInteropAPI::WaitSemaphore, + &OrtInteropAPI::SignalSemaphore, + // End of Version 24 - DO NOT MODIFY ABOVE +}; + +// Checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtInteropApi, SignalSemaphore) / sizeof(void*) == 10, + "Size of version 24 Api cannot change"); // initial version in ORT 1.24 + +ORT_API(const OrtInteropApi*, OrtInteropAPI::GetInteropApi) { + return &ort_interop_api; +} diff --git a/onnxruntime/core/session/interop_api.h b/onnxruntime/core/session/interop_api.h new file mode 100644 index 0000000000000..f0822dafd7077 --- /dev/null +++ b/onnxruntime/core/session/interop_api.h @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/session/onnxruntime_c_api.h" + +namespace OrtInteropAPI { + +// implementation that returns the API struct +ORT_API(const OrtInteropApi*, GetInteropApi); + +ORT_API_STATUS_IMPL(CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + +ORT_API(void, ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer); + +ORT_API_STATUS_IMPL(CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + +ORT_API(void, ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle); + +ORT_API_STATUS_IMPL(CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor); + +ORT_API_STATUS_IMPL(CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + +ORT_API(void, ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle); + +ORT_API_STATUS_IMPL(WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +ORT_API_STATUS_IMPL(SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +} // namespace OrtInteropAPI diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b1c2e07b9ffb7..c3bf74a4607e8 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,6 +38,7 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" +#include "core/session/interop_api.h" #include "core/session/plugin_ep/ep_api.h" #include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" @@ -3372,6 +3373,34 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession* ort_session, + _Out_writes_(num_values) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_values) { + API_IMPL_BEGIN + if (ort_session == nullptr || outputs_ep_devices == nullptr || num_values == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to SessionGetEpDeviceForOutputs."); + } + + auto session = reinterpret_cast(ort_session); + + InlinedVector ep_devices; + + ORT_API_RETURN_IF_STATUS_NOT_OK(session->GetEpDeviceForOutputs(ep_devices)); + + auto num_found = ep_devices.size(); + if (num_found > num_values) { + auto msg = MakeString("Number of outputs ", num_found, " exceeds the provided size of ", num_values); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, msg.c_str()); + } + + for (size_t i = 0; i < num_values; ++i) { + outputs_ep_devices[i] = (i < num_found) ? ep_devices[i] : nullptr; + } + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStream** ort_stream) { @@ -3536,6 +3565,11 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, API_IMPL_END } +// GetInteropApi - returns the Interop API struct +ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { + return OrtInteropAPI::GetInteropApi(); +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { @@ -3621,6 +3655,19 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, API_IMPL_END } +ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { + fprintf(stderr, "The Interop API is not supported in a minimal build.\n"); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession* /*ort_session*/, + _Out_writes_(num_values) const OrtEpDevice** /*outputs_ep_devices*/, + _In_ size_t /*num_values*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForOutputs is not supported in a minimal build."); + API_IMPL_END +} + #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -4241,6 +4288,9 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::KernelInfo_GetOperatorDomain, &OrtApis::KernelInfo_GetOperatorType, &OrtApis::KernelInfo_GetOperatorSinceVersion, + + &OrtApis::GetInteropApi, + &OrtApis::SessionGetEpDeviceForOutputs, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index ccdfa53e1b225..7aa09adfd32d1 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -761,4 +761,11 @@ ORT_API_STATUS_IMPL(KernelInfo_GetOperatorType, _In_ const OrtKernelInfo* info, ORT_API_STATUS_IMPL(KernelInfo_GetOperatorSinceVersion, _In_ const OrtKernelInfo* info, _Out_ int* since_version); +// Interop API +ORT_API(const OrtInteropApi*, GetInteropApi); + +ORT_API_STATUS_IMPL(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_outputs); + } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 364bab471ddbe..8dc92802aa84b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -33,6 +33,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::IsStreamAware = Forward::IsStreamAware; OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions; + OrtEpFactory::CreateExternalResourceImporterForDevice = Forward::CreateExternalResourceImporterForDevice; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 6eb83a117fb63..ae98f2c0ac589 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -91,6 +91,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->SetEnvironmentOptions(options); } + OrtStatus* CreateExternalResourceImporterForDevice(_In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + return impl_->CreateExternalResourceImporterForDevice(ep_device, importer); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index de9e2d44431bf..20a47715df2b8 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -88,6 +88,14 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* CreateExternalResourceImporterForDevice( + _In_ const OrtEpDevice* /*ep_device*/, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + // Default implementation does not support external resource import + *importer = nullptr; + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 8c5ef526baba1..26173f0055ed7 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -65,6 +65,18 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); } + OrtStatus* CreateExternalResourceImporterForDevice( + const OrtEpDevice* ep_device, + OrtExternalResourceImporterImpl** importer) noexcept override { + // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. + if (ep_factory_.ort_version_supported < 24 || + ep_factory_.CreateExternalResourceImporterForDevice == nullptr) { + *importer = nullptr; + return nullptr; + } + return ep_factory_.CreateExternalResourceImporterForDevice(&ep_factory_, ep_device, importer); + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 65c396181f0a7..2530ae8eb3c2b 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -87,6 +87,13 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->SetEnvironmentOptions(options); } + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDevice( + _In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + return static_cast(this_ptr)->CreateExternalResourceImporterForDevice(ep_device, importer); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc new file mode 100644 index 0000000000000..5230064138d03 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_external_resource_importer.h" + +#include +#include +#include +#include +#include + +ExampleExternalResourceImporter::ExampleExternalResourceImporter(const ApiPtrs& apis) + : OrtExternalResourceImporterImpl{}, apis_{apis} { + ort_version_supported = ORT_API_VERSION; + + // Memory operations + CanImportMemory = CanImportMemoryImpl; + ImportMemory = ImportMemoryImpl; + ReleaseMemory = ReleaseMemoryImpl; + CreateTensorFromMemory = CreateTensorFromMemoryImpl; + + // Semaphore operations + CanImportSemaphore = CanImportSemaphoreImpl; + ImportSemaphore = ImportSemaphoreImpl; + ReleaseSemaphore = ReleaseSemaphoreImpl; + WaitSemaphore = WaitSemaphoreImpl; + SignalSemaphore = SignalSemaphoreImpl; + + // Release + Release = ReleaseImpl; +} + +/*static*/ +bool ORT_API_CALL ExampleExternalResourceImporter::CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalMemoryHandleType handle_type) noexcept { + // The example EP supports both D3D12 resource and heap handle types for testing + return handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE || + handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportMemory"); + } + + *out_handle = nullptr; + + // Validate handle type + if (!CanImportMemoryImpl(this_ptr, desc->handle_type)) { + return impl.apis_.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external memory handle type"); + } + + // In a real implementation, you would: + // 1. Open/import the native handle (e.g., cuImportExternalMemory for CUDA) + // 2. Map the memory to get a device pointer + // + // For testing purposes, we simulate this by allocating CPU memory + // that mirrors the size of the external allocation. + + auto* handle = new (std::nothrow) ExampleExternalMemoryHandle(); + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external memory handle"); + } + + // Allocate simulated memory (using CPU memory for the example) + size_t effective_size = desc->size_bytes - desc->offset_bytes; + handle->simulated_ptr = std::make_unique(effective_size); + + handle->size_bytes = desc->size_bytes; + handle->offset_bytes = desc->offset_bytes; + handle->handle_type = desc->handle_type; + + *out_handle = handle; + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalMemoryHandle* handle) noexcept { + if (handle == nullptr) { + return; + } + + auto* mem_handle = static_cast(handle); + delete mem_handle; // destructor frees simulated_ptr +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept { + auto& impl = *static_cast(this_ptr); + + if (mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to CreateTensorFromMemory"); + } + + *out_tensor = nullptr; + + auto* handle = static_cast(mem_handle); + + // Calculate the data pointer with tensor offset + void* data_ptr = handle->simulated_ptr.get() + tensor_desc->offset_bytes; + + // For the example EP, we use CPU memory info since we're simulating with CPU memory + // In a real implementation, you would use the appropriate GPU memory info + OrtMemoryInfo* memory_info = nullptr; + OrtStatus* status = impl.apis_.ort_api.CreateMemoryInfo( + "Cpu", // For testing, we use CPU memory + OrtDeviceAllocator, + 0, // device ID + OrtMemTypeDefault, + &memory_info); + + if (status != nullptr) { + return status; + } + + // Calculate buffer size + // NOTE: This is a simplified calculation for testing. Production code should: + // 1. Calculate actual tensor size from shape + element_type + // 2. Validate it fits within available memory region + // 3. Use that validated size rather than subtracting offsets + size_t buffer_size = handle->size_bytes - handle->offset_bytes - tensor_desc->offset_bytes; + + // Create tensor with pre-allocated memory + status = impl.apis_.ort_api.CreateTensorWithDataAsOrtValue( + memory_info, + data_ptr, + buffer_size, + tensor_desc->shape, + tensor_desc->rank, + tensor_desc->element_type, + out_tensor); + + impl.apis_.ort_api.ReleaseMemoryInfo(memory_info); + return status; +} + +/*static*/ +bool ORT_API_CALL ExampleExternalResourceImporter::CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalSemaphoreType type) noexcept { + // The example EP supports D3D12 fence for testing + return type == ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportSemaphore"); + } + + *out_handle = nullptr; + + // Validate semaphore type + if (!CanImportSemaphoreImpl(this_ptr, desc->type)) { + return impl.apis_.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external semaphore type"); + } + + // In a real implementation, you would: + // 1. Import the native fence handle (e.g., cuImportExternalSemaphore for CUDA) + // + // For testing purposes, we create a simulated semaphore using an atomic counter + + auto* handle = new (std::nothrow) ExampleExternalSemaphoreHandle(); + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external semaphore handle"); + } + + handle->type = desc->type; + handle->value.store(0); + + *out_handle = handle; + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalSemaphoreHandle* handle) noexcept { + if (handle == nullptr) { + return; + } + + auto* sem_handle = static_cast(handle); + delete sem_handle; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to WaitSemaphore"); + } + + // stream can be nullptr for synchronous wait + (void)stream; + + auto* sem_handle = static_cast(handle); + + // In a real implementation, you would: + // 1. Queue a wait operation on the GPU stream (e.g., cuWaitExternalSemaphoresAsync) + // + // For testing, we do a simple spin-wait on the atomic counter + // with a reasonable timeout to prevent infinite loops in tests + + const int max_iterations = 10000; + int iterations = 0; + while (sem_handle->value.load() < value && iterations < max_iterations) { + std::this_thread::sleep_for(std::chrono::microseconds(100)); + ++iterations; + } + + if (iterations >= max_iterations) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "WaitSemaphore timed out"); + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to SignalSemaphore"); + } + + // stream can be nullptr for synchronous signal + (void)stream; + + auto* sem_handle = static_cast(handle); + + // In a real implementation, you would: + // 1. Queue a signal operation on the GPU stream (e.g., cuSignalExternalSemaphoresAsync) + // + // For testing, we simply update the atomic counter + + sem_handle->value.store(value); + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h new file mode 100644 index 0000000000000..4721367c68963 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h @@ -0,0 +1,130 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../plugin_ep_utils.h" + +#include +#include +#include + +/** + * @brief Example derived handle for imported external memory. + * + * Derives from OrtExternalMemoryHandle and adds example-specific fields. + * This mock implementation simulates imported external memory for testing purposes. + * In a real EP, this would hold a GPU-mapped pointer from an imported D3D12/Vulkan/CUDA resource. + */ +struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { + std::unique_ptr simulated_ptr; ///< Simulated mapped pointer (CPU memory for testing) + + ExampleExternalMemoryHandle() + : simulated_ptr(nullptr) { + // Initialize base struct fields + version = ORT_API_VERSION; + ep_device = nullptr; + handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + size_bytes = 0; + offset_bytes = 0; + Release = ReleaseCallback; + } + + ~ExampleExternalMemoryHandle() = default; + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalMemoryHandle* handle) noexcept { + if (handle == nullptr) return; + delete static_cast(handle); + } +}; + +/** + * @brief Example derived handle for imported external semaphore. + * + * Derives from OrtExternalSemaphoreHandle and adds example-specific fields. + * This mock implementation simulates imported external semaphores for testing purposes. + * In a real EP, this would hold an imported D3D12 fence / Vulkan semaphore / CUDA external semaphore. + */ +struct ExampleExternalSemaphoreHandle : OrtExternalSemaphoreHandle { + std::atomic value; ///< Simulated fence value for testing + + ExampleExternalSemaphoreHandle() + : value(0) { + // Initialize base struct fields + version = ORT_API_VERSION; + ep_device = nullptr; + type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + Release = ReleaseCallback; + } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalSemaphoreHandle* handle) noexcept { + if (handle == nullptr) return; + delete static_cast(handle); + } +}; + +/** + * @brief Example implementation of OrtExternalResourceImporterImpl. + * + * This is a mock implementation that simulates external resource interop for testing + * the ORT public API without requiring actual D3D12/CUDA/Vulkan hardware. + * + * Key features: + * - Reports support for D3D12 resource/heap and fence handle types + * - Creates simulated memory mappings using CPU memory + * - Simulates fence wait/signal operations + * - Allows tensor creation from "imported" memory + */ +class ExampleExternalResourceImporter : public OrtExternalResourceImporterImpl { + public: + ExampleExternalResourceImporter(const ApiPtrs& apis); + + static bool ORT_API_CALL CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type) noexcept; + + static OrtStatus* ORT_API_CALL ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept; + + static void ORT_API_CALL ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandle* handle) noexcept; + + static OrtStatus* ORT_API_CALL CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept; + + static bool ORT_API_CALL CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type) noexcept; + + static OrtStatus* ORT_API_CALL ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept; + + static void ORT_API_CALL ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle) noexcept; + + static OrtStatus* ORT_API_CALL WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept; + + static OrtStatus* ORT_API_CALL SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept; + + static void ORT_API_CALL ReleaseImpl(_In_ OrtExternalResourceImporterImpl* this_ptr) noexcept; + + private: + ApiPtrs apis_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 85f9504b14a4d..7c2b8e59ade89 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -37,6 +37,8 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", @@ -308,3 +310,25 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac return nullptr; } + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtEpDevice* /*ep_device*/, + OrtExternalResourceImporterImpl** out_importer) noexcept { + auto& factory = *static_cast(this_ptr); + + if (out_importer == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "out_importer cannot be nullptr"); + } + + // Create the external resource importer + // NOTE: For production multi-GPU EPs, you should capture ep_device in the importer + // to enable proper device validation and support multiple physical devices. + // This example EP only supports a single device, so we don't store it. + auto importer = std::make_unique(factory); + *out_importer = importer.release(); + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 196e67fc5c558..230fdef772e2f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -7,6 +7,7 @@ #include "ep_arena.h" #include "ep_data_transfer.h" +#include "ep_external_resource_importer.h" #include "../plugin_ep_utils.h" /// @@ -67,6 +68,11 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept; + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtEpDevice* ep_device, + OrtExternalResourceImporterImpl** out_importer) noexcept; + const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name diff --git a/onnxruntime/test/autoep/test_external_resource_importer.cc b/onnxruntime/test/autoep/test_external_resource_importer.cc new file mode 100644 index 0000000000000..b6f3014e1145e --- /dev/null +++ b/onnxruntime/test/autoep/test_external_resource_importer.cc @@ -0,0 +1,356 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Tests for the External Resource Interop API using the example_plugin_ep. +// This tests the public ORT API without requiring actual D3D12/CUDA hardware. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +class ExternalResourceImporterTest : public ::testing::Test { + protected: + void SetUp() override { + // Register the example EP and get the device using shared utility + Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, registered_ep_); + ASSERT_NE(registered_ep_.get(), nullptr) << "Example EP device not found"; + ep_device_ = registered_ep_.get(); + } + + const OrtInteropApi& GetInteropApi() const { + return Ort::GetInteropApi(); + } + + RegisteredEpDeviceUniquePtr registered_ep_; + const OrtEpDevice* ep_device_ = nullptr; +}; + +// Test: Create External Resource Importer +TEST_F(ExternalResourceImporterTest, CreateExternalResourceImporter) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + + // Status should be nullptr on success (even if importer is null for unsupported EPs) + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + + // importer may be nullptr if EP doesn't support this optional feature + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported by this EP"; + } + + // Release the importer + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Memory Import Capability +TEST_F(ExternalResourceImporterTest, CanImportMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported"; + } + + // Check D3D12 Resource support + bool can_import_resource = false; + status = GetInteropApi().CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import_resource); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_RESOURCE should succeed"; + EXPECT_TRUE(can_import_resource) << "Example EP should support D3D12 Resource import"; + + // Check D3D12 Heap support + bool can_import_heap = false; + status = GetInteropApi().CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, &can_import_heap); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_HEAP should succeed"; + EXPECT_TRUE(can_import_heap) << "Example EP should support D3D12 Heap import"; + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Semaphore Import Capability +TEST_F(ExternalResourceImporterTest, CanImportSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported"; + } + + // Check D3D12 Fence support + bool can_import_fence = false; + status = GetInteropApi().CanImportSemaphore( + importer, ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE, &can_import_fence); + ASSERT_EQ(status, nullptr) << "CanImportSemaphore for D3D12_FENCE should succeed"; + EXPECT_TRUE(can_import_fence) << "Example EP should support D3D12 Fence import"; + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Import Memory (Simulated) +TEST_F(ExternalResourceImporterTest, ImportMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported"; + } + + // Import memory (using a dummy handle for testing) + const size_t buffer_size = 1024 * sizeof(float); + void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); // Simulated handle + + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = dummy_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr) << "ImportMemory should succeed"; + ASSERT_NE(mem_handle, nullptr) << "Memory handle should not be null"; + + // Release memory handle + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Create Tensor from Imported Memory +TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported"; + } + + // Create tensor shape: [1, 3, 32, 32] + const int64_t batch = 1, channels = 3, height = 32, width = 32; + const int64_t shape[] = {batch, channels, height, width}; + const size_t num_elements = batch * channels * height * width; + const size_t buffer_size = num_elements * sizeof(float); + + // Import memory + void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); + + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = dummy_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr); + + // Create tensor from imported memory + OrtExternalTensorDescriptor tensor_desc = {}; + tensor_desc.version = ORT_API_VERSION; + tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + tensor_desc.shape = shape; + tensor_desc.rank = 4; + tensor_desc.offset_bytes = 0; + + OrtValue* tensor = nullptr; + status = GetInteropApi().CreateTensorFromMemory( + importer, mem_handle, &tensor_desc, &tensor); + ASSERT_EQ(status, nullptr) << "CreateTensorFromMemory should succeed"; + ASSERT_NE(tensor, nullptr) << "Tensor should not be null"; + + // Verify tensor properties + OrtTensorTypeAndShapeInfo* type_info = nullptr; + status = Ort::GetApi().GetTensorTypeAndShape(tensor, &type_info); + ASSERT_EQ(status, nullptr); + + size_t rank = 0; + status = Ort::GetApi().GetDimensionsCount(type_info, &rank); + ASSERT_EQ(status, nullptr); + EXPECT_EQ(rank, 4u); + + std::vector actual_shape(rank); + status = Ort::GetApi().GetDimensions(type_info, actual_shape.data(), rank); + ASSERT_EQ(status, nullptr); + EXPECT_EQ(actual_shape[0], batch); + EXPECT_EQ(actual_shape[1], channels); + EXPECT_EQ(actual_shape[2], height); + EXPECT_EQ(actual_shape[3], width); + + ONNXTensorElementDataType elem_type; + status = Ort::GetApi().GetTensorElementType(type_info, &elem_type); + ASSERT_EQ(status, nullptr); + EXPECT_EQ(elem_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + Ort::GetApi().ReleaseTensorTypeAndShapeInfo(type_info); + + // Cleanup + Ort::GetApi().ReleaseValue(tensor); + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Import Semaphore (Simulated) +TEST_F(ExternalResourceImporterTest, ImportSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported"; + } + + // Import semaphore (using a dummy handle for testing) + void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_API_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = dummy_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = GetInteropApi().ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; + ASSERT_NE(sem_handle, nullptr) << "Semaphore handle should not be null"; + + // Release semaphore handle + GetInteropApi().ReleaseExternalSemaphoreHandle(sem_handle); + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Wait and Signal Semaphore (Simulated) +TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported"; + } + + // Create a stream for the EP + OrtSyncStream* stream = nullptr; + status = Ort::GetApi().CreateSyncStreamForEpDevice(ep_device_, nullptr, &stream); + ASSERT_EQ(status, nullptr) << "CreateSyncStreamForEpDevice should succeed"; + + // Import semaphore + void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_API_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = dummy_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = GetInteropApi().ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr); + + // Signal the semaphore with value 1 + status = GetInteropApi().SignalSemaphore(importer, sem_handle, stream, 1); + ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; + + // Wait for value 1 (should succeed immediately since we just signaled it) + status = GetInteropApi().WaitSemaphore(importer, sem_handle, stream, 1); + ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; + + // Signal with value 5 + status = GetInteropApi().SignalSemaphore(importer, sem_handle, stream, 5); + ASSERT_EQ(status, nullptr); + + // Wait for value 3 (should succeed since current value is 5) + status = GetInteropApi().WaitSemaphore(importer, sem_handle, stream, 3); + ASSERT_EQ(status, nullptr); + + // Cleanup + GetInteropApi().ReleaseExternalSemaphoreHandle(sem_handle); + Ort::GetApi().ReleaseSyncStream(stream); + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Multiple Memory Imports +TEST_F(ExternalResourceImporterTest, MultipleMemoryImports) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + ASSERT_EQ(status, nullptr) << "CreateExternalResourceImporterForDevice should succeed"; + if (importer == nullptr) { + GTEST_SKIP() << "External resource interop not supported"; + } + + constexpr int kNumBuffers = 5; + std::vector handles(kNumBuffers); + + // Import multiple memory regions + for (int i = 0; i < kNumBuffers; ++i) { + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = reinterpret_cast(static_cast(0x10000000 + i * 0x1000)); + mem_desc.size_bytes = (i + 1) * 1024; + mem_desc.offset_bytes = 0; + + status = GetInteropApi().ImportMemory(importer, &mem_desc, &handles[i]); + ASSERT_EQ(status, nullptr) << "ImportMemory " << i << " should succeed"; + ASSERT_NE(handles[i], nullptr); + } + + // Release all handles + for (int i = 0; i < kNumBuffers; ++i) { + GetInteropApi().ReleaseExternalMemoryHandle(handles[i]); + } + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: SessionGetEpDeviceForOutputs +TEST_F(ExternalResourceImporterTest, SessionGetEpDeviceForOutputs) { + // Load a simple model with the example EP + Ort::SessionOptions session_options; + + // Add the example EP to the session + const OrtEpDevice* devices[] = {ep_device_}; + OrtStatus* status = Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( + session_options, *ort_env, devices, 1, nullptr, nullptr, 0); + if (status != nullptr) { + std::string error = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "Example EP not available: " << error; + } + + // Create session with test model (mul_1.onnx - a simple model the example EP supports) + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + + // Get output count + size_t num_outputs = session.GetOutputCount(); + ASSERT_GT(num_outputs, 0U) << "Model should have at least one output"; + + // Get EP devices for outputs + std::vector output_devices(num_outputs); + status = Ort::GetApi().SessionGetEpDeviceForOutputs( + session, output_devices.data(), num_outputs); + ASSERT_EQ(status, nullptr) << "SessionGetEpDeviceForOutputs should succeed"; + + // Validate that we got EP devices (may be nullptr if not assigned to EP) + // At least verify the call succeeded and returned valid array + for (size_t i = 0; i < num_outputs; ++i) { + if (output_devices[i] != nullptr) { + // If an EP device is returned, validate it has a name + const char* ep_name = Ort::GetApi().EpDevice_EpName(output_devices[i]); + ASSERT_NE(ep_name, nullptr) << "EP device should have a name"; + } + } +} + +} // namespace test +} // namespace onnxruntime