From 8e2e4f3602d42e93531e199ed7fac863c4cd7f2a Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Thu, 18 Dec 2025 15:35:08 -0500 Subject: [PATCH 01/10] Add OrtExternalResourceImporter API for D3D12 shared resource import Introduces the OrtExternalResourceImporter API enabling execution providers to import D3D12 shared resources and timeline fences for zero-copy GPU-to-GPU data sharing with ORT inference. Public API additions: - OrtExternalResourceImporter capability object - OrtExternalMemoryHandle for imported D3D12 allocations - OrtExternalSemaphoreHandle for imported D3D12 timeline fences - SessionGetEpDeviceForOutputs to query output EP device placement - RunOptions_SetSyncStream to associate sync stream for async execution EP Plugin API: - OrtExternalResourceImporterImpl interface for EP implementations - OrtEpFactory::CreateExternalResourceImporterForDevice extension Design: - No GPU virtual addresses in public API - EP-agnostic design allows any EP to implement import - Capability discovery with explicit ORT_NOT_IMPLEMENTED - Follows existing patterns (Allocator, DataTransfer, SyncStream) Includes example_plugin_ep mock implementation and autoep tests. --- .../onnxruntime/core/framework/run_options.h | 5 + .../core/session/onnxruntime_c_api.h | 301 ++++++++++++ .../core/session/onnxruntime_cxx_api.h | 11 + .../core/session/onnxruntime_cxx_inline.h | 18 + .../core/session/onnxruntime_ep_c_api.h | 194 ++++++++ onnxruntime/core/session/inference_session.cc | 42 ++ onnxruntime/core/session/inference_session.h | 9 + onnxruntime/core/session/onnxruntime_c_api.cc | 428 ++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 49 ++ .../session/plugin_ep/ep_factory_internal.cc | 1 + .../session/plugin_ep/ep_factory_internal.h | 5 + .../plugin_ep/ep_factory_internal_impl.h | 8 + .../plugin_ep/ep_factory_provider_bridge.h | 10 + .../plugin_ep/forward_to_factory_impl.h | 7 + .../ep_external_resource_importer.cc | 279 ++++++++++++ .../ep_external_resource_importer.h | 121 +++++ .../library/example_plugin_ep/ep_factory.cc | 32 ++ .../library/example_plugin_ep/ep_factory.h | 6 + .../autoep/test_external_resource_importer.cc | 424 +++++++++++++++++ 19 files changed, 1950 insertions(+) create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h create mode 100644 onnxruntime/test/autoep/test_external_resource_importer.cc diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index e63ab044834f5..001fa158345ab 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -51,6 +51,11 @@ struct OrtRunOptions { onnxruntime::InlinedVector active_adapters; + // Optional sync stream for external resource import. + // When set, the EP uses this stream for execution, enabling proper + // synchronization with imported external semaphores. + OrtSyncStream* sync_stream = nullptr; + OrtRunOptions() = default; ~OrtRunOptions() = default; }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d1b652229e4b6..9225cfe6ba1c7 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -330,6 +330,9 @@ ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. ORT_RUNTIME_CLASS(ExternalInitializerInfo); +ORT_RUNTIME_CLASS(ExternalResourceImporter); // Capability object for external resource import +ORT_RUNTIME_CLASS(ExternalMemoryHandle); // EP-imported view of shared external allocation +ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared external semaphore #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -955,6 +958,87 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n * * \nosubgrouping */ + +/** \addtogroup Global + * @{ + */ + +/** \brief External memory handle type for importing GPU resources. + * + * \since Version 1.24. + */ +typedef enum OrtExternalMemoryHandleType { + ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(resource) */ + ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 1, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(heap) */ +} OrtExternalMemoryHandleType; + +/** \brief Access mode for imported external memory. + * + * \since Version 1.24. + */ +typedef enum OrtExternalMemoryAccessMode { + ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE = 0, /**< Memory can be read and written */ + ORT_EXTERNAL_MEMORY_ACCESS_READ_ONLY = 1, /**< Memory is read-only */ + ORT_EXTERNAL_MEMORY_ACCESS_WRITE_ONLY = 2, /**< Memory is write-only */ +} OrtExternalMemoryAccessMode; + +/** \brief Descriptor for importing external memory. + * + * \note The version field must be set to ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +#define ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION 1 +typedef struct OrtExternalMemoryDescriptor { + uint32_t version; /**< Must be ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION */ + OrtExternalMemoryHandleType handle_type; /**< Type of the external memory handle */ + void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ + size_t size_bytes; /**< Total size in bytes of the external allocation */ + size_t offset_bytes; /**< Offset in bytes into the allocation (default 0) */ + OrtExternalMemoryAccessMode access_mode; /**< Access mode for the imported memory */ +} OrtExternalMemoryDescriptor; + +/** \brief External semaphore type for GPU synchronization. + * + * \since Version 1.24. + */ +typedef enum OrtExternalSemaphoreType { + ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(fence) */ +} OrtExternalSemaphoreType; + +/** \brief Descriptor for importing external semaphores. + * + * \note The version field must be set to ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +#define ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION 1 +typedef struct OrtExternalSemaphoreDescriptor { + uint32_t version; /**< Must be ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION */ + OrtExternalSemaphoreType type; /**< Type of the external semaphore */ + void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ +} OrtExternalSemaphoreDescriptor; + +/** \brief Descriptor for creating a tensor from imported external memory. + * + * \note The version field must be set to ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +#define ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION 1 +typedef struct OrtExternalTensorDescriptor { + uint32_t version; /**< Must be ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION */ + ONNXTensorElementDataType element_type; /**< Data type of tensor elements */ + const int64_t* shape; /**< Array of dimension sizes */ + size_t rank; /**< Number of dimensions */ + size_t offset_bytes; /**< Optional offset within imported memory (default 0) */ +} OrtExternalTensorDescriptor; + +/// @} + /* * Public enum for compiled model compatibility across EPs. */ @@ -6608,6 +6692,223 @@ struct OrtApi { * \since Version 1.24 */ ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); + + /// \name External Resource Import + /// @{ + + /** \brief Create an external resource importer for a specific EP device. + * + * The external resource importer is a capability object that provides methods for importing + * external GPU memory and semaphores for zero-copy import with an execution provider. + * + * \param[in] ep_device The OrtEpDevice instance to create the importer for. + * \param[out] out_importer Output parameter set to the created OrtExternalResourceImporter instance. + * Returns nullptr if the EP does not support external resource import. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateExternalResourceImporterForDevice, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + + /** \brief Release an OrtExternalResourceImporter instance. + * + * \param[in] importer The OrtExternalResourceImporter instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalResourceImporter); + + /** \brief Check if the external resource importer can import a specific memory handle type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] handle_type The type of external memory handle to check. + * \param[out] out_supported Set to true if the handle type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ExternalResourceImporter_CanImportMemory, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + + /** \brief Import external memory into the execution provider. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external memory handle and properties. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ExternalResourceImporter_ImportMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + + /** \brief Release an OrtExternalMemoryHandle instance. + * + * \param[in] handle The OrtExternalMemoryHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalMemoryHandle); + + /** \brief Create a tensor backed by imported external memory. + * + * The created tensor is a view over the imported memory and does not copy data. + * The OrtExternalMemoryHandle must remain valid for the lifetime of the tensor. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] mem_handle The imported external memory handle. + * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. + * \param[in] tensor_location Optional OrtMemoryInfo for the tensor location. May be nullptr. + * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ExternalResourceImporter_CreateTensorFromMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor); + + /** \brief Check if the external resource importer can import a specific semaphore type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] type The type of external semaphore to check. + * \param[out] out_supported Set to true if the semaphore type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ExternalResourceImporter_CanImportSemaphore, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + + /** \brief Import an external semaphore into the execution provider. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external semaphore handle and type. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ExternalResourceImporter_ImportSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + + /** \brief Release an OrtExternalSemaphoreHandle instance. + * + * \param[in] handle The OrtExternalSemaphoreHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalSemaphoreHandle); + + /** \brief Wait on an external semaphore on the EP's stream. + * + * Inserts a wait operation into the EP's stream that blocks until the semaphore + * reaches the specified value. This is used to synchronize with external GPU work + * (e.g., D3D12 timeline fence). + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to wait on. + * \param[in] value The fence/semaphore value to wait for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ExternalResourceImporter_WaitSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Signal an external semaphore from the EP's stream. + * + * Inserts a signal operation into the EP's stream that sets the semaphore + * to the specified value when reached. This is used to notify external GPU work + * (e.g., D3D12 timeline fence) that ORT inference is complete. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to signal from. + * \param[in] value The fence/semaphore value to signal. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ExternalResourceImporter_SignalSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Get the EP device assigned to each session output. + * + * Returns the OrtEpDevice assigned to each output of the session after graph partitioning. + * This allows validation that outputs are placed on the expected device for external resource sharing. + * + * The EP device for each output is determined by which execution provider claims that output + * during graph partitioning. This information is useful for: + * - Validating that outputs will be placed on the expected device for external resource sharing + * - Deciding whether to use external memory handles for outputs + * + * \param[in] session The OrtSession instance to query. + * \param[out] outputs_ep_devices An array to be filled with the EP device for each output. + * The array must be allocated by the caller with space for + * OrtEpDevice* values for each output. + * The order is the same as returned by SessionGetOutputName. + * \param[in] num_outputs The number of outputs in the session. Must match SessionGetOutputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24 + */ + ORT_API2_STATUS(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_outputs); + + /** \brief Associate an OrtSyncStream with run options. + * + * Associates an OrtSyncStream with OrtRunOptions for use with Run() or RunWithBinding(). + * When a sync stream is set, the EP uses this stream for execution, enabling proper + * synchronization with imported external semaphores. + * + * This approach: + * - Works with both Run() and RunWithBinding() — no IOBinding requirement + * - Allows different Run calls to use different streams for concurrent inference + * - Integrates cleanly with the external semaphore wait/signal pattern + * + * \param[in] run_options The OrtRunOptions instance to modify. + * \param[in] stream The OrtSyncStream to associate with the run options. May be nullptr to clear. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24 + */ + ORT_API2_STATUS(RunOptions_SetSyncStream, + _Inout_ OrtRunOptions* run_options, + _In_opt_ OrtSyncStream* stream); + + /// @} }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index bc75aabc7e229..b4dcdcc7bcca3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1287,6 +1287,16 @@ struct RunOptions : detail::Base { * \param adapter The LoraAdapter to be used as the active adapter */ RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); + + /** \brief Associate a sync stream with the run options. + * + * When set, the EP uses this stream for execution, enabling proper + * synchronization with imported external semaphores. + * + * Wraps OrtApi::RunOptions_SetSyncStream + * \param stream The OrtSyncStream to associate with these run options. May be nullptr to clear. + */ + RunOptions& SetSyncStream(OrtSyncStream* stream); }; namespace detail { @@ -1607,6 +1617,7 @@ struct ConstSessionImpl : Base { std::vector GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs std::vector GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs std::vector GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs + std::vector GetEpDeviceForOutputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForOutputs /** \brief Returns a copy of input name at the specified index. * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index aff1061a67fea..0622afb681ddb 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -994,6 +994,11 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) return *this; } +inline RunOptions& RunOptions::SetSyncStream(OrtSyncStream* stream) { + ThrowOnError(GetApi().RunOptions_SetSyncStream(p_, stream)); + return *this; +} + inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) { ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_)); } @@ -1660,6 +1665,19 @@ inline std::vector ConstSessionImpl::GetEpDeviceForInputs() co return input_devices; } +template +inline std::vector ConstSessionImpl::GetEpDeviceForOutputs() const { + auto num_outputs = GetOutputCount(); + std::vector output_devices; + if (num_outputs > 0) { + output_devices.resize(num_outputs); + ThrowOnError(GetApi().SessionGetEpDeviceForOutputs(this->p_, + reinterpret_cast(output_devices.data()), + num_outputs)); + } + return output_devices; +} + template inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { uint64_t out; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 6fa5c8dea04e6..edd7ab657c8c4 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -24,6 +24,10 @@ ORT_RUNTIME_CLASS(DataTransferImpl); ORT_RUNTIME_CLASS(SyncNotificationImpl); ORT_RUNTIME_CLASS(SyncStreamImpl); +ORT_RUNTIME_CLASS(ExternalResourceImporterImpl); +ORT_RUNTIME_CLASS(ExternalMemoryHandleImpl); +ORT_RUNTIME_CLASS(ExternalSemaphoreHandleImpl); + // Opaque types for kernel-based EPs ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(KernelDefBuilder); @@ -190,6 +194,170 @@ struct OrtSyncStreamImpl { ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr); }; +/** \brief Struct that an EP implements for external resource import (memory + semaphore import). + * + * This capability object provides methods for importing external GPU memory and semaphores + * for zero-copy import. EPs that support D3D12, CUDA, HIP, or Vulkan external resource APIs + * can implement this interface. + * + * \since Version 1.24. + */ +struct OrtExternalResourceImporterImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + // Memory operations (stream-independent) + + /** \brief Check if the implementation can import external memory of the given handle type. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle_type The type of external memory handle to check. + * \return True if the handle type is supported. + * + * \since Version 1.24. + */ + ORT_API_T(bool, CanImportMemory, + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type); + + /** \brief Import external memory. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] desc Descriptor containing the external memory handle and properties. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandleImpl. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandleImpl** out_handle); + + /** \brief Release an imported external memory handle. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The OrtExternalMemoryHandleImpl to release. + * + * \since Version 1.24. + */ + ORT_API_T(void, ReleaseMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleImpl* handle); + + /** \brief Create a tensor backed by imported external memory. + * + * The created tensor is a view over the imported memory and does not copy data. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] mem_handle The imported external memory handle. + * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. + * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateTensorFromMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor); + + // Semaphore operations (require stream) + + /** \brief Check if the implementation can import external semaphores of the given type. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] type The type of external semaphore to check. + * \return True if the semaphore type is supported. + * + * \since Version 1.24. + */ + ORT_API_T(bool, CanImportSemaphore, + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type); + + /** \brief Import an external semaphore. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] desc Descriptor containing the external semaphore handle and type. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandleImpl. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle); + + /** \brief Release an imported external semaphore handle. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The OrtExternalSemaphoreHandleImpl to release. + * + * \since Version 1.24. + */ + ORT_API_T(void, ReleaseSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle); + + /** \brief Wait on an external semaphore on the EP's stream. + * + * Inserts a wait operation into the EP's stream that blocks until the semaphore + * reaches the specified value. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to wait on. + * \param[in] value The fence/semaphore value to wait for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(WaitSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Signal an external semaphore from the EP's stream. + * + * Inserts a signal operation into the EP's stream that sets the semaphore + * to the specified value when reached. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to signal from. + * \param[in] value The fence/semaphore value to signal. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SignalSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + // Release the capability object itself + + /** \brief Release the OrtExternalResourceImporterImpl instance. + * + * This is called by ORT when the OrtExternalResourceImporterImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * + * \since Version 1.24. + */ + ORT_API_T(void, Release, _In_ OrtExternalResourceImporterImpl* this_ptr); +}; + struct OrtNodeFusionOptions; typedef struct OrtNodeFusionOptions OrtNodeFusionOptions; @@ -1413,6 +1581,32 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); + + /** \brief Create an OrtExternalResourceImporterImpl for external resource import. + * + * This is used to create an external resource importer that enables zero-copy import of + * external GPU memory (e.g., D3D12 shared resources) and synchronization primitives + * (e.g., D3D12 timeline fences). + * + * EPs that support external resource import (via CUDA, HIP, Vulkan, or D3D12 APIs) can + * implement this to allow applications to share GPU resources without copies. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_device The OrtMemoryDevice to create the external resource importer for. + * \param[out] out_importer The created OrtExternalResourceImporterImpl instance. + * Set to nullptr if external resource import is not supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. + * An EP factory should only implement this if it supports external resource import. + * If not implemented or not supported, return ORT_NOT_IMPLEMENTED or set out_importer to nullptr. + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer); }; #ifdef __cplusplus diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5523dc78b5d2..e9b0c2f230263 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3480,6 +3480,48 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector& ep_devices) const { + ep_devices.clear(); + +#if defined(ORT_MINIMAL_BUILD) + return common::Status(common::ONNXRUNTIME, common::FAIL, + "GetEpDeviceForOutputs is not available in a minimal build."); +#else + if (!is_inited_) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized."); + } + + std::pair outputs = GetModelOutputs(); + + ORT_RETURN_IF_ERROR(outputs.first); + + const auto& def_list = *outputs.second; + ep_devices.reserve(def_list.size()); + + const auto& available_eps = environment_.GetOrtEpDevices(); + + for (const auto* def : def_list) { + InlinedVector node_info_vec; + ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + assert(!node_info_vec.empty()); + // If we have an output that is not produced by any node, + // then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } + } + + return Status::OK(); +#endif +} + common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { { std::lock_guard l(session_mutex_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 8bea15c169ed4..ff54f6fa7bca0 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -484,6 +484,15 @@ class InferenceSession { * This is required for a user to know the location of the input/output when autoep selection is enabled. */ common::Status GetEpDeviceForInputs(InlinedVector& memory_info) const; + + /** + * Get the OrtEpDevice (if available) for the outputs of the model. + * + * This is required for a user to validate that outputs will be placed on the expected device + * for external resource sharing. + */ + common::Status GetEpDeviceForOutputs(InlinedVector& memory_info) const; + /** * Get the current number of in-progress concurrent Run calls. */ diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 82f7cef4aec49..191969832db40 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3372,6 +3372,34 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession* ort_session, + _Out_writes_(num_values) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_values) { + API_IMPL_BEGIN + if (ort_session == nullptr || outputs_ep_devices == nullptr || num_values == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to SessionGetEpDeviceForOutputs."); + } + + auto session = reinterpret_cast(ort_session); + + InlinedVector ep_devices; + + ORT_API_RETURN_IF_STATUS_NOT_OK(session->GetEpDeviceForOutputs(ep_devices)); + + auto num_found = ep_devices.size(); + if (num_found > num_values) { + auto msg = MakeString("Number of outputs ", num_found, " exceeds the provided size of ", num_values); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, msg.c_str()); + } + + for (size_t i = 0; i < num_values; ++i) { + outputs_ep_devices[i] = (i < num_found) ? ep_devices[i] : nullptr; + } + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStream** ort_stream) { @@ -3536,6 +3564,297 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, API_IMPL_END } +// ============================================================================ +// External Resource Importer APIs (Version 1.24) +// ============================================================================ + +// Wrapper class for OrtExternalResourceImporterImpl +namespace { +struct ExternalResourceImporterWrapper { + const OrtEpDevice* ep_device; + OrtExternalResourceImporterImpl* impl; + + ExternalResourceImporterWrapper(const OrtEpDevice* device, OrtExternalResourceImporterImpl* importer) + : ep_device(device), impl(importer) {} + + ~ExternalResourceImporterWrapper() { + if (impl && impl->Release) { + impl->Release(impl); + } + } + + // Non-copyable + ExternalResourceImporterWrapper(const ExternalResourceImporterWrapper&) = delete; + ExternalResourceImporterWrapper& operator=(const ExternalResourceImporterWrapper&) = delete; +}; + +struct ExternalMemoryHandleWrapper { + OrtExternalResourceImporterImpl* importer_impl; // Not owned + OrtExternalMemoryHandleImpl* impl; + + ExternalMemoryHandleWrapper(OrtExternalResourceImporterImpl* importer, OrtExternalMemoryHandleImpl* handle) + : importer_impl(importer), impl(handle) {} + + ~ExternalMemoryHandleWrapper() { + if (importer_impl && impl && importer_impl->ReleaseMemory) { + importer_impl->ReleaseMemory(importer_impl, impl); + } + } + + // Non-copyable + ExternalMemoryHandleWrapper(const ExternalMemoryHandleWrapper&) = delete; + ExternalMemoryHandleWrapper& operator=(const ExternalMemoryHandleWrapper&) = delete; +}; + +struct ExternalSemaphoreHandleWrapper { + OrtExternalResourceImporterImpl* importer_impl; // Not owned + OrtExternalSemaphoreHandleImpl* impl; + + ExternalSemaphoreHandleWrapper(OrtExternalResourceImporterImpl* importer, OrtExternalSemaphoreHandleImpl* handle) + : importer_impl(importer), impl(handle) {} + + ~ExternalSemaphoreHandleWrapper() { + if (importer_impl && impl && importer_impl->ReleaseSemaphore) { + importer_impl->ReleaseSemaphore(importer_impl, impl); + } + } + + // Non-copyable + ExternalSemaphoreHandleWrapper(const ExternalSemaphoreHandleWrapper&) = delete; + ExternalSemaphoreHandleWrapper& operator=(const ExternalSemaphoreHandleWrapper&) = delete; +}; +} // namespace + +ORT_API_STATUS_IMPL(OrtApis::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { + API_IMPL_BEGIN + if (ep_device == nullptr || out_importer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and out_importer must be provided."); + } + + *out_importer = nullptr; + + const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr; + if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not use DEFAULT memory of a non-CPU device."); + } + + const auto* factory = ep_device->ep_factory; + if (factory == nullptr || factory->CreateExternalResourceImporterForDevice == nullptr) { + // EP doesn't support external resource import - not an error, just return nullptr + return nullptr; + } + + OrtExternalResourceImporterImpl* impl = nullptr; + ORT_API_RETURN_IF_ERROR(factory->CreateExternalResourceImporterForDevice( + ep_device->GetMutableFactory(), + static_cast(device), + &impl)); + + if (impl == nullptr) { + // EP supports the factory method but returned null - not supported for this device + return nullptr; + } + + auto wrapper = std::make_unique(ep_device, impl); + *out_importer = reinterpret_cast(wrapper.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer) { + if (importer != nullptr) { + delete reinterpret_cast(importer); + } +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportMemory == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportMemory(wrapper->impl, handle_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External memory import is not supported by this EP."); + } + + OrtExternalMemoryHandleImpl* impl = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportMemory(wrapper->impl, desc, &impl)); + + if (impl == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportMemory returned null handle."); + } + + auto mem_wrapper = std::make_unique(wrapper->impl, impl); + *out_handle = reinterpret_cast(mem_wrapper.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle) { + if (handle != nullptr) { + delete reinterpret_cast(handle); + } +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* /*tensor_location*/, + _Outptr_ OrtValue** out_tensor) { + API_IMPL_BEGIN + if (importer == nullptr || mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, mem_handle, tensor_desc, and out_tensor must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + auto* mem_wrapper = reinterpret_cast(mem_handle); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->CreateTensorFromMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateTensorFromMemory is not supported by this EP."); + } + + OrtValue* tensor = nullptr; + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->CreateTensorFromMemory(imp_wrapper->impl, mem_wrapper->impl, tensor_desc, &tensor)); + + *out_tensor = tensor; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportSemaphore == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportSemaphore(wrapper->impl, type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External semaphore import is not supported by this EP."); + } + + OrtExternalSemaphoreHandleImpl* impl = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportSemaphore(wrapper->impl, desc, &impl)); + + if (impl == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportSemaphore returned null handle."); + } + + auto sem_wrapper = std::make_unique(wrapper->impl, impl); + *out_handle = reinterpret_cast(sem_wrapper.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle) { + if (handle != nullptr) { + delete reinterpret_cast(handle); + } +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + auto* sem_wrapper = reinterpret_cast(semaphore_handle); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->WaitSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "WaitSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->WaitSemaphore(imp_wrapper->impl, sem_wrapper->impl, stream, value)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + auto* sem_wrapper = reinterpret_cast(semaphore_handle); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->SignalSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SignalSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->SignalSemaphore(imp_wrapper->impl, sem_wrapper->impl, stream, value)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::RunOptions_SetSyncStream, _Inout_ OrtRunOptions* run_options, + _In_opt_ OrtSyncStream* stream) { + API_IMPL_BEGIN + if (run_options == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "run_options must be provided."); + } + + run_options->sync_stream = stream; + + return nullptr; + API_IMPL_END +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { @@ -3621,6 +3940,101 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, API_IMPL_END } +// External Resource Importer minimal build stubs +ORT_API_STATUS_IMPL(OrtApis::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* /*ep_device*/, + _Outptr_result_maybenull_ OrtExternalResourceImporter** /*out_importer*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateExternalResourceImporterForDevice is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* /*importer*/) { + fprintf(stderr, "External resource import is not supported in a minimal build.\n"); +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportMemory, _In_ const OrtExternalResourceImporter* /*importer*/, + _In_ OrtExternalMemoryHandleType /*handle_type*/, + _Out_ bool* /*out_supported*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportMemory, _In_ OrtExternalResourceImporter* /*importer*/, + _In_ const OrtExternalMemoryDescriptor* /*desc*/, + _Outptr_ OrtExternalMemoryHandle** /*out_handle*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* /*handle*/) { + fprintf(stderr, "External resource import is not supported in a minimal build.\n"); +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CreateTensorFromMemory, _In_ OrtExternalResourceImporter* /*importer*/, + _In_ const OrtExternalMemoryHandle* /*mem_handle*/, + _In_ const OrtExternalTensorDescriptor* /*tensor_desc*/, + _In_opt_ const OrtMemoryInfo* /*tensor_location*/, + _Outptr_ OrtValue** /*out_tensor*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportSemaphore, _In_ const OrtExternalResourceImporter* /*importer*/, + _In_ OrtExternalSemaphoreType /*type*/, + _Out_ bool* /*out_supported*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportSemaphore, _In_ OrtExternalResourceImporter* /*importer*/, + _In_ const OrtExternalSemaphoreDescriptor* /*desc*/, + _Outptr_ OrtExternalSemaphoreHandle** /*out_handle*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API(void, OrtApis::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* /*handle*/) { + fprintf(stderr, "External resource import is not supported in a minimal build.\n"); +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_WaitSemaphore, _In_ OrtExternalResourceImporter* /*importer*/, + _In_ OrtExternalSemaphoreHandle* /*semaphore_handle*/, + _In_ OrtSyncStream* /*stream*/, + _In_ uint64_t /*value*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_SignalSemaphore, _In_ OrtExternalResourceImporter* /*importer*/, + _In_ OrtExternalSemaphoreHandle* /*semaphore_handle*/, + _In_ OrtSyncStream* /*stream*/, + _In_ uint64_t /*value*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession* /*ort_session*/, + _Out_writes_(num_values) const OrtEpDevice** /*outputs_ep_devices*/, + _In_ size_t /*num_values*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForOutputs is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::RunOptions_SetSyncStream, _Inout_ OrtRunOptions* /*run_options*/, + _In_opt_ OrtSyncStream* /*stream*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "RunOptions_SetSyncStream is not supported in a minimal build."); + API_IMPL_END +} + #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -4238,6 +4652,20 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::TensorTypeAndShape_HasShape, &OrtApis::KernelInfo_GetConfigEntries, + + &OrtApis::CreateExternalResourceImporterForDevice, + &OrtApis::ReleaseExternalResourceImporter, + &OrtApis::ExternalResourceImporter_CanImportMemory, + &OrtApis::ExternalResourceImporter_ImportMemory, + &OrtApis::ReleaseExternalMemoryHandle, + &OrtApis::ExternalResourceImporter_CreateTensorFromMemory, + &OrtApis::ExternalResourceImporter_CanImportSemaphore, + &OrtApis::ExternalResourceImporter_ImportSemaphore, + &OrtApis::ReleaseExternalSemaphoreHandle, + &OrtApis::ExternalResourceImporter_WaitSemaphore, + &OrtApis::ExternalResourceImporter_SignalSemaphore, + &OrtApis::SessionGetEpDeviceForOutputs, + &OrtApis::RunOptions_SetSyncStream, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f3525d8de7b95..9ada8dd66a010 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -755,4 +755,53 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); +// External Resource Importer APIs +ORT_API_STATUS_IMPL(CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + +ORT_API(void, ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer); + +ORT_API_STATUS_IMPL(ExternalResourceImporter_CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ExternalResourceImporter_ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + +ORT_API(void, ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle); + +ORT_API_STATUS_IMPL(ExternalResourceImporter_CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor); + +ORT_API_STATUS_IMPL(ExternalResourceImporter_CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ExternalResourceImporter_ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + +ORT_API(void, ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle); + +ORT_API_STATUS_IMPL(ExternalResourceImporter_WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +ORT_API_STATUS_IMPL(ExternalResourceImporter_SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +ORT_API_STATUS_IMPL(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_outputs); + +ORT_API_STATUS_IMPL(RunOptions_SetSyncStream, _Inout_ OrtRunOptions* run_options, + _In_opt_ OrtSyncStream* stream); + } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 364bab471ddbe..8dc92802aa84b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -33,6 +33,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::IsStreamAware = Forward::IsStreamAware; OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions; + OrtEpFactory::CreateExternalResourceImporterForDevice = Forward::CreateExternalResourceImporterForDevice; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 6eb83a117fb63..dbe5bc20a876a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -91,6 +91,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->SetEnvironmentOptions(options); } + OrtStatus* CreateExternalResourceImporterForDevice(_In_ const OrtMemoryDevice* device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + return impl_->CreateExternalResourceImporterForDevice(device, importer); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index de9e2d44431bf..b5240f48847d4 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -88,6 +88,14 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* CreateExternalResourceImporterForDevice( + _In_ const OrtMemoryDevice* /*device*/, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + // Default implementation does not support external resource import + *importer = nullptr; + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 8c5ef526baba1..ed69cd001b120 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -65,6 +65,16 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); } + OrtStatus* CreateExternalResourceImporterForDevice( + const OrtMemoryDevice* device, + OrtExternalResourceImporterImpl** importer) noexcept override { + if (ep_factory_.CreateExternalResourceImporterForDevice == nullptr) { + *importer = nullptr; + return nullptr; + } + return ep_factory_.CreateExternalResourceImporterForDevice(&ep_factory_, device, importer); + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 65c396181f0a7..8e08a64ff8c96 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -87,6 +87,13 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->SetEnvironmentOptions(options); } + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDevice( + _In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + return static_cast(this_ptr)->CreateExternalResourceImporterForDevice(device, importer); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc new file mode 100644 index 0000000000000..d53cf8be800f8 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc @@ -0,0 +1,279 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_external_resource_importer.h" + +#include +#include +#include +#include +#include + +ExampleExternalResourceImporter::ExampleExternalResourceImporter(int device_id, const ApiPtrs& apis) + : OrtExternalResourceImporterImpl{}, device_id_{device_id}, apis_{apis} { + ort_version_supported = ORT_API_VERSION; + + // Memory operations + CanImportMemory = CanImportMemoryImpl; + ImportMemory = ImportMemoryImpl; + ReleaseMemory = ReleaseMemoryImpl; + CreateTensorFromMemory = CreateTensorFromMemoryImpl; + + // Semaphore operations + CanImportSemaphore = CanImportSemaphoreImpl; + ImportSemaphore = ImportSemaphoreImpl; + ReleaseSemaphore = ReleaseSemaphoreImpl; + WaitSemaphore = WaitSemaphoreImpl; + SignalSemaphore = SignalSemaphoreImpl; + + // Release + Release = ReleaseImpl; +} + +/*static*/ +bool ORT_API_CALL ExampleExternalResourceImporter::CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalMemoryHandleType handle_type) noexcept { + // The example EP supports both D3D12 resource and heap handle types for testing + return handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE || + handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandleImpl** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportMemory"); + } + + *out_handle = nullptr; + + // Validate handle type + if (!CanImportMemoryImpl(this_ptr, desc->handle_type)) { + return impl.apis_.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external memory handle type"); + } + + // In a real implementation, you would: + // 1. Open/import the native handle (e.g., cuImportExternalMemory for CUDA) + // 2. Map the memory to get a device pointer + // + // For testing purposes, we simulate this by allocating CPU memory + // that mirrors the size of the external allocation. + + auto* handle = new (std::nothrow) ExampleExternalMemoryHandle(); + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external memory handle"); + } + + // Allocate simulated memory (using CPU memory for the example) + size_t effective_size = desc->size_bytes - desc->offset_bytes; + handle->simulated_ptr = malloc(effective_size); + if (handle->simulated_ptr == nullptr) { + delete handle; + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate simulated memory"); + } + + // Initialize to zero + memset(handle->simulated_ptr, 0, effective_size); + + handle->size_bytes = desc->size_bytes; + handle->offset_bytes = desc->offset_bytes; + handle->handle_type = desc->handle_type; + handle->access_mode = desc->access_mode; + + *out_handle = reinterpret_cast(handle); + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalMemoryHandleImpl* handle) noexcept { + if (handle == nullptr) { + return; + } + + auto* mem_handle = reinterpret_cast(handle); + delete mem_handle; // destructor frees simulated_ptr +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept { + auto& impl = *static_cast(this_ptr); + + if (mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to CreateTensorFromMemory"); + } + + *out_tensor = nullptr; + + auto* handle = reinterpret_cast(mem_handle); + + // Calculate the data pointer with tensor offset + void* data_ptr = static_cast(handle->simulated_ptr) + tensor_desc->offset_bytes; + + // For the example EP, we use CPU memory info since we're simulating with CPU memory + // In a real implementation, you would use the appropriate GPU memory info + OrtMemoryInfo* memory_info = nullptr; + OrtStatus* status = impl.apis_.ort_api.CreateMemoryInfo( + "Cpu", // For testing, we use CPU memory + OrtDeviceAllocator, + 0, // device ID + OrtMemTypeDefault, + &memory_info); + + if (status != nullptr) { + return status; + } + + // Calculate buffer size + size_t buffer_size = handle->size_bytes - handle->offset_bytes - tensor_desc->offset_bytes; + + // Create tensor with pre-allocated memory + status = impl.apis_.ort_api.CreateTensorWithDataAsOrtValue( + memory_info, + data_ptr, + buffer_size, + tensor_desc->shape, + tensor_desc->rank, + tensor_desc->element_type, + out_tensor); + + impl.apis_.ort_api.ReleaseMemoryInfo(memory_info); + return status; +} + +/*static*/ +bool ORT_API_CALL ExampleExternalResourceImporter::CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalSemaphoreType type) noexcept { + // The example EP supports D3D12 fence for testing + return type == ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportSemaphore"); + } + + *out_handle = nullptr; + + // Validate semaphore type + if (!CanImportSemaphoreImpl(this_ptr, desc->type)) { + return impl.apis_.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external semaphore type"); + } + + // In a real implementation, you would: + // 1. Import the native fence handle (e.g., cuImportExternalSemaphore for CUDA) + // + // For testing purposes, we create a simulated semaphore using an atomic counter + + auto* handle = new (std::nothrow) ExampleExternalSemaphoreHandle(); + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external semaphore handle"); + } + + handle->type = desc->type; + handle->value.store(0); + + *out_handle = reinterpret_cast(handle); + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalSemaphoreHandleImpl* handle) noexcept { + if (handle == nullptr) { + return; + } + + auto* sem_handle = reinterpret_cast(handle); + delete sem_handle; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to WaitSemaphore"); + } + + // stream can be nullptr for synchronous wait + (void)stream; + + auto* sem_handle = reinterpret_cast(handle); + + // In a real implementation, you would: + // 1. Queue a wait operation on the GPU stream (e.g., cuWaitExternalSemaphoresAsync) + // + // For testing, we do a simple spin-wait on the atomic counter + // with a reasonable timeout to prevent infinite loops in tests + + const int max_iterations = 10000; + int iterations = 0; + while (sem_handle->value.load() < value && iterations < max_iterations) { + std::this_thread::sleep_for(std::chrono::microseconds(100)); + ++iterations; + } + + if (iterations >= max_iterations) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "WaitSemaphore timed out"); + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to SignalSemaphore"); + } + + // stream can be nullptr for synchronous signal + (void)stream; + + auto* sem_handle = reinterpret_cast(handle); + + // In a real implementation, you would: + // 1. Queue a signal operation on the GPU stream (e.g., cuSignalExternalSemaphoresAsync) + // + // For testing, we simply update the atomic counter + + sem_handle->value.store(value); + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h new file mode 100644 index 0000000000000..7dcd8f42313f3 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../plugin_ep_utils.h" + +#include +#include +#include + +/** + * @brief Example implementation of external memory handle. + * + * This mock implementation simulates imported external memory for testing purposes. + * In a real EP, this would hold a GPU-mapped pointer from an imported D3D12/Vulkan/CUDA resource. + */ +struct ExampleExternalMemoryHandle { + void* simulated_ptr; ///< Simulated mapped pointer (CPU memory for testing) + size_t size_bytes; ///< Size of the imported memory + size_t offset_bytes; ///< Offset into the imported memory + OrtExternalMemoryHandleType handle_type; ///< Original handle type + OrtExternalMemoryAccessMode access_mode; ///< Access mode for the imported memory + + ExampleExternalMemoryHandle() + : simulated_ptr(nullptr), size_bytes(0), offset_bytes(0), handle_type(ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE), access_mode(ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE) {} + + ~ExampleExternalMemoryHandle() { + // Free the simulated pointer if allocated + if (simulated_ptr != nullptr) { + free(simulated_ptr); + } + } +}; + +/** + * @brief Example implementation of external semaphore handle. + * + * This mock implementation simulates imported external semaphores for testing purposes. + * In a real EP, this would hold an imported D3D12 fence / Vulkan semaphore / CUDA external semaphore. + */ +struct ExampleExternalSemaphoreHandle { + OrtExternalSemaphoreType type; ///< Original semaphore type + std::atomic value; ///< Simulated fence value for testing + + ExampleExternalSemaphoreHandle() + : type(ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE), value(0) {} +}; + +/** + * @brief Example implementation of OrtExternalResourceImporterImpl. + * + * This is a mock implementation that simulates external resource interop for testing + * the ORT public API without requiring actual D3D12/CUDA/Vulkan hardware. + * + * Key features: + * - Reports support for D3D12 resource/heap and fence handle types + * - Creates simulated memory mappings using CPU memory + * - Simulates fence wait/signal operations + * - Allows tensor creation from "imported" memory + */ +class ExampleExternalResourceImporter : public OrtExternalResourceImporterImpl { + public: + ExampleExternalResourceImporter(int device_id, const ApiPtrs& apis); + + // ──────────────── Memory operations ──────────────── + + static bool ORT_API_CALL CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type) noexcept; + + static OrtStatus* ORT_API_CALL ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandleImpl** out_handle) noexcept; + + static void ORT_API_CALL ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleImpl* handle) noexcept; + + static OrtStatus* ORT_API_CALL CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept; + + // ──────────────── Semaphore operations ──────────────── + + static bool ORT_API_CALL CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type) noexcept; + + static OrtStatus* ORT_API_CALL ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle) noexcept; + + static void ORT_API_CALL ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle) noexcept; + + static OrtStatus* ORT_API_CALL WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept; + + static OrtStatus* ORT_API_CALL SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept; + + // ──────────────── Release ──────────────── + + static void ORT_API_CALL ReleaseImpl(_In_ OrtExternalResourceImporterImpl* this_ptr) noexcept; + + private: + int device_id_; + ApiPtrs apis_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 85f9504b14a4d..3f0d8b335e361 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -37,6 +37,8 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", @@ -308,3 +310,33 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac return nullptr; } + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + OrtExternalResourceImporterImpl** out_importer) noexcept { + auto& factory = *static_cast(this_ptr); + + if (out_importer == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "out_importer cannot be nullptr"); + } + + *out_importer = nullptr; + + // For the example EP, we support external resource import on the default (GPU-simulated) device memory + if (factory.ep_api.MemoryDevice_GetMemoryType(memory_device) != OrtDeviceMemoryType_DEFAULT) { + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "External resource import only supported for DEFAULT device memory"); + } + + // Get the device ID from the memory device + auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); + + // Create the external resource importer + auto importer = std::make_unique(device_id, factory); + *out_importer = importer.release(); + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 196e67fc5c558..c7bfbcfb918a1 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -7,6 +7,7 @@ #include "ep_arena.h" #include "ep_data_transfer.h" +#include "ep_external_resource_importer.h" #include "../plugin_ep_utils.h" /// @@ -67,6 +68,11 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept; + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + OrtExternalResourceImporterImpl** out_importer) noexcept; + const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name diff --git a/onnxruntime/test/autoep/test_external_resource_importer.cc b/onnxruntime/test/autoep/test_external_resource_importer.cc new file mode 100644 index 0000000000000..f13934e9e94ce --- /dev/null +++ b/onnxruntime/test/autoep/test_external_resource_importer.cc @@ -0,0 +1,424 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Tests for the External Resource Interop API using the example_plugin_ep. +// This tests the public ORT API without requiring actual D3D12/CUDA hardware. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +class ExternalResourceImporterTest : public ::testing::Test { + protected: + void SetUp() override { + // Register the example EP and get the device using shared utility + Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, registered_ep_); + ASSERT_NE(registered_ep_.get(), nullptr) << "Example EP device not found"; + ep_device_ = registered_ep_.get(); + } + + RegisteredEpDeviceUniquePtr registered_ep_; + const OrtEpDevice* ep_device_ = nullptr; +}; + +// Test: Create External Resource Importer +TEST_F(ExternalResourceImporterTest, CreateExternalResourceImporter) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + + if (status != nullptr) { + std::string error = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "CreateExternalResourceImporterForDevice not supported: " << error; + } + + ASSERT_NE(importer, nullptr) << "External resource importer should not be null"; + + // Release the importer + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Memory Import Capability +TEST_F(ExternalResourceImporterTest, CanImportMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Check D3D12 Resource support + bool can_import_resource = false; + status = Ort::GetApi().ExternalResourceImporter_CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import_resource); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_RESOURCE should succeed"; + EXPECT_TRUE(can_import_resource) << "Example EP should support D3D12 Resource import"; + + // Check D3D12 Heap support + bool can_import_heap = false; + status = Ort::GetApi().ExternalResourceImporter_CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, &can_import_heap); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_HEAP should succeed"; + EXPECT_TRUE(can_import_heap) << "Example EP should support D3D12 Heap import"; + + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Semaphore Import Capability +TEST_F(ExternalResourceImporterTest, CanImportSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Check D3D12 Fence support + bool can_import_fence = false; + status = Ort::GetApi().ExternalResourceImporter_CanImportSemaphore( + importer, ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE, &can_import_fence); + ASSERT_EQ(status, nullptr) << "CanImportSemaphore for D3D12_FENCE should succeed"; + EXPECT_TRUE(can_import_fence) << "Example EP should support D3D12 Fence import"; + + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Import Memory (Simulated) +TEST_F(ExternalResourceImporterTest, ImportMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Import memory (using a dummy handle for testing) + const size_t buffer_size = 1024 * sizeof(float); + void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); // Simulated handle + + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = dummy_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr) << "ImportMemory should succeed"; + ASSERT_NE(mem_handle, nullptr) << "Memory handle should not be null"; + + // Release memory handle + Ort::GetApi().ReleaseExternalMemoryHandle(mem_handle); + + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Create Tensor from Imported Memory +TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Create tensor shape: [1, 3, 32, 32] + const int64_t batch = 1, channels = 3, height = 32, width = 32; + const int64_t shape[] = {batch, channels, height, width}; + const size_t num_elements = batch * channels * height * width; + const size_t buffer_size = num_elements * sizeof(float); + + // Import memory + void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); + + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = dummy_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr); + + // Create tensor from imported memory + OrtExternalTensorDescriptor tensor_desc = {}; + tensor_desc.version = ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION; + tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + tensor_desc.shape = shape; + tensor_desc.rank = 4; + tensor_desc.offset_bytes = 0; + + OrtValue* tensor = nullptr; + status = Ort::GetApi().ExternalResourceImporter_CreateTensorFromMemory( + importer, mem_handle, &tensor_desc, nullptr, &tensor); + ASSERT_EQ(status, nullptr) << "CreateTensorFromMemory should succeed"; + ASSERT_NE(tensor, nullptr) << "Tensor should not be null"; + + // Verify tensor properties + OrtTensorTypeAndShapeInfo* type_info = nullptr; + status = Ort::GetApi().GetTensorTypeAndShape(tensor, &type_info); + ASSERT_EQ(status, nullptr); + + size_t rank = 0; + Ort::GetApi().GetDimensionsCount(type_info, &rank); + EXPECT_EQ(rank, 4u); + + std::vector actual_shape(rank); + Ort::GetApi().GetDimensions(type_info, actual_shape.data(), rank); + EXPECT_EQ(actual_shape[0], batch); + EXPECT_EQ(actual_shape[1], channels); + EXPECT_EQ(actual_shape[2], height); + EXPECT_EQ(actual_shape[3], width); + + ONNXTensorElementDataType elem_type; + Ort::GetApi().GetTensorElementType(type_info, &elem_type); + EXPECT_EQ(elem_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + Ort::GetApi().ReleaseTensorTypeAndShapeInfo(type_info); + + // Cleanup + Ort::GetApi().ReleaseValue(tensor); + Ort::GetApi().ReleaseExternalMemoryHandle(mem_handle); + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Import Semaphore (Simulated) +TEST_F(ExternalResourceImporterTest, ImportSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Import semaphore (using a dummy handle for testing) + void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = dummy_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = Ort::GetApi().ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; + ASSERT_NE(sem_handle, nullptr) << "Semaphore handle should not be null"; + + // Release semaphore handle + Ort::GetApi().ReleaseExternalSemaphoreHandle(sem_handle); + + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Wait and Signal Semaphore (Simulated) +TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Create a stream for the EP + OrtSyncStream* stream = nullptr; + status = Ort::GetApi().CreateSyncStreamForEpDevice(ep_device_, nullptr, &stream); + ASSERT_EQ(status, nullptr) << "CreateSyncStreamForEpDevice should succeed"; + + // Import semaphore + void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = dummy_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = Ort::GetApi().ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr); + + // Signal the semaphore with value 1 + status = Ort::GetApi().ExternalResourceImporter_SignalSemaphore(importer, sem_handle, stream, 1); + ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; + + // Wait for value 1 (should succeed immediately since we just signaled it) + status = Ort::GetApi().ExternalResourceImporter_WaitSemaphore(importer, sem_handle, stream, 1); + ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; + + // Signal with value 5 + status = Ort::GetApi().ExternalResourceImporter_SignalSemaphore(importer, sem_handle, stream, 5); + ASSERT_EQ(status, nullptr); + + // Wait for value 3 (should succeed since current value is 5) + status = Ort::GetApi().ExternalResourceImporter_WaitSemaphore(importer, sem_handle, stream, 3); + ASSERT_EQ(status, nullptr); + + // Cleanup + Ort::GetApi().ReleaseExternalSemaphoreHandle(sem_handle); + Ort::GetApi().ReleaseSyncStream(stream); + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Multiple Memory Imports +TEST_F(ExternalResourceImporterTest, MultipleMemoryImports) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + constexpr int kNumBuffers = 5; + std::vector handles(kNumBuffers); + + // Import multiple memory regions + for (int i = 0; i < kNumBuffers; ++i) { + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = reinterpret_cast(static_cast(0x10000000 + i * 0x1000)); + mem_desc.size_bytes = (i + 1) * 1024; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &handles[i]); + ASSERT_EQ(status, nullptr) << "ImportMemory " << i << " should succeed"; + ASSERT_NE(handles[i], nullptr); + } + + // Release all handles + for (int i = 0; i < kNumBuffers; ++i) { + Ort::GetApi().ReleaseExternalMemoryHandle(handles[i]); + } + + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Access Mode Variations +TEST_F(ExternalResourceImporterTest, AccessModeVariations) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + const OrtExternalMemoryAccessMode access_modes[] = { + ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE, + ORT_EXTERNAL_MEMORY_ACCESS_READ_ONLY, + ORT_EXTERNAL_MEMORY_ACCESS_WRITE_ONLY}; + + for (auto access_mode : access_modes) { + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = reinterpret_cast(static_cast(0x12345678)); + mem_desc.size_bytes = 4096; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = access_mode; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr) << "ImportMemory with access_mode " << access_mode << " should succeed"; + ASSERT_NE(mem_handle, nullptr); + + Ort::GetApi().ReleaseExternalMemoryHandle(mem_handle); + } + + Ort::GetApi().ReleaseExternalResourceImporter(importer); +} + +// Test: SessionGetEpDeviceForOutputs +TEST_F(ExternalResourceImporterTest, SessionGetEpDeviceForOutputs) { + // Load a simple model with the example EP + Ort::SessionOptions session_options; + + // Add the example EP to the session + const OrtEpDevice* devices[] = {ep_device_}; + OrtStatus* status = Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( + session_options, *ort_env, devices, 1, nullptr, nullptr, 0); + if (status != nullptr) { + std::string error = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "Example EP not available: " << error; + } + + // Create session with test model (mul_1.onnx - a simple model the example EP supports) + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + + // Get output count + size_t num_outputs = session.GetOutputCount(); + ASSERT_GT(num_outputs, 0U) << "Model should have at least one output"; + + // Get EP devices for outputs + std::vector output_devices(num_outputs); + status = Ort::GetApi().SessionGetEpDeviceForOutputs( + session, output_devices.data(), num_outputs); + ASSERT_EQ(status, nullptr) << "SessionGetEpDeviceForOutputs should succeed"; + + // Validate that we got EP devices (may be nullptr if not assigned to EP) + // At least verify the call succeeded and returned valid array + for (size_t i = 0; i < num_outputs; ++i) { + if (output_devices[i] != nullptr) { + // If an EP device is returned, validate it has a name + const char* ep_name = Ort::GetApi().EpDevice_EpName(output_devices[i]); + ASSERT_NE(ep_name, nullptr) << "EP device should have a name"; + } + } +} + +// Test: RunOptions_SetSyncStream +TEST_F(ExternalResourceImporterTest, RunOptionsSetSyncStream) { + // Create run options + Ort::RunOptions run_options; + + // Set sync stream to nullptr (which is valid - clears the stream) + OrtStatus* status = Ort::GetApi().RunOptions_SetSyncStream(run_options, nullptr); + ASSERT_EQ(status, nullptr) << "RunOptions_SetSyncStream with nullptr should succeed"; + + // Try to get a real sync stream from the EP device + OrtSyncStream* stream = nullptr; + status = Ort::GetApi().CreateSyncStreamForEpDevice(ep_device_, nullptr, &stream); + if (status != nullptr) { + std::string error = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + // Sync stream not supported - just test with nullptr + return; + } + + // Set the sync stream on run options + status = Ort::GetApi().RunOptions_SetSyncStream(run_options, stream); + ASSERT_EQ(status, nullptr) << "RunOptions_SetSyncStream with stream should succeed"; + + // Clean up + Ort::GetApi().ReleaseSyncStream(stream); +} + +// Test: RunOptions_SetSyncStream with Invalid Arguments +TEST_F(ExternalResourceImporterTest, RunOptionsSetSyncStreamInvalidArgs) { + // Test with nullptr run_options + OrtStatus* status = Ort::GetApi().RunOptions_SetSyncStream(nullptr, nullptr); + ASSERT_NE(status, nullptr) << "RunOptions_SetSyncStream with nullptr run_options should fail"; + + OrtErrorCode error_code = Ort::GetApi().GetErrorCode(status); + EXPECT_EQ(error_code, ORT_INVALID_ARGUMENT); + Ort::GetApi().ReleaseStatus(status); +} + +} // namespace test +} // namespace onnxruntime From 5865b01e466662c426770e022b5b4b3a926a7e3f Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Thu, 18 Dec 2025 16:02:54 -0500 Subject: [PATCH 02/10] [TRTRTX EP] Implement OrtExternalResourceImporter for D3D12-CUDA interop This commit adds the OrtExternalResourceImporter implementation for the NvTensorRtRtx execution provider, enabling zero-copy D3D12 to CUDA memory sharing and GPU synchronization. Implementation: - NvTrtRtxExternalResourceImporterImpl: Full implementation of the OrtExternalResourceImporter interface using CUDA Driver APIs - Memory import: cuImportExternalMemory for D3D12_RESOURCE and D3D12_HEAP - Semaphore import: cuImportExternalSemaphore for D3D12_FENCE - Tensor creation: CreateTensorFromMemory wraps imported CUDA device pointers - Synchronization: WaitSemaphore/SignalSemaphore using cuWaitExternalSemaphoresAsync/cuSignalExternalSemaphoresAsync Tests (nv_external_resource_importer_test.cc): - CreateExternalResourceImporter: Basic importer creation - CanImportMemoryCapabilities: D3D12 Resource/Heap capability queries - CanImportSemaphoreCapabilities: D3D12 Fence capability queries - ImportD3D12SharedResource: Memory import validation - CreateTensorFromImportedMemory: Tensor creation with CUDA device ptr verification - ImportD3D12Fence: Semaphore import validation - WaitAndSignalSemaphore: Bidirectional D3D12-CUDA sync - FullInferenceWithExternalMemory: E2E test with ReLU model verifying D3D12 upload -> CUDA inference -> D3D12 readback pipeline --- cmake/onnxruntime_providers_nv.cmake | 4 +- .../nv_tensorrt_rtx/nv_provider_factory.cc | 519 ++++++++++- .../nv_external_resource_importer_test.cc | 848 ++++++++++++++++++ 3 files changed, 1367 insertions(+), 4 deletions(-) create mode 100644 onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index e59463b6b91f1..5ec45a64e46bb 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -146,9 +146,9 @@ endif () target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen) add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) else() - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) endif() target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e5015e705958d..77154bfcabe71 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -4,6 +4,7 @@ #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/framework/provider_options.h" @@ -19,6 +20,8 @@ #include "nv_data_transfer.h" #include "nv_allocator.h" +#include + using namespace onnxruntime; namespace onnxruntime { @@ -516,32 +519,481 @@ struct NvTrtRtxSyncStreamImpl : OrtSyncStreamImpl { const OrtApi& ort_api; }; +// External Resource Import Implementation (D3D12 ↔ CUDA) +/** + * @brief Wrapper for imported external memory from D3D12 to CUDA. + * + * This struct holds the CUDA external memory object and the mapped device pointer + * that can be used for zero-copy tensor creation. + */ +struct NvTrtRtxExternalMemoryHandle { + CUexternalMemory ext_memory; ///< CUDA external memory object + CUdeviceptr mapped_ptr; ///< Mapped device pointer for tensor access + size_t size_bytes; ///< Size of the imported memory + size_t offset_bytes; ///< Offset into the imported memory + OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking + bool is_dedicated; ///< Whether the D3D12 resource is a dedicated allocation + + NvTrtRtxExternalMemoryHandle() + : ext_memory(nullptr), mapped_ptr(0), size_bytes(0), offset_bytes(0), handle_type(ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE), is_dedicated(true) {} +}; + +/** + * @brief Wrapper for imported external semaphore from D3D12 fence to CUDA. + * + * D3D12 timeline fences are imported as CUDA external semaphores, enabling + * GPU-GPU synchronization between D3D12 and CUDA streams. + */ +struct NvTrtRtxExternalSemaphoreHandle { + CUexternalSemaphore ext_semaphore; ///< CUDA external semaphore object + OrtExternalSemaphoreType type; ///< Original semaphore type + + NvTrtRtxExternalSemaphoreHandle() + : ext_semaphore(nullptr), type(ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE) {} +}; + +/** + * @brief Implementation of OrtExternalResourceImporterImpl for NvTensorRtRtx EP. + * + * This struct implements the external resource importer interface using CUDA Driver APIs + * to import D3D12 shared resources and timeline fences for zero-copy import. + * + * Supported handle types: + * - ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE → CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + * - ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP → CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP + * - ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE → CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE + */ +struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { + NvTrtRtxExternalResourceImporterImpl(int device_id, const OrtApi& ort_api_in) + : device_id_{device_id}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + + // Memory operations + CanImportMemory = CanImportMemoryImpl; + ImportMemory = ImportMemoryImpl; + ReleaseMemory = ReleaseMemoryImpl; + CreateTensorFromMemory = CreateTensorFromMemoryImpl; + + // Semaphore operations + CanImportSemaphore = CanImportSemaphoreImpl; + ImportSemaphore = ImportSemaphoreImpl; + ReleaseSemaphore = ReleaseSemaphoreImpl; + WaitSemaphore = WaitSemaphoreImpl; + SignalSemaphore = SignalSemaphoreImpl; + + // Release + Release = ReleaseImpl; + } + + // ──────────────── Memory operations ──────────────── + + static bool ORT_API_CALL CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type) noexcept { + (void)this_ptr; + // CUDA supports both D3D12 resource and heap handles + return handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE || + handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP; + } + + static OrtStatus* ORT_API_CALL ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandleImpl** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportMemory"); + } + + // Validate descriptor version + if (desc->version != ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtExternalMemoryDescriptor version"); + } + + *out_handle = nullptr; + + // Validate handle type + if (!CanImportMemoryImpl(this_ptr, desc->handle_type)) { + return impl.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external memory handle type for CUDA import"); + } + + // Validate offset does not exceed allocation size + if (desc->offset_bytes > desc->size_bytes) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "offset_bytes exceeds size_bytes in OrtExternalMemoryDescriptor"); + } + + // Set CUDA device + CUresult cu_result = cuCtxSetCurrent(nullptr); // Reset context + CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.device_id_)); + + // Map ORT handle type to CUDA handle type + CUexternalMemoryHandleType cu_handle_type; + bool is_dedicated = true; + switch (desc->handle_type) { + case ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE: + cu_handle_type = CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + is_dedicated = true; // D3D12 committed resources are dedicated + break; + case ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP: + cu_handle_type = CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP; + is_dedicated = false; // D3D12 heaps are not dedicated + break; + default: + return impl.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Unknown external memory handle type"); + } + + // Setup external memory handle descriptor + CUDA_EXTERNAL_MEMORY_HANDLE_DESC ext_mem_desc = {}; + ext_mem_desc.type = cu_handle_type; + ext_mem_desc.handle.win32.handle = desc->native_handle; + ext_mem_desc.size = desc->size_bytes; + ext_mem_desc.flags = is_dedicated ? CUDA_EXTERNAL_MEMORY_DEDICATED : 0; + + // Import the external memory + CUexternalMemory ext_memory = nullptr; + cu_result = cuImportExternalMemory(&ext_memory, &ext_mem_desc); + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuImportExternalMemory failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + // Map the external memory to get a device pointer + CUDA_EXTERNAL_MEMORY_BUFFER_DESC buffer_desc = {}; + buffer_desc.offset = desc->offset_bytes; + buffer_desc.size = desc->size_bytes - desc->offset_bytes; + buffer_desc.flags = 0; + + CUdeviceptr mapped_ptr = 0; + cu_result = cuExternalMemoryGetMappedBuffer(&mapped_ptr, ext_memory, &buffer_desc); + if (cu_result != CUDA_SUCCESS) { + cuDestroyExternalMemory(ext_memory); + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuExternalMemoryGetMappedBuffer failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + // Create and return the handle wrapper + auto* handle = new (std::nothrow) NvTrtRtxExternalMemoryHandle(); + if (handle == nullptr) { + cuDestroyExternalMemory(ext_memory); + return impl.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external memory handle"); + } + + handle->ext_memory = ext_memory; + handle->mapped_ptr = mapped_ptr; + handle->size_bytes = desc->size_bytes; + handle->offset_bytes = desc->offset_bytes; + handle->handle_type = desc->handle_type; + handle->is_dedicated = is_dedicated; + + *out_handle = reinterpret_cast(handle); + return nullptr; + } + + static void ORT_API_CALL ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleImpl* handle) noexcept { + auto& impl = *static_cast(this_ptr); + (void)impl; + + if (handle == nullptr) { + return; + } + + auto* mem_handle = reinterpret_cast(handle); + + // Destroy the external memory object (also releases mapped buffer) + if (mem_handle->ext_memory != nullptr) { + cuDestroyExternalMemory(mem_handle->ext_memory); + } + + delete mem_handle; + } + + static OrtStatus* ORT_API_CALL CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept { + auto& impl = *static_cast(this_ptr); + + if (mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to CreateTensorFromMemory"); + } + + // Validate descriptor version + if (tensor_desc->version != ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtExternalTensorDescriptor version"); + } + + *out_tensor = nullptr; + + auto* handle = reinterpret_cast(mem_handle); + + // Validate tensor offset does not exceed available buffer size + size_t available_size = handle->size_bytes - handle->offset_bytes; + if (tensor_desc->offset_bytes > available_size) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "tensor offset_bytes exceeds available imported memory size"); + } + + // Calculate the data pointer with tensor offset + void* data_ptr = reinterpret_cast(handle->mapped_ptr + tensor_desc->offset_bytes); + + // Create memory info for CUDA device + OrtMemoryInfo* memory_info = nullptr; + OrtStatus* status = impl.ort_api.CreateMemoryInfo_V2( + "NvTensorRTRTX", + OrtMemoryInfoDeviceType_GPU, + OrtDevice::VendorIds::NVIDIA, + impl.device_id_, + OrtDeviceMemoryType_DEFAULT, + 0, // alignment + OrtDeviceAllocator, + &memory_info); + + if (status != nullptr) { + return status; + } + + // Create tensor with pre-allocated memory + status = impl.ort_api.CreateTensorWithDataAsOrtValue( + memory_info, + data_ptr, + handle->size_bytes - tensor_desc->offset_bytes, + tensor_desc->shape, + tensor_desc->rank, + tensor_desc->element_type, + out_tensor); + + impl.ort_api.ReleaseMemoryInfo(memory_info); + return status; + } + + // ──────────────── Semaphore operations ──────────────── + + static bool ORT_API_CALL CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type) noexcept { + (void)this_ptr; + // CUDA supports D3D12 timeline fences + return type == ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + } + + static OrtStatus* ORT_API_CALL ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportSemaphore"); + } + + // Validate descriptor version + if (desc->version != ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid OrtExternalSemaphoreDescriptor version"); + } + + *out_handle = nullptr; + + // Validate semaphore type + if (!CanImportSemaphoreImpl(this_ptr, desc->type)) { + return impl.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external semaphore type for CUDA import"); + } + + // Set CUDA device + CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.device_id_)); + + // Setup external semaphore handle descriptor for D3D12 fence + CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC ext_sem_desc = {}; + ext_sem_desc.type = CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE; + ext_sem_desc.handle.win32.handle = desc->native_handle; + ext_sem_desc.flags = 0; + + // Import the external semaphore + CUexternalSemaphore ext_semaphore = nullptr; + CUresult cu_result = cuImportExternalSemaphore(&ext_semaphore, &ext_sem_desc); + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuImportExternalSemaphore failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + // Create and return the handle wrapper + auto* handle = new (std::nothrow) NvTrtRtxExternalSemaphoreHandle(); + if (handle == nullptr) { + cuDestroyExternalSemaphore(ext_semaphore); + return impl.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external semaphore handle"); + } + + handle->ext_semaphore = ext_semaphore; + handle->type = desc->type; + + *out_handle = reinterpret_cast(handle); + return nullptr; + } + + static void ORT_API_CALL ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle) noexcept { + (void)this_ptr; + + if (handle == nullptr) { + return; + } + + auto* sem_handle = reinterpret_cast(handle); + + if (sem_handle->ext_semaphore != nullptr) { + cuDestroyExternalSemaphore(sem_handle->ext_semaphore); + } + + delete sem_handle; + } + + static OrtStatus* ORT_API_CALL WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr || stream == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to WaitSemaphore"); + } + + auto* sem_handle = reinterpret_cast(handle); + + // Get the CUDA stream from OrtSyncStream + cudaStream_t cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + + // Setup wait parameters for D3D12 fence (timeline semaphore) + CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS wait_params = {}; + wait_params.params.fence.value = value; + wait_params.flags = 0; + + // Wait on the external semaphore asynchronously + CUresult cu_result = cuWaitExternalSemaphoresAsync( + &sem_handle->ext_semaphore, + &wait_params, + 1, // numExtSems + cuda_stream); + + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuWaitExternalSemaphoresAsync failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + return nullptr; + } + + static OrtStatus* ORT_API_CALL SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr || stream == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to SignalSemaphore"); + } + + auto* sem_handle = reinterpret_cast(handle); + + // Get the CUDA stream from OrtSyncStream + cudaStream_t cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + + // Setup signal parameters for D3D12 fence (timeline semaphore) + CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS signal_params = {}; + signal_params.params.fence.value = value; + signal_params.flags = 0; + + // Signal the external semaphore asynchronously + CUresult cu_result = cuSignalExternalSemaphoresAsync( + &sem_handle->ext_semaphore, + &signal_params, + 1, // numExtSems + cuda_stream); + + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuSignalExternalSemaphoresAsync failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } + + return nullptr; + } + + // ──────────────── Release ──────────────── + + static void ORT_API_CALL ReleaseImpl(_In_ OrtExternalResourceImporterImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + private: + int device_id_; + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + // OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection. struct NvTensorRtRtxEpFactory : OrtEpFactory { using MemoryInfoUniquePtr = std::unique_ptr>; NvTensorRtRtxEpFactory(const OrtApi& ort_api_in, - const OrtLogger& default_logger_in) : ort_api{ort_api_in}, + const OrtLogger& default_logger_in) : OrtEpFactory{}, // Zero-initialize base struct + ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()}, default_logger{default_logger_in}, data_transfer_impl{ort_api_in} { + // Initialize all OrtEpFactory function pointers explicitly to avoid garbage values + // Required members GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; - GetVendorId = GetVendorIdImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + // Allocator CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; + // Data transfer CreateDataTransfer = CreateDataTransferImpl; + // Stream support IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + // Optional members - explicitly set to nullptr if not implemented + ValidateCompiledModelCompatibilityInfo = nullptr; // Not implemented + SetEnvironmentOptions = nullptr; // Not implemented + + // External resource import (D3D12 to CUDA) + CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } @@ -735,6 +1187,69 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return nullptr; } + /** @brief Create an external resource importer for D3D12 to CUDA import. + * + * This enables zero-copy import of D3D12 shared resources and timeline fences. + * The implementation uses CUDA Driver APIs (cuImportExternalMemory, cuImportExternalSemaphore). + * + * @param this_ptr The OrtEpFactory instance. + * @param memory_device The OrtMemoryDevice to create the importer for. + * @param out_importer Output parameter set to the created OrtExternalResourceImporterImpl. + * @return nullptr on success, OrtStatus with error on failure. + */ + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + OrtExternalResourceImporterImpl** out_importer) noexcept { + auto& factory = *static_cast(this_ptr); + + if (out_importer == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "out_importer cannot be nullptr"); + } + + *out_importer = nullptr; + + // Check memory type - only DEFAULT device memory is supported + auto mem_type = factory.ep_api.MemoryDevice_GetMemoryType(memory_device); + if (mem_type != OrtDeviceMemoryType_DEFAULT) { + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "External resource import only supported for DEFAULT device memory"); + } + + // Validate that this is a GPU device + OrtMemoryInfoDeviceType device_type = factory.ep_api.MemoryDevice_GetDeviceType(memory_device); + if (device_type != OrtMemoryInfoDeviceType_GPU) { + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "External resource import only supported for GPU devices"); + } + + // Get the CUDA device ID + auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); + + // Verify the device is an NVIDIA GPU + auto vendor_id = factory.ep_api.MemoryDevice_GetVendorId(memory_device); + if (vendor_id != OrtDevice::VendorIds::NVIDIA) { + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "External resource import only supported for NVIDIA GPUs"); + } + + // Verify CUDA device is valid and has necessary capabilities + int cuda_device_count = 0; + cudaError_t cuda_err = cudaGetDeviceCount(&cuda_device_count); + if (cuda_err != cudaSuccess || cuda_device_count <= 0 || + device_id >= static_cast(cuda_device_count)) { + return factory.ort_api.CreateStatus(ORT_FAIL, + "Invalid CUDA device ID for external resource import"); + } + + // Create the external resource importer + auto importer = std::make_unique(device_id, factory.ort_api); + *out_importer = importer.release(); + + return nullptr; + } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { gpu_memory_infos.reserve(num_devices); host_accessible_memory_infos.reserve(num_devices); diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc new file mode 100644 index 0000000000000..edb886abdd0c8 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc @@ -0,0 +1,848 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This test validates the D3D12 ↔ CUDA external resource import functionality +// for the NvTensorRtRtx execution provider. +// +// Test Coverage: +// 1. External Resource Importer creation and destruction +// 2. Memory import capability check (D3D12 Resource & Heap) +// 3. Semaphore import capability check (D3D12 Fence) +// 4. D3D12 shared resource import to CUDA +// 5. Tensor creation from imported external memory +// 6. D3D12 timeline fence import for GPU synchronization +// 7. Wait/Signal semaphore operations +// 8. Full inference pipeline with zero-copy external memory + +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/providers/provider_test_utils.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/util/include/scoped_env_vars.h" +#include "test/common/random_generator.h" + +#include +#include +#include +#include + +#if defined(_WIN32) +#include +#include +#include +using Microsoft::WRL::ComPtr; +#endif + +// Include CUDA headers for pointer attribute verification +#include + +using namespace std; +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::logging; + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +#if defined(_WIN32) + +// Helper functions for D3D12 resource creation +class D3D12ResourceHelper { + public: + static void CreateSharedBuffer(ID3D12Device* device, + size_t size, + ID3D12Resource** out_resource, + D3D12_RESOURCE_STATES initial_state = D3D12_RESOURCE_STATE_COMMON) { + D3D12_RESOURCE_DESC desc = {}; + desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + desc.Alignment = 0; + desc.Width = size; + desc.Height = 1; + desc.DepthOrArraySize = 1; + desc.MipLevels = 1; + desc.Format = DXGI_FORMAT_UNKNOWN; + desc.SampleDesc.Count = 1; + desc.SampleDesc.Quality = 0; + desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + desc.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS; + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_DEFAULT; + heap_props.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; + heap_props.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; + heap_props.CreationNodeMask = 1; + heap_props.VisibleNodeMask = 1; + + // Create with SHARED heap flag for cross-API import + HRESULT hr = device->CreateCommittedResource( + &heap_props, + D3D12_HEAP_FLAG_SHARED, + &desc, + initial_state, + nullptr, + IID_PPV_ARGS(out_resource)); + + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create shared D3D12 buffer, HRESULT: 0x" << std::hex << hr; + } + } + + static void CreateUploadBuffer(ID3D12Device* device, size_t size, ID3D12Resource** out_resource) { + D3D12_RESOURCE_DESC desc = {}; + desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + desc.Alignment = 0; + desc.Width = size; + desc.Height = 1; + desc.DepthOrArraySize = 1; + desc.MipLevels = 1; + desc.Format = DXGI_FORMAT_UNKNOWN; + desc.SampleDesc.Count = 1; + desc.SampleDesc.Quality = 0; + desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + desc.Flags = D3D12_RESOURCE_FLAG_NONE; + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_UPLOAD; + + HRESULT hr = device->CreateCommittedResource( + &heap_props, + D3D12_HEAP_FLAG_NONE, + &desc, + D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, + IID_PPV_ARGS(out_resource)); + + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create upload buffer, HRESULT: 0x" << std::hex << hr; + } + } + + static void CreateReadbackBuffer(ID3D12Device* device, size_t size, ID3D12Resource** out_resource) { + D3D12_RESOURCE_DESC desc = {}; + desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + desc.Alignment = 0; + desc.Width = size; + desc.Height = 1; + desc.DepthOrArraySize = 1; + desc.MipLevels = 1; + desc.Format = DXGI_FORMAT_UNKNOWN; + desc.SampleDesc.Count = 1; + desc.SampleDesc.Quality = 0; + desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + desc.Flags = D3D12_RESOURCE_FLAG_NONE; + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_READBACK; + + HRESULT hr = device->CreateCommittedResource( + &heap_props, + D3D12_HEAP_FLAG_NONE, + &desc, + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_PPV_ARGS(out_resource)); + + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create readback buffer, HRESULT: 0x" << std::hex << hr; + } + } + + static void FlushAndWait(ID3D12Device* device, ID3D12CommandQueue* queue) { + ComPtr fence; + HRESULT hr = device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&fence)); + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create fence for flush, HRESULT: 0x" << std::hex << hr; + } + + HANDLE event = CreateEvent(nullptr, FALSE, FALSE, nullptr); + queue->Signal(fence.Get(), 1); + fence->SetEventOnCompletion(1, event); + WaitForSingleObject(event, INFINITE); + CloseHandle(event); + } +}; + +// Test Fixture +class ExternalResourceImporterTest : public testing::Test { + protected: + void SetUp() override { + // Get the ORT API + ort_api_ = &Ort::GetApi(); + + // Try to create D3D12 device + HRESULT hr = D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + d3d12_available_ = true; + + // Create command queue + D3D12_COMMAND_QUEUE_DESC queue_desc = {}; + queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE; + queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; + hr = d3d12_device_->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&command_queue_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + + // Create command allocator and list + hr = d3d12_device_->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, + IID_PPV_ARGS(&command_allocator_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + + hr = d3d12_device_->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE, + command_allocator_.Get(), nullptr, + IID_PPV_ARGS(&command_list_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + command_list_->Close(); + + // Register NvTensorRtRtx EP + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, registered_ep_); + if (registered_ep_.get() == nullptr) { + ep_available_ = false; + return; + } + ep_available_ = true; + ep_device_ = registered_ep_.get(); + } + + void TearDown() override { + // Release resources + command_list_.Reset(); + command_allocator_.Reset(); + command_queue_.Reset(); + d3d12_device_.Reset(); + } + + bool IsD3D12Available() const { return d3d12_available_; } + bool IsEPAvailable() const { return ep_available_; } + + ComPtr d3d12_device_; + ComPtr command_queue_; + ComPtr command_allocator_; + ComPtr command_list_; + const OrtApi* ort_api_ = nullptr; + RegisteredEpDeviceUniquePtr registered_ep_; // RAII - auto-unregisters EP + const OrtEpDevice* ep_device_ = nullptr; + bool d3d12_available_ = false; + bool ep_available_ = false; +}; + +// Test: External Resource Importer Creation +TEST_F(ExternalResourceImporterTest, CreateExternalResourceImporter) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + // Create external resource importer + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + + if (status != nullptr) { + std::string error = ort_api_->GetErrorMessage(status); + ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "CreateExternalResourceImporterForDevice not supported: " << error; + } + + if (importer == nullptr) { + // EP doesn't support external resource import yet + GTEST_SKIP() << "External resource import not yet implemented by this EP"; + } + + // Release the importer + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Memory Import Capability Check +TEST_F(ExternalResourceImporterTest, CanImportMemoryCapabilities) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Check D3D12 Resource support + bool can_import_resource = false; + status = ort_api_->ExternalResourceImporter_CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import_resource); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_RESOURCE should succeed"; + EXPECT_TRUE(can_import_resource) << "Should support D3D12 Resource import"; + + // Check D3D12 Heap support + bool can_import_heap = false; + status = ort_api_->ExternalResourceImporter_CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, &can_import_heap); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_HEAP should succeed"; + EXPECT_TRUE(can_import_heap) << "Should support D3D12 Heap import"; + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Semaphore Import Capability Check +TEST_F(ExternalResourceImporterTest, CanImportSemaphoreCapabilities) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Check D3D12 Fence support + bool can_import_fence = false; + status = ort_api_->ExternalResourceImporter_CanImportSemaphore( + importer, ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE, &can_import_fence); + ASSERT_EQ(status, nullptr) << "CanImportSemaphore for D3D12_FENCE should succeed"; + EXPECT_TRUE(can_import_fence) << "Should support D3D12 Fence import"; + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Import D3D12 Shared Resource +TEST_F(ExternalResourceImporterTest, ImportD3D12SharedResource) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create a shared D3D12 buffer + const size_t buffer_size = 1024 * sizeof(float); + ComPtr d3d12_buffer; + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &d3d12_buffer); + + // Create shared handle + HANDLE shared_handle = nullptr; + HRESULT hr = d3d12_device_->CreateSharedHandle(d3d12_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)) << "Failed to create shared handle"; + + // Import the memory + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = shared_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr) << "ImportMemory should succeed"; + ASSERT_NE(mem_handle, nullptr) << "Memory handle should not be null"; + + // Release memory handle + ort_api_->ReleaseExternalMemoryHandle(mem_handle); + + // Close shared handle + CloseHandle(shared_handle); + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Create Tensor from Imported Memory +TEST_F(ExternalResourceImporterTest, CreateTensorFromImportedMemory) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create tensor shape: [1, 3, 32, 32] (batch, channels, height, width) + const int64_t batch = 1, channels = 3, height = 32, width = 32; + const int64_t shape[] = {batch, channels, height, width}; + const size_t num_elements = batch * channels * height * width; + const size_t buffer_size = num_elements * sizeof(float); + + // Create shared D3D12 buffer + ComPtr d3d12_buffer; + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &d3d12_buffer); + + // Create shared handle + HANDLE shared_handle = nullptr; + HRESULT hr = d3d12_device_->CreateSharedHandle(d3d12_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)); + + // Import the memory + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = shared_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr); + + // Create tensor from imported memory + OrtExternalTensorDescriptor tensor_desc = {}; + tensor_desc.version = ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION; + tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + tensor_desc.shape = shape; + tensor_desc.rank = 4; + tensor_desc.offset_bytes = 0; + + OrtValue* tensor = nullptr; + status = ort_api_->ExternalResourceImporter_CreateTensorFromMemory(importer, mem_handle, &tensor_desc, nullptr, &tensor); + ASSERT_EQ(status, nullptr) << "CreateTensorFromMemory should succeed"; + ASSERT_NE(tensor, nullptr) << "Tensor should not be null"; + + // Verify tensor properties + OrtTensorTypeAndShapeInfo* type_info = nullptr; + status = ort_api_->GetTensorTypeAndShape(tensor, &type_info); + ASSERT_EQ(status, nullptr); + + size_t rank = 0; + ort_api_->GetDimensionsCount(type_info, &rank); + EXPECT_EQ(rank, 4u); + + std::vector actual_shape(rank); + ort_api_->GetDimensions(type_info, actual_shape.data(), rank); + EXPECT_EQ(actual_shape[0], batch); + EXPECT_EQ(actual_shape[1], channels); + EXPECT_EQ(actual_shape[2], height); + EXPECT_EQ(actual_shape[3], width); + + ONNXTensorElementDataType elem_type; + ort_api_->GetTensorElementType(type_info, &elem_type); + EXPECT_EQ(elem_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + ort_api_->ReleaseTensorTypeAndShapeInfo(type_info); + + // Get the tensor's data pointer and verify it's CUDA device memory + // This proves the D3D12 to CUDA memory import actually happened + void* tensor_data = nullptr; + status = ort_api_->GetTensorMutableData(tensor, &tensor_data); + ASSERT_EQ(status, nullptr) << "GetTensorMutableData should succeed"; + ASSERT_NE(tensor_data, nullptr) << "Tensor data pointer should not be null"; + + // Use cudaPointerGetAttributes to verify this is CUDA device memory + cudaPointerAttributes attrs; + cudaError_t cuda_err = cudaPointerGetAttributes(&attrs, tensor_data); + ASSERT_EQ(cuda_err, cudaSuccess) << "cudaPointerGetAttributes failed: " << cudaGetErrorString(cuda_err); + EXPECT_EQ(attrs.type, cudaMemoryTypeDevice) + << "Memory should be CUDA device memory, but got type " << attrs.type + << " (cudaMemoryTypeDevice=" << cudaMemoryTypeDevice << ")"; + EXPECT_NE(attrs.device, -1) << "Device should be valid"; + + // Cleanup + ort_api_->ReleaseValue(tensor); + ort_api_->ReleaseExternalMemoryHandle(mem_handle); + CloseHandle(shared_handle); + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Import D3D12 Timeline Fence +TEST_F(ExternalResourceImporterTest, ImportD3D12Fence) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create a D3D12 fence with SHARED flag for cross-API import + ComPtr d3d12_fence; + HRESULT hr = d3d12_device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&d3d12_fence)); + ASSERT_TRUE(SUCCEEDED(hr)) << "Failed to create D3D12 fence"; + + // Create shared handle + HANDLE shared_handle = nullptr; + hr = d3d12_device_->CreateSharedHandle(d3d12_fence.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)) << "Failed to create shared fence handle"; + + // Import the semaphore + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = shared_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; + ASSERT_NE(sem_handle, nullptr) << "Semaphore handle should not be null"; + + // Release semaphore handle + ort_api_->ReleaseExternalSemaphoreHandle(sem_handle); + + // Close shared handle + CloseHandle(shared_handle); + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Wait and Signal Semaphore +TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create a CUDA stream via ORT + OrtSyncStream* ort_stream = nullptr; + status = ort_api_->CreateSyncStreamForEpDevice(ep_device_, nullptr, &ort_stream); + ASSERT_EQ(status, nullptr) << "CreateSyncStreamForEpDevice should succeed"; + + // Create a D3D12 fence + ComPtr d3d12_fence; + HRESULT hr = d3d12_device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&d3d12_fence)); + ASSERT_TRUE(SUCCEEDED(hr)); + + HANDLE shared_handle = nullptr; + hr = d3d12_device_->CreateSharedHandle(d3d12_fence.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)); + + // Import semaphore + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = shared_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr); + + // Signal the fence from D3D12 side + uint64_t signal_value = 1; + command_queue_->Signal(d3d12_fence.Get(), signal_value); + + // Wait on the fence from CUDA side + status = ort_api_->ExternalResourceImporter_WaitSemaphore(importer, sem_handle, ort_stream, signal_value); + ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; + + // Signal from CUDA side + uint64_t cuda_signal_value = 2; + status = ort_api_->ExternalResourceImporter_SignalSemaphore(importer, sem_handle, ort_stream, cuda_signal_value); + ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; + + // Synchronize by getting the native stream handle and calling cudaStreamSynchronize + // (In real code, the signal enqueued on the stream will complete when the stream is flushed) + void* stream_handle = ort_api_->SyncStream_GetHandle(ort_stream); + ASSERT_NE(stream_handle, nullptr); + // Note: In production code, you'd call cudaStreamSynchronize((cudaStream_t)stream_handle) + // For this test, we just verify the handle is valid + + // Verify D3D12 can see the signaled value + HANDLE wait_event = CreateEvent(nullptr, FALSE, FALSE, nullptr); + d3d12_fence->SetEventOnCompletion(cuda_signal_value, wait_event); + DWORD wait_result = WaitForSingleObject(wait_event, 5000); // 5 second timeout + CloseHandle(wait_event); + EXPECT_EQ(wait_result, WAIT_OBJECT_0) << "D3D12 should see the fence signaled by CUDA"; + + // Cleanup + ort_api_->ReleaseExternalSemaphoreHandle(sem_handle); + CloseHandle(shared_handle); + ort_api_->ReleaseSyncStream(ort_stream); + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Full Inference with External Memory (E2E) +// This test validates the complete D3D12 to CUDA interop pipeline: +// 1. Create D3D12 shared resources and fences +// 2. Import them into CUDA via OrtExternalResourceImporter +// 3. Create ORT tensors from imported memory +// 4. Run inference with proper synchronization +// 5. Verify output correctness +TEST_F(ExternalResourceImporterTest, FullInferenceWithExternalMemory) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + // Create a simple ReLU model using shared utility pattern + PathString model_path = ORT_TSTR("external_mem_relu_test.onnx"); + { + onnxruntime::Model model("relu_test", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(64); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(64); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &tensor_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &tensor_type); + graph.AddNode("relu", "Relu", "ReLU operation", {&input_arg}, {&output_arg}); + + ASSERT_STATUS_OK(graph.Resolve()); + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_path)); + } + + const int64_t batch = 1, channels = 3, dim = 64; + const int64_t shape[] = {batch, channels, dim, dim}; + const size_t num_elements = batch * channels * dim * dim; + const size_t buffer_size = num_elements * sizeof(float); + + // Create external resource importer + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + clearFileIfExists(model_path); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create CUDA stream via ORT + OrtSyncStream* ort_stream = nullptr; + status = ort_api_->CreateSyncStreamForEpDevice(ep_device_, nullptr, &ort_stream); + ASSERT_EQ(status, nullptr); + + // Create shared D3D12 buffers for input and output + ComPtr input_buffer, output_buffer; + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &input_buffer); + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &output_buffer); + + // Create shared handles for cross-API import + HANDLE input_handle = nullptr, output_handle = nullptr; + d3d12_device_->CreateSharedHandle(input_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &input_handle); + d3d12_device_->CreateSharedHandle(output_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &output_handle); + + // Import memory into CUDA + OrtExternalMemoryDescriptor input_mem_desc = {}; + input_mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + input_mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + input_mem_desc.native_handle = input_handle; + input_mem_desc.size_bytes = buffer_size; + input_mem_desc.offset_bytes = 0; + input_mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryDescriptor output_mem_desc = input_mem_desc; + output_mem_desc.native_handle = output_handle; + + OrtExternalMemoryHandle *input_mem = nullptr, *output_mem = nullptr; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &input_mem_desc, &input_mem); + ASSERT_EQ(status, nullptr) << "ImportMemory for input should succeed (proves cuImportExternalMemory called)"; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &output_mem_desc, &output_mem); + ASSERT_EQ(status, nullptr) << "ImportMemory for output should succeed"; + + // Create ORT tensors from imported memory + OrtExternalTensorDescriptor tensor_desc = {}; + tensor_desc.version = ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION; + tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + tensor_desc.shape = shape; + tensor_desc.rank = 4; + tensor_desc.offset_bytes = 0; + + OrtValue *input_tensor = nullptr, *output_tensor = nullptr; + status = ort_api_->ExternalResourceImporter_CreateTensorFromMemory(importer, input_mem, &tensor_desc, nullptr, &input_tensor); + ASSERT_EQ(status, nullptr); + status = ort_api_->ExternalResourceImporter_CreateTensorFromMemory(importer, output_mem, &tensor_desc, nullptr, &output_tensor); + ASSERT_EQ(status, nullptr); + + // Verify the tensor data pointers are CUDA device memory + void* input_data_ptr = nullptr; + void* output_data_ptr = nullptr; + status = ort_api_->GetTensorMutableData(input_tensor, &input_data_ptr); + ASSERT_EQ(status, nullptr); + status = ort_api_->GetTensorMutableData(output_tensor, &output_data_ptr); + ASSERT_EQ(status, nullptr); + + cudaPointerAttributes input_attrs, output_attrs; + ASSERT_EQ(cudaPointerGetAttributes(&input_attrs, input_data_ptr), cudaSuccess); + ASSERT_EQ(cudaPointerGetAttributes(&output_attrs, output_data_ptr), cudaSuccess); + EXPECT_EQ(input_attrs.type, cudaMemoryTypeDevice) << "Input tensor must be CUDA device memory"; + EXPECT_EQ(output_attrs.type, cudaMemoryTypeDevice) << "Output tensor must be CUDA device memory"; + + // Create D3D12 fence for bidirectional synchronization + ComPtr sync_fence; + d3d12_device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&sync_fence)); + HANDLE fence_handle = nullptr; + d3d12_device_->CreateSharedHandle(sync_fence.Get(), nullptr, GENERIC_ALL, nullptr, &fence_handle); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = fence_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; + + // Setup test data via D3D12 upload buffer + ComPtr upload_buffer; + D3D12ResourceHelper::CreateUploadBuffer(d3d12_device_.Get(), buffer_size, &upload_buffer); + + // Generate test data: alternating positive and negative values for ReLU verification + std::vector test_data(num_elements); + for (size_t i = 0; i < num_elements; ++i) { + test_data[i] = (i % 2 == 0) ? static_cast(i + 1) : -static_cast(i + 1); + } + + void* upload_ptr = nullptr; + upload_buffer->Map(0, nullptr, &upload_ptr); + memcpy(upload_ptr, test_data.data(), buffer_size); + upload_buffer->Unmap(0, nullptr); + + // Copy upload buffer to input buffer via D3D12 + command_allocator_->Reset(); + command_list_->Reset(command_allocator_.Get(), nullptr); + command_list_->CopyBufferRegion(input_buffer.Get(), 0, upload_buffer.Get(), 0, buffer_size); + command_list_->Close(); + + ID3D12CommandList* cmd_lists[] = {command_list_.Get()}; + command_queue_->ExecuteCommandLists(1, cmd_lists); + + // Signal fence after D3D12 upload completes + uint64_t upload_complete_value = 1; + command_queue_->Signal(sync_fence.Get(), upload_complete_value); + + // Make CUDA wait for D3D12 upload to complete + status = ort_api_->ExternalResourceImporter_WaitSemaphore(importer, sem_handle, ort_stream, upload_complete_value); + ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; + + // Setup ORT session with user_compute_stream (like the patch PR test) + Ort::SessionOptions session_options; + session_options.SetExecutionMode(ORT_SEQUENTIAL); + session_options.DisableMemPattern(); + session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + + // Configure to use our CUDA stream + char stream_address[32]; + size_t stream_addr_val = reinterpret_cast(ort_api_->SyncStream_GetHandle(ort_stream)); + sprintf(stream_address, "%llu", static_cast(stream_addr_val)); + const char* option_keys[] = {"user_compute_stream", "has_user_compute_stream"}; + const char* option_values[] = {stream_address, "1"}; + + // Add the NvTensorRtRtx EP with user stream + status = ort_api_->SessionOptionsAppendExecutionProvider_V2( + session_options, *ort_env, &ep_device_, 1, option_keys, option_values, 2); + ASSERT_EQ(status, nullptr); + + // Create session + Ort::Session session(*ort_env, model_path.c_str(), session_options); + + // Create IoBinding and bind external tensors + Ort::IoBinding io_binding(session); + Ort::AllocatorWithDefaultOptions allocator; + + Ort::AllocatedStringPtr input_name = session.GetInputNameAllocated(0, allocator); + Ort::AllocatedStringPtr output_name = session.GetOutputNameAllocated(0, allocator); + + io_binding.BindInput(input_name.get(), Ort::Value(input_tensor)); + io_binding.BindOutput(output_name.get(), Ort::Value(output_tensor)); + io_binding.SynchronizeInputs(); + + // Run inference with synchronization disabled (we handle it manually) + Ort::RunOptions run_options; + run_options.AddConfigEntry("disable_synchronize_execution_providers", "1"); + session.Run(run_options, io_binding); + + // Signal from CUDA that inference is complete + uint64_t inference_complete_value = 2; + status = ort_api_->ExternalResourceImporter_SignalSemaphore(importer, sem_handle, ort_stream, inference_complete_value); + ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; + + // Wait on D3D12 for CUDA inference to complete + command_queue_->Wait(sync_fence.Get(), inference_complete_value); + + // Copy output to readback buffer + ComPtr readback_buffer; + D3D12ResourceHelper::CreateReadbackBuffer(d3d12_device_.Get(), buffer_size, &readback_buffer); + + command_allocator_->Reset(); + command_list_->Reset(command_allocator_.Get(), nullptr); + command_list_->CopyBufferRegion(readback_buffer.Get(), 0, output_buffer.Get(), 0, buffer_size); + command_list_->Close(); + + command_queue_->ExecuteCommandLists(1, cmd_lists); + D3D12ResourceHelper::FlushAndWait(d3d12_device_.Get(), command_queue_.Get()); + + // Read back and verify ReLU output: max(0, x) + std::vector output_data(num_elements); + void* readback_ptr = nullptr; + readback_buffer->Map(0, nullptr, &readback_ptr); + memcpy(output_data.data(), readback_ptr, buffer_size); + readback_buffer->Unmap(0, nullptr); + + // Verify ReLU correctness + for (size_t i = 0; i < num_elements; ++i) { + float expected = std::max(0.0f, test_data[i]); + EXPECT_FLOAT_EQ(output_data[i], expected) + << "Mismatch at index " << i << ": input=" << test_data[i] + << ", expected=" << expected << ", got=" << output_data[i]; + } + + // Note: io_binding takes ownership of input_tensor and output_tensor, so don't release them manually + + // Cleanup + ort_api_->ReleaseExternalSemaphoreHandle(sem_handle); + ort_api_->ReleaseExternalMemoryHandle(output_mem); + ort_api_->ReleaseExternalMemoryHandle(input_mem); + CloseHandle(fence_handle); + CloseHandle(output_handle); + CloseHandle(input_handle); + ort_api_->ReleaseSyncStream(ort_stream); + ort_api_->ReleaseExternalResourceImporter(importer); + + clearFileIfExists(model_path); +} + +#endif // _WIN32 + +} // namespace test +} // namespace onnxruntime From 558dbd57df7675ef51439c62cdfd00a1b8c4ab11 Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Fri, 19 Dec 2025 14:19:06 -0500 Subject: [PATCH 03/10] Address PR feedback - Deleted the sync_stream member from OrtRunOptions structure. - Removed the RunOptions_SetSyncStream API and its implementation. - Updated related C++ API and example implementations to reflect the removal of sync stream functionality. - Adjusted tests to remove references to RunOptions_SetSyncStream. - Introduced new structures for external memory and semaphore handles to improve resource management. - Ensured backward compatibility by checking EP version support for external resource import. --- .../onnxruntime/core/framework/run_options.h | 5 - .../core/session/onnxruntime_c_api.h | 24 +--- .../core/session/onnxruntime_cxx_api.h | 10 -- .../core/session/onnxruntime_cxx_inline.h | 5 - .../core/session/onnxruntime_ep_c_api.h | 104 ++++++++++++++--- onnxruntime/core/session/onnxruntime_c_api.cc | 109 ++++-------------- onnxruntime/core/session/ort_apis.h | 3 - .../session/plugin_ep/ep_factory_internal.h | 4 +- .../plugin_ep/ep_factory_internal_impl.h | 2 +- .../plugin_ep/ep_factory_provider_bridge.h | 8 +- .../plugin_ep/forward_to_factory_impl.h | 4 +- .../ep_external_resource_importer.cc | 32 ++--- .../ep_external_resource_importer.h | 68 ++++++----- .../library/example_plugin_ep/ep_factory.cc | 15 +-- .../library/example_plugin_ep/ep_factory.h | 2 +- .../autoep/test_external_resource_importer.cc | 47 +------- 16 files changed, 188 insertions(+), 254 deletions(-) diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index 001fa158345ab..e63ab044834f5 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -51,11 +51,6 @@ struct OrtRunOptions { onnxruntime::InlinedVector active_adapters; - // Optional sync stream for external resource import. - // When set, the EP uses this stream for execution, enabling proper - // synchronization with imported external semaphores. - OrtSyncStream* sync_stream = nullptr; - OrtRunOptions() = default; ~OrtRunOptions() = default; }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9225cfe6ba1c7..7414d3732e83b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6741,6 +6741,7 @@ struct OrtApi { * \param[in] importer The OrtExternalResourceImporter instance. * \param[in] desc Descriptor containing the external memory handle and properties. * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle. + * The caller owns the returned handle and must call ReleaseExternalMemoryHandle to free it. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -6769,6 +6770,7 @@ struct OrtApi { * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. * \param[in] tensor_location Optional OrtMemoryInfo for the tensor location. May be nullptr. * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * The caller owns the returned tensor and must call ReleaseValue to free it. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -6886,28 +6888,6 @@ struct OrtApi { _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, _In_ size_t num_outputs); - /** \brief Associate an OrtSyncStream with run options. - * - * Associates an OrtSyncStream with OrtRunOptions for use with Run() or RunWithBinding(). - * When a sync stream is set, the EP uses this stream for execution, enabling proper - * synchronization with imported external semaphores. - * - * This approach: - * - Works with both Run() and RunWithBinding() — no IOBinding requirement - * - Allows different Run calls to use different streams for concurrent inference - * - Integrates cleanly with the external semaphore wait/signal pattern - * - * \param[in] run_options The OrtRunOptions instance to modify. - * \param[in] stream The OrtSyncStream to associate with the run options. May be nullptr to clear. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24 - */ - ORT_API2_STATUS(RunOptions_SetSyncStream, - _Inout_ OrtRunOptions* run_options, - _In_opt_ OrtSyncStream* stream); - /// @} }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index b4dcdcc7bcca3..c5de8b3e40a23 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1287,16 +1287,6 @@ struct RunOptions : detail::Base { * \param adapter The LoraAdapter to be used as the active adapter */ RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); - - /** \brief Associate a sync stream with the run options. - * - * When set, the EP uses this stream for execution, enabling proper - * synchronization with imported external semaphores. - * - * Wraps OrtApi::RunOptions_SetSyncStream - * \param stream The OrtSyncStream to associate with these run options. May be nullptr to clear. - */ - RunOptions& SetSyncStream(OrtSyncStream* stream); }; namespace detail { diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 0622afb681ddb..dc73614ef0445 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -994,11 +994,6 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) return *this; } -inline RunOptions& RunOptions::SetSyncStream(OrtSyncStream* stream) { - ThrowOnError(GetApi().RunOptions_SetSyncStream(p_, stream)); - return *this; -} - inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) { ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_)); } diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index edd7ab657c8c4..3130cc4da3238 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -25,8 +25,66 @@ ORT_RUNTIME_CLASS(SyncNotificationImpl); ORT_RUNTIME_CLASS(SyncStreamImpl); ORT_RUNTIME_CLASS(ExternalResourceImporterImpl); -ORT_RUNTIME_CLASS(ExternalMemoryHandleImpl); -ORT_RUNTIME_CLASS(ExternalSemaphoreHandleImpl); + +/** \brief Base struct for imported external memory handles. + * + * EPs derive from this struct to add EP-specific fields (e.g., CUdeviceptr for CUDA). + * EP is responsible for creating and releasing instances of the derived type. + * + * Example derived type for CUDA EP: + * \code + * struct MyCudaExternalMemoryHandle : OrtExternalMemoryHandle { + * CUexternalMemory ext_memory; + * CUdeviceptr mapped_ptr; + * bool is_dedicated; + * }; + * \endcode + * + * \since Version 1.24. + */ +#define ORT_EXTERNAL_MEMORY_HANDLE_VERSION 1 +struct OrtExternalMemoryHandle { + uint32_t version; ///< Must be ORT_EXTERNAL_MEMORY_HANDLE_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking + size_t size_bytes; ///< Size of the imported memory + size_t offset_bytes; ///< Offset into the imported memory + + /** \brief Release callback for this handle. EP sets this to its release function. + * + * ORT calls this when ReleaseExternalMemoryHandle is invoked. The EP's callback + * should cast the handle to its derived type and delete it. + */ + void(ORT_API_CALL* Release)(_In_ OrtExternalMemoryHandle* handle); +}; + +/** \brief Base struct for imported external semaphore handles. + * + * EPs derive from this struct to add EP-specific fields (e.g., CUexternalSemaphore for CUDA). + * EP is responsible for creating and releasing instances of the derived type. + * + * Example derived type for CUDA EP: + * \code + * struct MyCudaExternalSemaphoreHandle : OrtExternalSemaphoreHandle { + * CUexternalSemaphore ext_semaphore; + * }; + * \endcode + * + * \since Version 1.24. + */ +#define ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION 1 +struct OrtExternalSemaphoreHandle { + uint32_t version; ///< Must be ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalSemaphoreType type; ///< Original semaphore type + + /** \brief Release callback for this handle. EP sets this to its release function. + * + * ORT calls this when ReleaseExternalSemaphoreHandle is invoked. The EP's callback + * should cast the handle to its derived type and delete it. + */ + void(ORT_API_CALL* Release)(_In_ OrtExternalSemaphoreHandle* handle); +}; // Opaque types for kernel-based EPs ORT_RUNTIME_CLASS(KernelRegistry); @@ -220,10 +278,13 @@ struct OrtExternalResourceImporterImpl { _In_ OrtExternalMemoryHandleType handle_type); /** \brief Import external memory. + * + * The EP creates a derived type of OrtExternalMemoryHandle and returns a pointer to the base. + * EP is responsible for the lifetime of the handle (release via ReleaseMemory). * * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. * \param[in] desc Descriptor containing the external memory handle and properties. - * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandleImpl. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle (EP's derived type). * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -232,25 +293,27 @@ struct OrtExternalResourceImporterImpl { ORT_API2_STATUS(ImportMemory, _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalMemoryDescriptor* desc, - _Outptr_ OrtExternalMemoryHandleImpl** out_handle); + _Outptr_ OrtExternalMemoryHandle** out_handle); /** \brief Release an imported external memory handle. + * + * The EP deletes its derived type instance. * * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. - * \param[in] handle The OrtExternalMemoryHandleImpl to release. + * \param[in] handle The OrtExternalMemoryHandle to release (EP casts to its derived type). * * \since Version 1.24. */ ORT_API_T(void, ReleaseMemory, _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalMemoryHandleImpl* handle); + _In_ OrtExternalMemoryHandle* handle); /** \brief Create a tensor backed by imported external memory. * * The created tensor is a view over the imported memory and does not copy data. * * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. - * \param[in] mem_handle The imported external memory handle. + * \param[in] mem_handle The imported external memory handle (EP casts to its derived type). * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. * @@ -260,7 +323,7 @@ struct OrtExternalResourceImporterImpl { */ ORT_API2_STATUS(CreateTensorFromMemory, _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalMemoryHandle* mem_handle, _In_ const OrtExternalTensorDescriptor* tensor_desc, _Outptr_ OrtValue** out_tensor); @@ -279,10 +342,13 @@ struct OrtExternalResourceImporterImpl { _In_ OrtExternalSemaphoreType type); /** \brief Import an external semaphore. + * + * The EP creates a derived type of OrtExternalSemaphoreHandle and returns a pointer to the base. + * EP is responsible for the lifetime of the handle (release via ReleaseSemaphore). * * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. * \param[in] desc Descriptor containing the external semaphore handle and type. - * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandleImpl. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle (EP's derived type). * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -291,18 +357,20 @@ struct OrtExternalResourceImporterImpl { ORT_API2_STATUS(ImportSemaphore, _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalSemaphoreDescriptor* desc, - _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle); + _Outptr_ OrtExternalSemaphoreHandle** out_handle); /** \brief Release an imported external semaphore handle. + * + * The EP deletes its derived type instance. * * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. - * \param[in] handle The OrtExternalSemaphoreHandleImpl to release. + * \param[in] handle The OrtExternalSemaphoreHandle to release (EP casts to its derived type). * * \since Version 1.24. */ ORT_API_T(void, ReleaseSemaphore, _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle); + _In_ OrtExternalSemaphoreHandle* handle); /** \brief Wait on an external semaphore on the EP's stream. * @@ -310,7 +378,7 @@ struct OrtExternalResourceImporterImpl { * reaches the specified value. * * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. - * \param[in] handle The imported external semaphore. + * \param[in] handle The imported external semaphore (EP casts to its derived type). * \param[in] stream The OrtSyncStream to wait on. * \param[in] value The fence/semaphore value to wait for. * @@ -320,7 +388,7 @@ struct OrtExternalResourceImporterImpl { */ ORT_API2_STATUS(WaitSemaphore, _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value); @@ -330,7 +398,7 @@ struct OrtExternalResourceImporterImpl { * to the specified value when reached. * * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. - * \param[in] handle The imported external semaphore. + * \param[in] handle The imported external semaphore (EP casts to its derived type). * \param[in] stream The OrtSyncStream to signal from. * \param[in] value The fence/semaphore value to signal. * @@ -340,7 +408,7 @@ struct OrtExternalResourceImporterImpl { */ ORT_API2_STATUS(SignalSemaphore, _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value); @@ -1592,7 +1660,7 @@ struct OrtEpFactory { * implement this to allow applications to share GPU resources without copies. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] memory_device The OrtMemoryDevice to create the external resource importer for. + * \param[in] ep_device The OrtEpDevice to create the external resource importer for. * \param[out] out_importer The created OrtExternalResourceImporterImpl instance. * Set to nullptr if external resource import is not supported. * @@ -1605,7 +1673,7 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr, - _In_ const OrtMemoryDevice* memory_device, + _In_ const OrtEpDevice* ep_device, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer); }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 191969832db40..c658365adedec 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3583,46 +3583,9 @@ struct ExternalResourceImporterWrapper { } } - // Non-copyable - ExternalResourceImporterWrapper(const ExternalResourceImporterWrapper&) = delete; - ExternalResourceImporterWrapper& operator=(const ExternalResourceImporterWrapper&) = delete; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalResourceImporterWrapper); }; -struct ExternalMemoryHandleWrapper { - OrtExternalResourceImporterImpl* importer_impl; // Not owned - OrtExternalMemoryHandleImpl* impl; - - ExternalMemoryHandleWrapper(OrtExternalResourceImporterImpl* importer, OrtExternalMemoryHandleImpl* handle) - : importer_impl(importer), impl(handle) {} - - ~ExternalMemoryHandleWrapper() { - if (importer_impl && impl && importer_impl->ReleaseMemory) { - importer_impl->ReleaseMemory(importer_impl, impl); - } - } - - // Non-copyable - ExternalMemoryHandleWrapper(const ExternalMemoryHandleWrapper&) = delete; - ExternalMemoryHandleWrapper& operator=(const ExternalMemoryHandleWrapper&) = delete; -}; - -struct ExternalSemaphoreHandleWrapper { - OrtExternalResourceImporterImpl* importer_impl; // Not owned - OrtExternalSemaphoreHandleImpl* impl; - - ExternalSemaphoreHandleWrapper(OrtExternalResourceImporterImpl* importer, OrtExternalSemaphoreHandleImpl* handle) - : importer_impl(importer), impl(handle) {} - - ~ExternalSemaphoreHandleWrapper() { - if (importer_impl && impl && importer_impl->ReleaseSemaphore) { - importer_impl->ReleaseSemaphore(importer_impl, impl); - } - } - - // Non-copyable - ExternalSemaphoreHandleWrapper(const ExternalSemaphoreHandleWrapper&) = delete; - ExternalSemaphoreHandleWrapper& operator=(const ExternalSemaphoreHandleWrapper&) = delete; -}; } // namespace ORT_API_STATUS_IMPL(OrtApis::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, @@ -3634,13 +3597,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateExternalResourceImporterForDevice, _In_ const *out_importer = nullptr; - const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr; - if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not use DEFAULT memory of a non-CPU device."); - } - + // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. const auto* factory = ep_device->ep_factory; - if (factory == nullptr || factory->CreateExternalResourceImporterForDevice == nullptr) { + if (factory == nullptr || + factory->ort_version_supported < 24 || + factory->CreateExternalResourceImporterForDevice == nullptr) { // EP doesn't support external resource import - not an error, just return nullptr return nullptr; } @@ -3648,7 +3609,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateExternalResourceImporterForDevice, _In_ const OrtExternalResourceImporterImpl* impl = nullptr; ORT_API_RETURN_IF_ERROR(factory->CreateExternalResourceImporterForDevice( ep_device->GetMutableFactory(), - static_cast(device), + ep_device, &impl)); if (impl == nullptr) { @@ -3701,23 +3662,23 @@ ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportMemory, _In_ OrtExte return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External memory import is not supported by this EP."); } - OrtExternalMemoryHandleImpl* impl = nullptr; - ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportMemory(wrapper->impl, desc, &impl)); + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalMemoryHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportMemory(wrapper->impl, desc, &handle)); - if (impl == nullptr) { + if (handle == nullptr) { return OrtApis::CreateStatus(ORT_FAIL, "ImportMemory returned null handle."); } - auto mem_wrapper = std::make_unique(wrapper->impl, impl); - *out_handle = reinterpret_cast(mem_wrapper.release()); + *out_handle = handle; return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle) { - if (handle != nullptr) { - delete reinterpret_cast(handle); + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); } } @@ -3732,14 +3693,13 @@ ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CreateTensorFromMemory, _I } auto* imp_wrapper = reinterpret_cast(importer); - auto* mem_wrapper = reinterpret_cast(mem_handle); if (imp_wrapper->impl == nullptr || imp_wrapper->impl->CreateTensorFromMemory == nullptr) { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateTensorFromMemory is not supported by this EP."); } OrtValue* tensor = nullptr; - ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->CreateTensorFromMemory(imp_wrapper->impl, mem_wrapper->impl, tensor_desc, &tensor)); + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->CreateTensorFromMemory(imp_wrapper->impl, mem_handle, tensor_desc, &tensor)); *out_tensor = tensor; return nullptr; @@ -3778,23 +3738,23 @@ ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportSemaphore, _In_ OrtE return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External semaphore import is not supported by this EP."); } - OrtExternalSemaphoreHandleImpl* impl = nullptr; - ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportSemaphore(wrapper->impl, desc, &impl)); + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalSemaphoreHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportSemaphore(wrapper->impl, desc, &handle)); - if (impl == nullptr) { + if (handle == nullptr) { return OrtApis::CreateStatus(ORT_FAIL, "ImportSemaphore returned null handle."); } - auto sem_wrapper = std::make_unique(wrapper->impl, impl); - *out_handle = reinterpret_cast(sem_wrapper.release()); + *out_handle = handle; return nullptr; API_IMPL_END } ORT_API(void, OrtApis::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle) { - if (handle != nullptr) { - delete reinterpret_cast(handle); + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); } } @@ -3808,13 +3768,12 @@ ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_WaitSemaphore, _In_ OrtExt } auto* imp_wrapper = reinterpret_cast(importer); - auto* sem_wrapper = reinterpret_cast(semaphore_handle); if (imp_wrapper->impl == nullptr || imp_wrapper->impl->WaitSemaphore == nullptr) { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "WaitSemaphore is not supported by this EP."); } - ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->WaitSemaphore(imp_wrapper->impl, sem_wrapper->impl, stream, value)); + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->WaitSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); return nullptr; API_IMPL_END @@ -3830,26 +3789,12 @@ ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_SignalSemaphore, _In_ OrtE } auto* imp_wrapper = reinterpret_cast(importer); - auto* sem_wrapper = reinterpret_cast(semaphore_handle); if (imp_wrapper->impl == nullptr || imp_wrapper->impl->SignalSemaphore == nullptr) { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SignalSemaphore is not supported by this EP."); } - ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->SignalSemaphore(imp_wrapper->impl, sem_wrapper->impl, stream, value)); - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::RunOptions_SetSyncStream, _Inout_ OrtRunOptions* run_options, - _In_opt_ OrtSyncStream* stream) { - API_IMPL_BEGIN - if (run_options == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "run_options must be provided."); - } - - run_options->sync_stream = stream; + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->SignalSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); return nullptr; API_IMPL_END @@ -4028,13 +3973,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::RunOptions_SetSyncStream, _Inout_ OrtRunOptions* /*run_options*/, - _In_opt_ OrtSyncStream* /*stream*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "RunOptions_SetSyncStream is not supported in a minimal build."); - API_IMPL_END -} - #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -4665,7 +4603,6 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::ExternalResourceImporter_WaitSemaphore, &OrtApis::ExternalResourceImporter_SignalSemaphore, &OrtApis::SessionGetEpDeviceForOutputs, - &OrtApis::RunOptions_SetSyncStream, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9ada8dd66a010..96ea33c8027d0 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -801,7 +801,4 @@ ORT_API_STATUS_IMPL(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, _In_ size_t num_outputs); -ORT_API_STATUS_IMPL(RunOptions_SetSyncStream, _Inout_ OrtRunOptions* run_options, - _In_opt_ OrtSyncStream* stream); - } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index dbe5bc20a876a..ae98f2c0ac589 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -91,9 +91,9 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->SetEnvironmentOptions(options); } - OrtStatus* CreateExternalResourceImporterForDevice(_In_ const OrtMemoryDevice* device, + OrtStatus* CreateExternalResourceImporterForDevice(_In_ const OrtEpDevice* ep_device, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { - return impl_->CreateExternalResourceImporterForDevice(device, importer); + return impl_->CreateExternalResourceImporterForDevice(ep_device, importer); } // Function ORT calls to release an EP instance. diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index b5240f48847d4..20a47715df2b8 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -89,7 +89,7 @@ class EpFactoryInternalImpl { } virtual OrtStatus* CreateExternalResourceImporterForDevice( - _In_ const OrtMemoryDevice* /*device*/, + _In_ const OrtEpDevice* /*ep_device*/, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { // Default implementation does not support external resource import *importer = nullptr; diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index ed69cd001b120..26173f0055ed7 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -66,13 +66,15 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { } OrtStatus* CreateExternalResourceImporterForDevice( - const OrtMemoryDevice* device, + const OrtEpDevice* ep_device, OrtExternalResourceImporterImpl** importer) noexcept override { - if (ep_factory_.CreateExternalResourceImporterForDevice == nullptr) { + // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. + if (ep_factory_.ort_version_supported < 24 || + ep_factory_.CreateExternalResourceImporterForDevice == nullptr) { *importer = nullptr; return nullptr; } - return ep_factory_.CreateExternalResourceImporterForDevice(&ep_factory_, device, importer); + return ep_factory_.CreateExternalResourceImporterForDevice(&ep_factory_, ep_device, importer); } OrtEpFactory& ep_factory_; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 8e08a64ff8c96..2530ae8eb3c2b 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -89,9 +89,9 @@ struct ForwardToFactoryImpl { static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDevice( _In_ OrtEpFactory* this_ptr, - _In_ const OrtMemoryDevice* device, + _In_ const OrtEpDevice* ep_device, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { - return static_cast(this_ptr)->CreateExternalResourceImporterForDevice(device, importer); + return static_cast(this_ptr)->CreateExternalResourceImporterForDevice(ep_device, importer); } static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc index d53cf8be800f8..e8347bc28cb84 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc @@ -9,8 +9,8 @@ #include #include -ExampleExternalResourceImporter::ExampleExternalResourceImporter(int device_id, const ApiPtrs& apis) - : OrtExternalResourceImporterImpl{}, device_id_{device_id}, apis_{apis} { +ExampleExternalResourceImporter::ExampleExternalResourceImporter(const ApiPtrs& apis) + : OrtExternalResourceImporterImpl{}, apis_{apis} { ort_version_supported = ORT_API_VERSION; // Memory operations @@ -43,7 +43,7 @@ bool ORT_API_CALL ExampleExternalResourceImporter::CanImportMemoryImpl( OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalMemoryDescriptor* desc, - _Outptr_ OrtExternalMemoryHandleImpl** out_handle) noexcept { + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept { auto& impl = *static_cast(this_ptr); if (desc == nullptr || out_handle == nullptr) { @@ -86,26 +86,26 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( handle->handle_type = desc->handle_type; handle->access_mode = desc->access_mode; - *out_handle = reinterpret_cast(handle); + *out_handle = handle; return nullptr; } /*static*/ void ORT_API_CALL ExampleExternalResourceImporter::ReleaseMemoryImpl( _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, - _In_ OrtExternalMemoryHandleImpl* handle) noexcept { + _In_ OrtExternalMemoryHandle* handle) noexcept { if (handle == nullptr) { return; } - auto* mem_handle = reinterpret_cast(handle); + auto* mem_handle = static_cast(handle); delete mem_handle; // destructor frees simulated_ptr } /*static*/ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::CreateTensorFromMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalMemoryHandle* mem_handle, _In_ const OrtExternalTensorDescriptor* tensor_desc, _Outptr_ OrtValue** out_tensor) noexcept { auto& impl = *static_cast(this_ptr); @@ -116,7 +116,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::CreateTensorFromMemoryI *out_tensor = nullptr; - auto* handle = reinterpret_cast(mem_handle); + auto* handle = static_cast(mem_handle); // Calculate the data pointer with tensor offset void* data_ptr = static_cast(handle->simulated_ptr) + tensor_desc->offset_bytes; @@ -164,7 +164,7 @@ bool ORT_API_CALL ExampleExternalResourceImporter::CanImportSemaphoreImpl( OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalSemaphoreDescriptor* desc, - _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle) noexcept { + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept { auto& impl = *static_cast(this_ptr); if (desc == nullptr || out_handle == nullptr) { @@ -192,26 +192,26 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportSemaphoreImpl( handle->type = desc->type; handle->value.store(0); - *out_handle = reinterpret_cast(handle); + *out_handle = handle; return nullptr; } /*static*/ void ORT_API_CALL ExampleExternalResourceImporter::ReleaseSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, - _In_ OrtExternalSemaphoreHandleImpl* handle) noexcept { + _In_ OrtExternalSemaphoreHandle* handle) noexcept { if (handle == nullptr) { return; } - auto* sem_handle = reinterpret_cast(handle); + auto* sem_handle = static_cast(handle); delete sem_handle; } /*static*/ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::WaitSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value) noexcept { auto& impl = *static_cast(this_ptr); @@ -223,7 +223,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::WaitSemaphoreImpl( // stream can be nullptr for synchronous wait (void)stream; - auto* sem_handle = reinterpret_cast(handle); + auto* sem_handle = static_cast(handle); // In a real implementation, you would: // 1. Queue a wait operation on the GPU stream (e.g., cuWaitExternalSemaphoresAsync) @@ -248,7 +248,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::WaitSemaphoreImpl( /*static*/ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::SignalSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value) noexcept { auto& impl = *static_cast(this_ptr); @@ -260,7 +260,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::SignalSemaphoreImpl( // stream can be nullptr for synchronous signal (void)stream; - auto* sem_handle = reinterpret_cast(handle); + auto* sem_handle = static_cast(handle); // In a real implementation, you would: // 1. Queue a signal operation on the GPU stream (e.g., cuSignalExternalSemaphoresAsync) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h index 7dcd8f42313f3..64b42611b5402 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h @@ -10,41 +10,64 @@ #include /** - * @brief Example implementation of external memory handle. + * @brief Example derived handle for imported external memory. * + * Derives from OrtExternalMemoryHandle and adds example-specific fields. * This mock implementation simulates imported external memory for testing purposes. * In a real EP, this would hold a GPU-mapped pointer from an imported D3D12/Vulkan/CUDA resource. */ -struct ExampleExternalMemoryHandle { +struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { void* simulated_ptr; ///< Simulated mapped pointer (CPU memory for testing) - size_t size_bytes; ///< Size of the imported memory - size_t offset_bytes; ///< Offset into the imported memory - OrtExternalMemoryHandleType handle_type; ///< Original handle type OrtExternalMemoryAccessMode access_mode; ///< Access mode for the imported memory ExampleExternalMemoryHandle() - : simulated_ptr(nullptr), size_bytes(0), offset_bytes(0), handle_type(ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE), access_mode(ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE) {} + : simulated_ptr(nullptr), access_mode(ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE) { + // Initialize base struct fields + version = ORT_EXTERNAL_MEMORY_HANDLE_VERSION; + ep_device = nullptr; + handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + size_bytes = 0; + offset_bytes = 0; + Release = ReleaseCallback; + } ~ExampleExternalMemoryHandle() { // Free the simulated pointer if allocated if (simulated_ptr != nullptr) { free(simulated_ptr); + simulated_ptr = nullptr; } } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalMemoryHandle* handle) noexcept { + if (handle == nullptr) return; + delete static_cast(handle); + } }; /** - * @brief Example implementation of external semaphore handle. + * @brief Example derived handle for imported external semaphore. * + * Derives from OrtExternalSemaphoreHandle and adds example-specific fields. * This mock implementation simulates imported external semaphores for testing purposes. * In a real EP, this would hold an imported D3D12 fence / Vulkan semaphore / CUDA external semaphore. */ -struct ExampleExternalSemaphoreHandle { - OrtExternalSemaphoreType type; ///< Original semaphore type - std::atomic value; ///< Simulated fence value for testing +struct ExampleExternalSemaphoreHandle : OrtExternalSemaphoreHandle { + std::atomic value; ///< Simulated fence value for testing ExampleExternalSemaphoreHandle() - : type(ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE), value(0) {} + : value(0) { + // Initialize base struct fields + version = ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION; + ep_device = nullptr; + type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + Release = ReleaseCallback; + } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalSemaphoreHandle* handle) noexcept { + if (handle == nullptr) return; + delete static_cast(handle); + } }; /** @@ -61,9 +84,7 @@ struct ExampleExternalSemaphoreHandle { */ class ExampleExternalResourceImporter : public OrtExternalResourceImporterImpl { public: - ExampleExternalResourceImporter(int device_id, const ApiPtrs& apis); - - // ──────────────── Memory operations ──────────────── + ExampleExternalResourceImporter(const ApiPtrs& apis); static bool ORT_API_CALL CanImportMemoryImpl( _In_ const OrtExternalResourceImporterImpl* this_ptr, @@ -72,20 +93,18 @@ class ExampleExternalResourceImporter : public OrtExternalResourceImporterImpl { static OrtStatus* ORT_API_CALL ImportMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalMemoryDescriptor* desc, - _Outptr_ OrtExternalMemoryHandleImpl** out_handle) noexcept; + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept; static void ORT_API_CALL ReleaseMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalMemoryHandleImpl* handle) noexcept; + _In_ OrtExternalMemoryHandle* handle) noexcept; static OrtStatus* ORT_API_CALL CreateTensorFromMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalMemoryHandle* mem_handle, _In_ const OrtExternalTensorDescriptor* tensor_desc, _Outptr_ OrtValue** out_tensor) noexcept; - // ──────────────── Semaphore operations ──────────────── - static bool ORT_API_CALL CanImportSemaphoreImpl( _In_ const OrtExternalResourceImporterImpl* this_ptr, _In_ OrtExternalSemaphoreType type) noexcept; @@ -93,29 +112,26 @@ class ExampleExternalResourceImporter : public OrtExternalResourceImporterImpl { static OrtStatus* ORT_API_CALL ImportSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalSemaphoreDescriptor* desc, - _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle) noexcept; + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept; static void ORT_API_CALL ReleaseSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle) noexcept; + _In_ OrtExternalSemaphoreHandle* handle) noexcept; static OrtStatus* ORT_API_CALL WaitSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value) noexcept; static OrtStatus* ORT_API_CALL SignalSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value) noexcept; - // ──────────────── Release ──────────────── - static void ORT_API_CALL ReleaseImpl(_In_ OrtExternalResourceImporterImpl* this_ptr) noexcept; private: - int device_id_; ApiPtrs apis_; }; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 3f0d8b335e361..fd652b8882df9 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -314,7 +314,7 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac /*static*/ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDeviceImpl( OrtEpFactory* this_ptr, - const OrtMemoryDevice* memory_device, + const OrtEpDevice* /*ep_device*/, OrtExternalResourceImporterImpl** out_importer) noexcept { auto& factory = *static_cast(this_ptr); @@ -323,19 +323,8 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDevic "out_importer cannot be nullptr"); } - *out_importer = nullptr; - - // For the example EP, we support external resource import on the default (GPU-simulated) device memory - if (factory.ep_api.MemoryDevice_GetMemoryType(memory_device) != OrtDeviceMemoryType_DEFAULT) { - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, - "External resource import only supported for DEFAULT device memory"); - } - - // Get the device ID from the memory device - auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); - // Create the external resource importer - auto importer = std::make_unique(device_id, factory); + auto importer = std::make_unique(factory); *out_importer = importer.release(); return nullptr; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index c7bfbcfb918a1..230fdef772e2f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -70,7 +70,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDeviceImpl( OrtEpFactory* this_ptr, - const OrtMemoryDevice* memory_device, + const OrtEpDevice* ep_device, OrtExternalResourceImporterImpl** out_importer) noexcept; const OrtLogger& default_logger_; // default logger for the EP factory diff --git a/onnxruntime/test/autoep/test_external_resource_importer.cc b/onnxruntime/test/autoep/test_external_resource_importer.cc index f13934e9e94ce..04240a57f6c25 100644 --- a/onnxruntime/test/autoep/test_external_resource_importer.cc +++ b/onnxruntime/test/autoep/test_external_resource_importer.cc @@ -176,18 +176,21 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { ASSERT_EQ(status, nullptr); size_t rank = 0; - Ort::GetApi().GetDimensionsCount(type_info, &rank); + status = Ort::GetApi().GetDimensionsCount(type_info, &rank); + ASSERT_EQ(status, nullptr); EXPECT_EQ(rank, 4u); std::vector actual_shape(rank); - Ort::GetApi().GetDimensions(type_info, actual_shape.data(), rank); + status = Ort::GetApi().GetDimensions(type_info, actual_shape.data(), rank); + ASSERT_EQ(status, nullptr); EXPECT_EQ(actual_shape[0], batch); EXPECT_EQ(actual_shape[1], channels); EXPECT_EQ(actual_shape[2], height); EXPECT_EQ(actual_shape[3], width); ONNXTensorElementDataType elem_type; - Ort::GetApi().GetTensorElementType(type_info, &elem_type); + status = Ort::GetApi().GetTensorElementType(type_info, &elem_type); + ASSERT_EQ(status, nullptr); EXPECT_EQ(elem_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); Ort::GetApi().ReleaseTensorTypeAndShapeInfo(type_info); @@ -382,43 +385,5 @@ TEST_F(ExternalResourceImporterTest, SessionGetEpDeviceForOutputs) { } } -// Test: RunOptions_SetSyncStream -TEST_F(ExternalResourceImporterTest, RunOptionsSetSyncStream) { - // Create run options - Ort::RunOptions run_options; - - // Set sync stream to nullptr (which is valid - clears the stream) - OrtStatus* status = Ort::GetApi().RunOptions_SetSyncStream(run_options, nullptr); - ASSERT_EQ(status, nullptr) << "RunOptions_SetSyncStream with nullptr should succeed"; - - // Try to get a real sync stream from the EP device - OrtSyncStream* stream = nullptr; - status = Ort::GetApi().CreateSyncStreamForEpDevice(ep_device_, nullptr, &stream); - if (status != nullptr) { - std::string error = Ort::GetApi().GetErrorMessage(status); - Ort::GetApi().ReleaseStatus(status); - // Sync stream not supported - just test with nullptr - return; - } - - // Set the sync stream on run options - status = Ort::GetApi().RunOptions_SetSyncStream(run_options, stream); - ASSERT_EQ(status, nullptr) << "RunOptions_SetSyncStream with stream should succeed"; - - // Clean up - Ort::GetApi().ReleaseSyncStream(stream); -} - -// Test: RunOptions_SetSyncStream with Invalid Arguments -TEST_F(ExternalResourceImporterTest, RunOptionsSetSyncStreamInvalidArgs) { - // Test with nullptr run_options - OrtStatus* status = Ort::GetApi().RunOptions_SetSyncStream(nullptr, nullptr); - ASSERT_NE(status, nullptr) << "RunOptions_SetSyncStream with nullptr run_options should fail"; - - OrtErrorCode error_code = Ort::GetApi().GetErrorCode(status); - EXPECT_EQ(error_code, ORT_INVALID_ARGUMENT); - Ort::GetApi().ReleaseStatus(status); -} - } // namespace test } // namespace onnxruntime From 92248caef8723c42aafbd276fa5cd26e90dc42dd Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Fri, 19 Dec 2025 14:24:24 -0500 Subject: [PATCH 04/10] Update based on switch to OrtEpDevice parameter --- .../nv_tensorrt_rtx/nv_provider_factory.cc | 207 +++++++++--------- .../nv_external_resource_importer_test.cc | 9 +- 2 files changed, 110 insertions(+), 106 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 77154bfcabe71..8ff93c103de57 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -519,37 +519,71 @@ struct NvTrtRtxSyncStreamImpl : OrtSyncStreamImpl { const OrtApi& ort_api; }; -// External Resource Import Implementation (D3D12 ↔ CUDA) +#if defined(_WIN32) + +// External Resource Import Implementation (D3D12 to CUDA) /** - * @brief Wrapper for imported external memory from D3D12 to CUDA. + * @brief Derived handle for imported external memory from D3D12 to CUDA. * + * Derives from OrtExternalMemoryHandle (base struct) and adds CUDA-specific fields. * This struct holds the CUDA external memory object and the mapped device pointer * that can be used for zero-copy tensor creation. */ -struct NvTrtRtxExternalMemoryHandle { - CUexternalMemory ext_memory; ///< CUDA external memory object - CUdeviceptr mapped_ptr; ///< Mapped device pointer for tensor access - size_t size_bytes; ///< Size of the imported memory - size_t offset_bytes; ///< Offset into the imported memory - OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking - bool is_dedicated; ///< Whether the D3D12 resource is a dedicated allocation +struct NvTrtRtxExternalMemoryHandle : OrtExternalMemoryHandle { + CUexternalMemory ext_memory; ///< CUDA external memory object + CUdeviceptr mapped_ptr; ///< Mapped device pointer for tensor access + bool is_dedicated; ///< Whether the D3D12 resource is a dedicated allocation NvTrtRtxExternalMemoryHandle() - : ext_memory(nullptr), mapped_ptr(0), size_bytes(0), offset_bytes(0), handle_type(ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE), is_dedicated(true) {} + : ext_memory(nullptr), mapped_ptr(0), is_dedicated(true) { + // Initialize base struct fields + version = ORT_EXTERNAL_MEMORY_HANDLE_VERSION; + ep_device = nullptr; + handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + size_bytes = 0; + offset_bytes = 0; + Release = ReleaseCallback; + } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalMemoryHandle* handle) noexcept { + if (handle == nullptr) return; + auto* derived = static_cast(handle); + // Destroy the external memory object (also releases mapped buffer) + if (derived->ext_memory != nullptr) { + cuDestroyExternalMemory(derived->ext_memory); + } + delete derived; + } }; /** - * @brief Wrapper for imported external semaphore from D3D12 fence to CUDA. + * @brief Derived handle for imported external semaphore from D3D12 fence to CUDA. * + * Derives from OrtExternalSemaphoreHandle (base struct) and adds CUDA-specific fields. * D3D12 timeline fences are imported as CUDA external semaphores, enabling * GPU-GPU synchronization between D3D12 and CUDA streams. */ -struct NvTrtRtxExternalSemaphoreHandle { +struct NvTrtRtxExternalSemaphoreHandle : OrtExternalSemaphoreHandle { CUexternalSemaphore ext_semaphore; ///< CUDA external semaphore object - OrtExternalSemaphoreType type; ///< Original semaphore type NvTrtRtxExternalSemaphoreHandle() - : ext_semaphore(nullptr), type(ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE) {} + : ext_semaphore(nullptr) { + // Initialize base struct fields + version = ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION; + ep_device = nullptr; + type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + Release = ReleaseCallback; + } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalSemaphoreHandle* handle) noexcept { + if (handle == nullptr) return; + auto* derived = static_cast(handle); + // Destroy the external semaphore object + if (derived->ext_semaphore != nullptr) { + cuDestroyExternalSemaphore(derived->ext_semaphore); + } + delete derived; + } }; /** @@ -564,8 +598,8 @@ struct NvTrtRtxExternalSemaphoreHandle { * - ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE → CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE */ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { - NvTrtRtxExternalResourceImporterImpl(int device_id, const OrtApi& ort_api_in) - : device_id_{device_id}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + NvTrtRtxExternalResourceImporterImpl(const OrtEpDevice* ep_device, const OrtApi& ort_api_in) + : ep_device_{ep_device}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { ort_version_supported = ORT_API_VERSION; // Memory operations @@ -585,8 +619,6 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { Release = ReleaseImpl; } - // ──────────────── Memory operations ──────────────── - static bool ORT_API_CALL CanImportMemoryImpl( _In_ const OrtExternalResourceImporterImpl* this_ptr, _In_ OrtExternalMemoryHandleType handle_type) noexcept { @@ -599,7 +631,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { static OrtStatus* ORT_API_CALL ImportMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalMemoryDescriptor* desc, - _Outptr_ OrtExternalMemoryHandleImpl** out_handle) noexcept { + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept { auto& impl = *static_cast(this_ptr); if (desc == nullptr || out_handle == nullptr) { @@ -628,7 +660,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { // Set CUDA device CUresult cu_result = cuCtxSetCurrent(nullptr); // Reset context - CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.device_id_)); + CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.DeviceId())); // Map ORT handle type to CUDA handle type CUexternalMemoryHandleType cu_handle_type; @@ -681,35 +713,38 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); } - // Create and return the handle wrapper + // Create and return the derived handle (cast to base pointer) auto* handle = new (std::nothrow) NvTrtRtxExternalMemoryHandle(); if (handle == nullptr) { cuDestroyExternalMemory(ext_memory); return impl.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external memory handle"); } - handle->ext_memory = ext_memory; - handle->mapped_ptr = mapped_ptr; + handle->ep_device = impl.ep_device_; + handle->handle_type = desc->handle_type; handle->size_bytes = desc->size_bytes; handle->offset_bytes = desc->offset_bytes; - handle->handle_type = desc->handle_type; + handle->ext_memory = ext_memory; + handle->mapped_ptr = mapped_ptr; handle->is_dedicated = is_dedicated; - *out_handle = reinterpret_cast(handle); + *out_handle = handle; return nullptr; } static void ORT_API_CALL ReleaseMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalMemoryHandleImpl* handle) noexcept { - auto& impl = *static_cast(this_ptr); - (void)impl; + _In_ OrtExternalMemoryHandle* handle) noexcept { + (void)this_ptr; if (handle == nullptr) { return; } - auto* mem_handle = reinterpret_cast(handle); + // The handle has a Release callback that does the actual cleanup + // This method is called from OrtExternalResourceImporterImpl::ReleaseMemory + // The Release callback in the handle will call the static ReleaseCallback + auto* mem_handle = static_cast(handle); // Destroy the external memory object (also releases mapped buffer) if (mem_handle->ext_memory != nullptr) { @@ -721,7 +756,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { static OrtStatus* ORT_API_CALL CreateTensorFromMemoryImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ const OrtExternalMemoryHandleImpl* mem_handle, + _In_ const OrtExternalMemoryHandle* mem_handle, _In_ const OrtExternalTensorDescriptor* tensor_desc, _Outptr_ OrtValue** out_tensor) noexcept { auto& impl = *static_cast(this_ptr); @@ -738,7 +773,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { *out_tensor = nullptr; - auto* handle = reinterpret_cast(mem_handle); + auto* handle = static_cast(mem_handle); // Validate tensor offset does not exceed available buffer size size_t available_size = handle->size_bytes - handle->offset_bytes; @@ -750,24 +785,14 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { // Calculate the data pointer with tensor offset void* data_ptr = reinterpret_cast(handle->mapped_ptr + tensor_desc->offset_bytes); - // Create memory info for CUDA device - OrtMemoryInfo* memory_info = nullptr; - OrtStatus* status = impl.ort_api.CreateMemoryInfo_V2( - "NvTensorRTRTX", - OrtMemoryInfoDeviceType_GPU, - OrtDevice::VendorIds::NVIDIA, - impl.device_id_, - OrtDeviceMemoryType_DEFAULT, - 0, // alignment - OrtDeviceAllocator, - &memory_info); - - if (status != nullptr) { - return status; - } + // Get memory info from the EP device (the importer is associated with the OrtEpDevice) + const OrtMemoryInfo* memory_info = impl.ep_device_->device_memory_info; - // Create tensor with pre-allocated memory - status = impl.ort_api.CreateTensorWithDataAsOrtValue( + // Create tensor that references the imported memory. The tensor does not own the memory - + // the user manages the lifetime of both the OrtValue and OrtExternalMemoryHandle. + // The user must keep the handle alive while the tensor is in use. + // No deleter is needed since this is for inference inputs/outputs where the user controls lifetime. + OrtStatus* status = impl.ort_api.CreateTensorWithDataAsOrtValue( memory_info, data_ptr, handle->size_bytes - tensor_desc->offset_bytes, @@ -776,12 +801,9 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { tensor_desc->element_type, out_tensor); - impl.ort_api.ReleaseMemoryInfo(memory_info); return status; } - // ──────────────── Semaphore operations ──────────────── - static bool ORT_API_CALL CanImportSemaphoreImpl( _In_ const OrtExternalResourceImporterImpl* this_ptr, _In_ OrtExternalSemaphoreType type) noexcept { @@ -793,7 +815,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { static OrtStatus* ORT_API_CALL ImportSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, _In_ const OrtExternalSemaphoreDescriptor* desc, - _Outptr_ OrtExternalSemaphoreHandleImpl** out_handle) noexcept { + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept { auto& impl = *static_cast(this_ptr); if (desc == nullptr || out_handle == nullptr) { @@ -815,7 +837,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { } // Set CUDA device - CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.device_id_)); + CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.DeviceId())); // Setup external semaphore handle descriptor for D3D12 fence CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC ext_sem_desc = {}; @@ -834,30 +856,34 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); } - // Create and return the handle wrapper + // Create and return the derived handle (cast to base pointer) auto* handle = new (std::nothrow) NvTrtRtxExternalSemaphoreHandle(); if (handle == nullptr) { cuDestroyExternalSemaphore(ext_semaphore); return impl.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external semaphore handle"); } - handle->ext_semaphore = ext_semaphore; + // Populate base struct fields + handle->ep_device = impl.ep_device_; handle->type = desc->type; - *out_handle = reinterpret_cast(handle); + // Populate derived fields + handle->ext_semaphore = ext_semaphore; + + *out_handle = handle; // Return base pointer return nullptr; } static void ORT_API_CALL ReleaseSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle) noexcept { + _In_ OrtExternalSemaphoreHandle* handle) noexcept { (void)this_ptr; if (handle == nullptr) { return; } - auto* sem_handle = reinterpret_cast(handle); + auto* sem_handle = static_cast(handle); if (sem_handle->ext_semaphore != nullptr) { cuDestroyExternalSemaphore(sem_handle->ext_semaphore); @@ -868,7 +894,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { static OrtStatus* ORT_API_CALL WaitSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value) noexcept { auto& impl = *static_cast(this_ptr); @@ -877,7 +903,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to WaitSemaphore"); } - auto* sem_handle = reinterpret_cast(handle); + auto* sem_handle = static_cast(handle); // Get the CUDA stream from OrtSyncStream cudaStream_t cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); @@ -907,7 +933,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { static OrtStatus* ORT_API_CALL SignalSemaphoreImpl( _In_ OrtExternalResourceImporterImpl* this_ptr, - _In_ OrtExternalSemaphoreHandleImpl* handle, + _In_ OrtExternalSemaphoreHandle* handle, _In_ OrtSyncStream* stream, _In_ uint64_t value) noexcept { auto& impl = *static_cast(this_ptr); @@ -916,7 +942,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to SignalSemaphore"); } - auto* sem_handle = reinterpret_cast(handle); + auto* sem_handle = static_cast(handle); // Get the CUDA stream from OrtSyncStream cudaStream_t cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); @@ -944,18 +970,23 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return nullptr; } - // ──────────────── Release ──────────────── - static void ORT_API_CALL ReleaseImpl(_In_ OrtExternalResourceImporterImpl* this_ptr) noexcept { delete static_cast(this_ptr); } + /// @brief Get the CUDA device ID from the EP device's memory info. + int DeviceId() const { + return ep_device_->device_memory_info->device.Id(); + } + private: - int device_id_; + const OrtEpDevice* ep_device_; const OrtApi& ort_api; const OrtEpApi& ep_api; }; +#endif // defined(_WIN32) + // OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection. struct NvTensorRtRtxEpFactory : OrtEpFactory { using MemoryInfoUniquePtr = std::unique_ptr>; @@ -1193,13 +1224,13 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { * The implementation uses CUDA Driver APIs (cuImportExternalMemory, cuImportExternalSemaphore). * * @param this_ptr The OrtEpFactory instance. - * @param memory_device The OrtMemoryDevice to create the importer for. + * @param ep_device The OrtEpDevice to create the importer for (must have "device_id" in ep_options). * @param out_importer Output parameter set to the created OrtExternalResourceImporterImpl. * @return nullptr on success, OrtStatus with error on failure. */ static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDeviceImpl( OrtEpFactory* this_ptr, - const OrtMemoryDevice* memory_device, + const OrtEpDevice* ep_device, OrtExternalResourceImporterImpl** out_importer) noexcept { auto& factory = *static_cast(this_ptr); @@ -1210,44 +1241,18 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { *out_importer = nullptr; - // Check memory type - only DEFAULT device memory is supported - auto mem_type = factory.ep_api.MemoryDevice_GetMemoryType(memory_device); - if (mem_type != OrtDeviceMemoryType_DEFAULT) { - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, - "External resource import only supported for DEFAULT device memory"); - } - - // Validate that this is a GPU device - OrtMemoryInfoDeviceType device_type = factory.ep_api.MemoryDevice_GetDeviceType(memory_device); - if (device_type != OrtMemoryInfoDeviceType_GPU) { - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, - "External resource import only supported for GPU devices"); - } - - // Get the CUDA device ID - auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); - - // Verify the device is an NVIDIA GPU - auto vendor_id = factory.ep_api.MemoryDevice_GetVendorId(memory_device); - if (vendor_id != OrtDevice::VendorIds::NVIDIA) { - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, - "External resource import only supported for NVIDIA GPUs"); - } - - // Verify CUDA device is valid and has necessary capabilities - int cuda_device_count = 0; - cudaError_t cuda_err = cudaGetDeviceCount(&cuda_device_count); - if (cuda_err != cudaSuccess || cuda_device_count <= 0 || - device_id >= static_cast(cuda_device_count)) { - return factory.ort_api.CreateStatus(ORT_FAIL, - "Invalid CUDA device ID for external resource import"); - } - +#if defined(_WIN32) // Create the external resource importer - auto importer = std::make_unique(device_id, factory.ort_api); + // The importer gets the CUDA device ID from ep_device->ep_options["device_id"] + auto importer = std::make_unique(ep_device, factory.ort_api); *out_importer = importer.release(); return nullptr; +#else + ORT_UNUSED_PARAMETER(ep_device); + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "External resource import is only available on Windows builds."); +#endif } OrtStatus* CreateMemoryInfoForDevices(int num_devices) { diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc index edb886abdd0c8..4ae346ff06450 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc @@ -755,7 +755,7 @@ TEST_F(ExternalResourceImporterTest, FullInferenceWithExternalMemory) { status = ort_api_->ExternalResourceImporter_WaitSemaphore(importer, sem_handle, ort_stream, upload_complete_value); ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; - // Setup ORT session with user_compute_stream (like the patch PR test) + // Setup ORT session with user_compute_stream Ort::SessionOptions session_options; session_options.SetExecutionMode(ORT_SEQUENTIAL); session_options.DisableMemPattern(); @@ -787,10 +787,9 @@ TEST_F(ExternalResourceImporterTest, FullInferenceWithExternalMemory) { io_binding.BindOutput(output_name.get(), Ort::Value(output_tensor)); io_binding.SynchronizeInputs(); - // Run inference with synchronization disabled (we handle it manually) - Ort::RunOptions run_options; - run_options.AddConfigEntry("disable_synchronize_execution_providers", "1"); - session.Run(run_options, io_binding); + // Run inference. ORT synchronizes the stream before returning, so GPU work is complete + // when we signal the semaphore below. + session.Run(Ort::RunOptions{}, io_binding); // Signal from CUDA that inference is complete uint64_t inference_complete_value = 2; From 6199160b2d3b6153bfd6b800822ac33d4129427f Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Fri, 19 Dec 2025 14:48:13 -0500 Subject: [PATCH 05/10] Update error handling and version validation in NvTrtRtxExternalResourceImporter --- .../nv_tensorrt_rtx/nv_provider_factory.cc | 59 +++++++++---------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 8ff93c103de57..6ce998b58747e 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -638,10 +638,10 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportMemory"); } - // Validate descriptor version - if (desc->version != ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION) { + // Validate descriptor version - check minimum supported version for forward compatibility + if (desc->version < ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION) { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, - "Invalid OrtExternalMemoryDescriptor version"); + "OrtExternalMemoryDescriptor version too old"); } *out_handle = nullptr; @@ -658,8 +658,11 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { "offset_bytes exceeds size_bytes in OrtExternalMemoryDescriptor"); } - // Set CUDA device - CUresult cu_result = cuCtxSetCurrent(nullptr); // Reset context + // Set CUDA device for this EP. The imported external memory handle is associated with + // the device where it was imported and remains valid regardless of subsequent cudaSetDevice + // calls. Multi-GPU scenarios with different sessions/EPs work correctly because each + // importer is bound to its EP's device via ep_device_->device_memory_info. + (void)cuCtxSetCurrent(nullptr); // Reset context CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.DeviceId())); // Map ORT handle type to CUDA handle type @@ -675,7 +678,8 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { is_dedicated = false; // D3D12 heaps are not dedicated break; default: - return impl.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Unknown external memory handle type"); + // Should not reach here - CanImportMemory already validated handle type + return impl.ort_api.CreateStatus(ORT_EP_FAIL, "Unexpected external memory handle type"); } // Setup external memory handle descriptor @@ -687,13 +691,13 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { // Import the external memory CUexternalMemory ext_memory = nullptr; - cu_result = cuImportExternalMemory(&ext_memory, &ext_mem_desc); + CUresult cu_result = cuImportExternalMemory(&ext_memory, &ext_mem_desc); if (cu_result != CUDA_SUCCESS) { const char* error_str = nullptr; cuGetErrorString(cu_result, &error_str); std::string error_msg = "cuImportExternalMemory failed: "; error_msg += error_str ? error_str : "unknown error"; - return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } // Map the external memory to get a device pointer @@ -710,15 +714,11 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { cuGetErrorString(cu_result, &error_str); std::string error_msg = "cuExternalMemoryGetMappedBuffer failed: "; error_msg += error_str ? error_str : "unknown error"; - return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } // Create and return the derived handle (cast to base pointer) - auto* handle = new (std::nothrow) NvTrtRtxExternalMemoryHandle(); - if (handle == nullptr) { - cuDestroyExternalMemory(ext_memory); - return impl.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external memory handle"); - } + auto handle = std::make_unique(); handle->ep_device = impl.ep_device_; handle->handle_type = desc->handle_type; @@ -728,7 +728,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { handle->mapped_ptr = mapped_ptr; handle->is_dedicated = is_dedicated; - *out_handle = handle; + *out_handle = handle.release(); return nullptr; } @@ -765,10 +765,10 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to CreateTensorFromMemory"); } - // Validate descriptor version - if (tensor_desc->version != ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION) { + // Validate descriptor version - check minimum supported version for forward compatibility + if (tensor_desc->version < ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION) { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, - "Invalid OrtExternalTensorDescriptor version"); + "OrtExternalTensorDescriptor version too old"); } *out_tensor = nullptr; @@ -822,10 +822,10 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportSemaphore"); } - // Validate descriptor version - if (desc->version != ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION) { + // Validate descriptor version - check minimum supported version for forward compatibility + if (desc->version < ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION) { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, - "Invalid OrtExternalSemaphoreDescriptor version"); + "OrtExternalSemaphoreDescriptor version too old"); } *out_handle = nullptr; @@ -836,7 +836,8 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { "Unsupported external semaphore type for CUDA import"); } - // Set CUDA device + // Set CUDA device for this EP. Imported semaphore handles remain valid regardless of + // subsequent cudaSetDevice calls. CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.DeviceId())); // Setup external semaphore handle descriptor for D3D12 fence @@ -853,15 +854,11 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { cuGetErrorString(cu_result, &error_str); std::string error_msg = "cuImportExternalSemaphore failed: "; error_msg += error_str ? error_str : "unknown error"; - return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } // Create and return the derived handle (cast to base pointer) - auto* handle = new (std::nothrow) NvTrtRtxExternalSemaphoreHandle(); - if (handle == nullptr) { - cuDestroyExternalSemaphore(ext_semaphore); - return impl.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external semaphore handle"); - } + auto handle = std::make_unique(); // Populate base struct fields handle->ep_device = impl.ep_device_; @@ -870,7 +867,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { // Populate derived fields handle->ext_semaphore = ext_semaphore; - *out_handle = handle; // Return base pointer + *out_handle = handle.release(); // Return base pointer return nullptr; } @@ -925,7 +922,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { cuGetErrorString(cu_result, &error_str); std::string error_msg = "cuWaitExternalSemaphoresAsync failed: "; error_msg += error_str ? error_str : "unknown error"; - return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } return nullptr; @@ -964,7 +961,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { cuGetErrorString(cu_result, &error_str); std::string error_msg = "cuSignalExternalSemaphoresAsync failed: "; error_msg += error_str ? error_str : "unknown error"; - return impl.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); } return nullptr; From 2d49c05cac2374d29b942676c5e12cbfa3b1c962 Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Mon, 22 Dec 2025 12:39:51 -0500 Subject: [PATCH 06/10] Use ORT_API_VERSION instead of separate version constants for external resource structs --- .../onnxruntime/core/session/onnxruntime_c_api.h | 15 ++++++--------- .../core/session/onnxruntime_ep_c_api.h | 6 ++---- .../ep_external_resource_importer.h | 4 ++-- .../autoep/test_external_resource_importer.cc | 14 +++++++------- 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 7414d3732e83b..a8822d941110b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -984,14 +984,13 @@ typedef enum OrtExternalMemoryAccessMode { /** \brief Descriptor for importing external memory. * - * \note The version field must be set to ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION. + * \note The version field must be set to ORT_API_VERSION. * This ensures forward compatibility as fields may be added in future versions. * * \since Version 1.24. */ -#define ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION 1 typedef struct OrtExternalMemoryDescriptor { - uint32_t version; /**< Must be ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION */ + uint32_t version; /**< Must be ORT_API_VERSION */ OrtExternalMemoryHandleType handle_type; /**< Type of the external memory handle */ void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ size_t size_bytes; /**< Total size in bytes of the external allocation */ @@ -1009,28 +1008,26 @@ typedef enum OrtExternalSemaphoreType { /** \brief Descriptor for importing external semaphores. * - * \note The version field must be set to ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION. + * \note The version field must be set to ORT_API_VERSION. * This ensures forward compatibility as fields may be added in future versions. * * \since Version 1.24. */ -#define ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION 1 typedef struct OrtExternalSemaphoreDescriptor { - uint32_t version; /**< Must be ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION */ + uint32_t version; /**< Must be ORT_API_VERSION */ OrtExternalSemaphoreType type; /**< Type of the external semaphore */ void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ } OrtExternalSemaphoreDescriptor; /** \brief Descriptor for creating a tensor from imported external memory. * - * \note The version field must be set to ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION. + * \note The version field must be set to ORT_API_VERSION. * This ensures forward compatibility as fields may be added in future versions. * * \since Version 1.24. */ -#define ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION 1 typedef struct OrtExternalTensorDescriptor { - uint32_t version; /**< Must be ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION */ + uint32_t version; /**< Must be ORT_API_VERSION */ ONNXTensorElementDataType element_type; /**< Data type of tensor elements */ const int64_t* shape; /**< Array of dimension sizes */ size_t rank; /**< Number of dimensions */ diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 3130cc4da3238..eb716afc76d6f 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -42,9 +42,8 @@ ORT_RUNTIME_CLASS(ExternalResourceImporterImpl); * * \since Version 1.24. */ -#define ORT_EXTERNAL_MEMORY_HANDLE_VERSION 1 struct OrtExternalMemoryHandle { - uint32_t version; ///< Must be ORT_EXTERNAL_MEMORY_HANDLE_VERSION + uint32_t version; ///< Must be ORT_API_VERSION const OrtEpDevice* ep_device; ///< EP device that created this handle OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking size_t size_bytes; ///< Size of the imported memory @@ -72,9 +71,8 @@ struct OrtExternalMemoryHandle { * * \since Version 1.24. */ -#define ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION 1 struct OrtExternalSemaphoreHandle { - uint32_t version; ///< Must be ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION + uint32_t version; ///< Must be ORT_API_VERSION const OrtEpDevice* ep_device; ///< EP device that created this handle OrtExternalSemaphoreType type; ///< Original semaphore type diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h index 64b42611b5402..903623847c795 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h @@ -23,7 +23,7 @@ struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { ExampleExternalMemoryHandle() : simulated_ptr(nullptr), access_mode(ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE) { // Initialize base struct fields - version = ORT_EXTERNAL_MEMORY_HANDLE_VERSION; + version = ORT_API_VERSION; ep_device = nullptr; handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; size_bytes = 0; @@ -58,7 +58,7 @@ struct ExampleExternalSemaphoreHandle : OrtExternalSemaphoreHandle { ExampleExternalSemaphoreHandle() : value(0) { // Initialize base struct fields - version = ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION; + version = ORT_API_VERSION; ep_device = nullptr; type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; Release = ReleaseCallback; diff --git a/onnxruntime/test/autoep/test_external_resource_importer.cc b/onnxruntime/test/autoep/test_external_resource_importer.cc index 04240a57f6c25..0362b74bbb27d 100644 --- a/onnxruntime/test/autoep/test_external_resource_importer.cc +++ b/onnxruntime/test/autoep/test_external_resource_importer.cc @@ -108,7 +108,7 @@ TEST_F(ExternalResourceImporterTest, ImportMemory) { void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); // Simulated handle OrtExternalMemoryDescriptor mem_desc = {}; - mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.version = ORT_API_VERSION; mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; mem_desc.native_handle = dummy_handle; mem_desc.size_bytes = buffer_size; @@ -145,7 +145,7 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); OrtExternalMemoryDescriptor mem_desc = {}; - mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.version = ORT_API_VERSION; mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; mem_desc.native_handle = dummy_handle; mem_desc.size_bytes = buffer_size; @@ -158,7 +158,7 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { // Create tensor from imported memory OrtExternalTensorDescriptor tensor_desc = {}; - tensor_desc.version = ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION; + tensor_desc.version = ORT_API_VERSION; tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; tensor_desc.shape = shape; tensor_desc.rank = 4; @@ -214,7 +214,7 @@ TEST_F(ExternalResourceImporterTest, ImportSemaphore) { void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); OrtExternalSemaphoreDescriptor sem_desc = {}; - sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.version = ORT_API_VERSION; sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; sem_desc.native_handle = dummy_handle; @@ -247,7 +247,7 @@ TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); OrtExternalSemaphoreDescriptor sem_desc = {}; - sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.version = ORT_API_VERSION; sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; sem_desc.native_handle = dummy_handle; @@ -292,7 +292,7 @@ TEST_F(ExternalResourceImporterTest, MultipleMemoryImports) { // Import multiple memory regions for (int i = 0; i < kNumBuffers; ++i) { OrtExternalMemoryDescriptor mem_desc = {}; - mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.version = ORT_API_VERSION; mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; mem_desc.native_handle = reinterpret_cast(static_cast(0x10000000 + i * 0x1000)); mem_desc.size_bytes = (i + 1) * 1024; @@ -328,7 +328,7 @@ TEST_F(ExternalResourceImporterTest, AccessModeVariations) { for (auto access_mode : access_modes) { OrtExternalMemoryDescriptor mem_desc = {}; - mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.version = ORT_API_VERSION; mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; mem_desc.native_handle = reinterpret_cast(static_cast(0x12345678)); mem_desc.size_bytes = 4096; From 1c309118ea4fdc70a5909ff40b106c617f8d2c6d Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Mon, 22 Dec 2025 12:42:27 -0500 Subject: [PATCH 07/10] Use ORT_API_VERSION in NV TensorRT RTX EP external resource implementation --- .../nv_tensorrt_rtx/nv_provider_factory.cc | 10 +++++----- .../nv_external_resource_importer_test.cc | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 6ce998b58747e..e6710d32fe904 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -537,7 +537,7 @@ struct NvTrtRtxExternalMemoryHandle : OrtExternalMemoryHandle { NvTrtRtxExternalMemoryHandle() : ext_memory(nullptr), mapped_ptr(0), is_dedicated(true) { // Initialize base struct fields - version = ORT_EXTERNAL_MEMORY_HANDLE_VERSION; + version = ORT_API_VERSION; ep_device = nullptr; handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; size_bytes = 0; @@ -569,7 +569,7 @@ struct NvTrtRtxExternalSemaphoreHandle : OrtExternalSemaphoreHandle { NvTrtRtxExternalSemaphoreHandle() : ext_semaphore(nullptr) { // Initialize base struct fields - version = ORT_EXTERNAL_SEMAPHORE_HANDLE_VERSION; + version = ORT_API_VERSION; ep_device = nullptr; type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; Release = ReleaseCallback; @@ -639,7 +639,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { } // Validate descriptor version - check minimum supported version for forward compatibility - if (desc->version < ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION) { + if (desc->version < ORT_API_VERSION) { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "OrtExternalMemoryDescriptor version too old"); } @@ -766,7 +766,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { } // Validate descriptor version - check minimum supported version for forward compatibility - if (tensor_desc->version < ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION) { + if (tensor_desc->version < ORT_API_VERSION) { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "OrtExternalTensorDescriptor version too old"); } @@ -823,7 +823,7 @@ struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { } // Validate descriptor version - check minimum supported version for forward compatibility - if (desc->version < ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION) { + if (desc->version < ORT_API_VERSION) { return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "OrtExternalSemaphoreDescriptor version too old"); } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc index 4ae346ff06450..6cc959dd3faa7 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc @@ -354,7 +354,7 @@ TEST_F(ExternalResourceImporterTest, ImportD3D12SharedResource) { // Import the memory OrtExternalMemoryDescriptor mem_desc = {}; - mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.version = ORT_API_VERSION; mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; mem_desc.native_handle = shared_handle; mem_desc.size_bytes = buffer_size; @@ -408,7 +408,7 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromImportedMemory) { // Import the memory OrtExternalMemoryDescriptor mem_desc = {}; - mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + mem_desc.version = ORT_API_VERSION; mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; mem_desc.native_handle = shared_handle; mem_desc.size_bytes = buffer_size; @@ -421,7 +421,7 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromImportedMemory) { // Create tensor from imported memory OrtExternalTensorDescriptor tensor_desc = {}; - tensor_desc.version = ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION; + tensor_desc.version = ORT_API_VERSION; tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; tensor_desc.shape = shape; tensor_desc.rank = 4; @@ -505,7 +505,7 @@ TEST_F(ExternalResourceImporterTest, ImportD3D12Fence) { // Import the semaphore OrtExternalSemaphoreDescriptor sem_desc = {}; - sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.version = ORT_API_VERSION; sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; sem_desc.native_handle = shared_handle; @@ -555,7 +555,7 @@ TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { // Import semaphore OrtExternalSemaphoreDescriptor sem_desc = {}; - sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.version = ORT_API_VERSION; sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; sem_desc.native_handle = shared_handle; @@ -664,7 +664,7 @@ TEST_F(ExternalResourceImporterTest, FullInferenceWithExternalMemory) { // Import memory into CUDA OrtExternalMemoryDescriptor input_mem_desc = {}; - input_mem_desc.version = ORT_EXTERNAL_MEMORY_DESCRIPTOR_VERSION; + input_mem_desc.version = ORT_API_VERSION; input_mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; input_mem_desc.native_handle = input_handle; input_mem_desc.size_bytes = buffer_size; @@ -682,7 +682,7 @@ TEST_F(ExternalResourceImporterTest, FullInferenceWithExternalMemory) { // Create ORT tensors from imported memory OrtExternalTensorDescriptor tensor_desc = {}; - tensor_desc.version = ORT_EXTERNAL_TENSOR_DESCRIPTOR_VERSION; + tensor_desc.version = ORT_API_VERSION; tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; tensor_desc.shape = shape; tensor_desc.rank = 4; @@ -715,7 +715,7 @@ TEST_F(ExternalResourceImporterTest, FullInferenceWithExternalMemory) { d3d12_device_->CreateSharedHandle(sync_fence.Get(), nullptr, GENERIC_ALL, nullptr, &fence_handle); OrtExternalSemaphoreDescriptor sem_desc = {}; - sem_desc.version = ORT_EXTERNAL_SEMAPHORE_DESCRIPTOR_VERSION; + sem_desc.version = ORT_API_VERSION; sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; sem_desc.native_handle = fence_handle; From 3bb37b9f77edbce20fae28d8ce7c6890e26d3fed Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Mon, 22 Dec 2025 12:46:44 -0500 Subject: [PATCH 08/10] remove redundant and outdated comment --- .../core/providers/nv_tensorrt_rtx/nv_provider_factory.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e6710d32fe904..545330c5bee20 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -1221,7 +1221,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { * The implementation uses CUDA Driver APIs (cuImportExternalMemory, cuImportExternalSemaphore). * * @param this_ptr The OrtEpFactory instance. - * @param ep_device The OrtEpDevice to create the importer for (must have "device_id" in ep_options). + * @param ep_device The OrtEpDevice to create the importer for. * @param out_importer Output parameter set to the created OrtExternalResourceImporterImpl. * @return nullptr on success, OrtStatus with error on failure. */ @@ -1240,7 +1240,6 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { #if defined(_WIN32) // Create the external resource importer - // The importer gets the CUDA device ID from ep_device->ep_options["device_id"] auto importer = std::make_unique(ep_device, factory.ort_api); *out_importer = importer.release(); From c94db3ef5013f848fd3aa34d62a9e8d647b0fa11 Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Mon, 22 Dec 2025 21:24:30 -0500 Subject: [PATCH 09/10] Use std::unique_ptr for simulated memory management in ExampleExternalMemoryHandle --- .../ep_external_resource_importer.cc | 11 ++--------- .../ep_external_resource_importer.h | 12 +++--------- 2 files changed, 5 insertions(+), 18 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc index e8347bc28cb84..c3d413ac82ab9 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc @@ -72,14 +72,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( // Allocate simulated memory (using CPU memory for the example) size_t effective_size = desc->size_bytes - desc->offset_bytes; - handle->simulated_ptr = malloc(effective_size); - if (handle->simulated_ptr == nullptr) { - delete handle; - return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate simulated memory"); - } - - // Initialize to zero - memset(handle->simulated_ptr, 0, effective_size); + handle->simulated_ptr = std::make_unique(effective_size); handle->size_bytes = desc->size_bytes; handle->offset_bytes = desc->offset_bytes; @@ -119,7 +112,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::CreateTensorFromMemoryI auto* handle = static_cast(mem_handle); // Calculate the data pointer with tensor offset - void* data_ptr = static_cast(handle->simulated_ptr) + tensor_desc->offset_bytes; + void* data_ptr = handle->simulated_ptr.get() + tensor_desc->offset_bytes; // For the example EP, we use CPU memory info since we're simulating with CPU memory // In a real implementation, you would use the appropriate GPU memory info diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h index 903623847c795..06b003cd3feaa 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h @@ -7,7 +7,7 @@ #include #include -#include +#include /** * @brief Example derived handle for imported external memory. @@ -17,7 +17,7 @@ * In a real EP, this would hold a GPU-mapped pointer from an imported D3D12/Vulkan/CUDA resource. */ struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { - void* simulated_ptr; ///< Simulated mapped pointer (CPU memory for testing) + std::unique_ptr simulated_ptr; ///< Simulated mapped pointer (CPU memory for testing) OrtExternalMemoryAccessMode access_mode; ///< Access mode for the imported memory ExampleExternalMemoryHandle() @@ -31,13 +31,7 @@ struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { Release = ReleaseCallback; } - ~ExampleExternalMemoryHandle() { - // Free the simulated pointer if allocated - if (simulated_ptr != nullptr) { - free(simulated_ptr); - simulated_ptr = nullptr; - } - } + ~ExampleExternalMemoryHandle() = default; static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalMemoryHandle* handle) noexcept { if (handle == nullptr) return; From 8a1d1485b5562f9999de8b83e07650820198d780 Mon Sep 17 00:00:00 2001 From: Nick Eubanks Date: Tue, 6 Jan 2026 19:44:40 -0500 Subject: [PATCH 10/10] Convert External Resource Importer APIs to use Interop API pattern - Added `ep_interop_api.h` to define the Interop API for external resource importers. - Implemented functions for creating and managing external resource importers, including memory and semaphore import capabilities. - Updated `onnxruntime_c_api.cc` to integrate the new Interop API, replacing previous external resource importer implementations. - Modified `ort_apis.h` to declare the new Interop API functions. - Refactored tests in `test_external_resource_importer.cc` to utilize the new Interop API for external resource importer operations. --- .../core/session/onnxruntime_c_api.h | 381 +++++++++-------- .../core/session/onnxruntime_cxx_api.h | 14 + onnxruntime/core/session/ep_interop_api.cc | 389 ++++++++++++++++++ onnxruntime/core/session/ep_interop_api.h | 54 +++ onnxruntime/core/session/onnxruntime_c_api.cc | 331 +-------------- onnxruntime/core/session/ort_apis.h | 43 +- .../autoep/test_external_resource_importer.cc | 80 ++-- 7 files changed, 726 insertions(+), 566 deletions(-) create mode 100644 onnxruntime/core/session/ep_interop_api.cc create mode 100644 onnxruntime/core/session/ep_interop_api.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a8822d941110b..03df16c315376 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -881,6 +881,9 @@ typedef struct OrtModelEditorApi OrtModelEditorApi; struct OrtCompileApi; typedef struct OrtCompileApi OrtCompileApi; +struct OrtInteropApi; +typedef struct OrtInteropApi OrtInteropApi; + struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; @@ -6690,175 +6693,17 @@ struct OrtApi { */ ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); - /// \name External Resource Import - /// @{ - - /** \brief Create an external resource importer for a specific EP device. - * - * The external resource importer is a capability object that provides methods for importing - * external GPU memory and semaphores for zero-copy import with an execution provider. - * - * \param[in] ep_device The OrtEpDevice instance to create the importer for. - * \param[out] out_importer Output parameter set to the created OrtExternalResourceImporter instance. - * Returns nullptr if the EP does not support external resource import. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(CreateExternalResourceImporterForDevice, - _In_ const OrtEpDevice* ep_device, - _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); - - /** \brief Release an OrtExternalResourceImporter instance. - * - * \param[in] importer The OrtExternalResourceImporter instance to release. May be nullptr. - * - * \since Version 1.24. - */ - ORT_CLASS_RELEASE(ExternalResourceImporter); - - /** \brief Check if the external resource importer can import a specific memory handle type. - * - * \param[in] importer The OrtExternalResourceImporter instance. - * \param[in] handle_type The type of external memory handle to check. - * \param[out] out_supported Set to true if the handle type is supported. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(ExternalResourceImporter_CanImportMemory, - _In_ const OrtExternalResourceImporter* importer, - _In_ OrtExternalMemoryHandleType handle_type, - _Out_ bool* out_supported); - - /** \brief Import external memory into the execution provider. - * - * \param[in] importer The OrtExternalResourceImporter instance. - * \param[in] desc Descriptor containing the external memory handle and properties. - * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle. - * The caller owns the returned handle and must call ReleaseExternalMemoryHandle to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(ExternalResourceImporter_ImportMemory, - _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalMemoryDescriptor* desc, - _Outptr_ OrtExternalMemoryHandle** out_handle); - - /** \brief Release an OrtExternalMemoryHandle instance. - * - * \param[in] handle The OrtExternalMemoryHandle instance to release. May be nullptr. - * - * \since Version 1.24. - */ - ORT_CLASS_RELEASE(ExternalMemoryHandle); - - /** \brief Create a tensor backed by imported external memory. - * - * The created tensor is a view over the imported memory and does not copy data. - * The OrtExternalMemoryHandle must remain valid for the lifetime of the tensor. - * - * \param[in] importer The OrtExternalResourceImporter instance. - * \param[in] mem_handle The imported external memory handle. - * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. - * \param[in] tensor_location Optional OrtMemoryInfo for the tensor location. May be nullptr. - * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. - * The caller owns the returned tensor and must call ReleaseValue to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(ExternalResourceImporter_CreateTensorFromMemory, - _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalMemoryHandle* mem_handle, - _In_ const OrtExternalTensorDescriptor* tensor_desc, - _In_opt_ const OrtMemoryInfo* tensor_location, - _Outptr_ OrtValue** out_tensor); - - /** \brief Check if the external resource importer can import a specific semaphore type. - * - * \param[in] importer The OrtExternalResourceImporter instance. - * \param[in] type The type of external semaphore to check. - * \param[out] out_supported Set to true if the semaphore type is supported. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(ExternalResourceImporter_CanImportSemaphore, - _In_ const OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreType type, - _Out_ bool* out_supported); - - /** \brief Import an external semaphore into the execution provider. - * - * \param[in] importer The OrtExternalResourceImporter instance. - * \param[in] desc Descriptor containing the external semaphore handle and type. - * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(ExternalResourceImporter_ImportSemaphore, - _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalSemaphoreDescriptor* desc, - _Outptr_ OrtExternalSemaphoreHandle** out_handle); - - /** \brief Release an OrtExternalSemaphoreHandle instance. - * - * \param[in] handle The OrtExternalSemaphoreHandle instance to release. May be nullptr. - * - * \since Version 1.24. - */ - ORT_CLASS_RELEASE(ExternalSemaphoreHandle); - - /** \brief Wait on an external semaphore on the EP's stream. - * - * Inserts a wait operation into the EP's stream that blocks until the semaphore - * reaches the specified value. This is used to synchronize with external GPU work - * (e.g., D3D12 timeline fence). - * - * \param[in] importer The OrtExternalResourceImporter instance. - * \param[in] semaphore_handle The imported external semaphore. - * \param[in] stream The OrtSyncStream to wait on. - * \param[in] value The fence/semaphore value to wait for. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.24. - */ - ORT_API2_STATUS(ExternalResourceImporter_WaitSemaphore, - _In_ OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreHandle* semaphore_handle, - _In_ OrtSyncStream* stream, - _In_ uint64_t value); - - /** \brief Signal an external semaphore from the EP's stream. - * - * Inserts a signal operation into the EP's stream that sets the semaphore - * to the specified value when reached. This is used to notify external GPU work - * (e.g., D3D12 timeline fence) that ORT inference is complete. + /** \brief Get the EP Interop API instance. * - * \param[in] importer The OrtExternalResourceImporter instance. - * \param[in] semaphore_handle The imported external semaphore. - * \param[in] stream The OrtSyncStream to signal from. - * \param[in] value The fence/semaphore value to signal. + * Get the Interop API instance to work with external resources. This API provides functions + * for importing external GPU memory and semaphores for zero-copy sharing between ORT inference + * and other GPU workloads. * - * \snippet{doc} snippets.dox OrtStatus Return Value + * \return Interop API struct instance. * * \since Version 1.24. */ - ORT_API2_STATUS(ExternalResourceImporter_SignalSemaphore, - _In_ OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreHandle* semaphore_handle, - _In_ OrtSyncStream* stream, - _In_ uint64_t value); + const OrtInteropApi*(ORT_API_CALL* GetInteropApi)(void); /** \brief Get the EP device assigned to each session output. * @@ -7683,6 +7528,214 @@ struct OrtCompileApi { _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); }; +/** + * \brief The OrtInteropApi struct provides functions for external resource interop with execution providers. + * + * This API enables importing external GPU resources (memory and semaphores) for zero-copy sharing + * between ORT inference and other GPU workloads (e.g., D3D12 applications, media pipelines). + * + * The API is designed to be EP-agnostic and can be extended to support various GPU interop mechanisms + * (D3D12 shared handles, CUDA external memory, Vulkan, etc.). + * + * Example usage (error handling not shown): + * const OrtInteropApi* interop_api = ort_api->GetInteropApi(); + * OrtExternalResourceImporter* importer = NULL; + * + * status = interop_api->CreateExternalResourceImporterForDevice(ep_device, &importer); + * bool can_import = false; + * status = interop_api->CanImportMemory(importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import); + * if (can_import) { + * OrtExternalMemoryHandle* mem_handle = NULL; + * status = interop_api->ImportMemory(importer, &mem_desc, &mem_handle); + * // ... use mem_handle to create tensors ... + * interop_api->ReleaseExternalMemoryHandle(mem_handle); + * } + * interop_api->ReleaseExternalResourceImporter(importer); + * + * \since Version 1.24. + */ +struct OrtInteropApi { + /// \name OrtExternalResourceImporter + /// @{ + + /** \brief Create an external resource importer for a specific EP device. + * + * The external resource importer is a capability object that provides methods for importing + * external GPU memory and semaphores for zero-copy import with an execution provider. + * + * \param[in] ep_device The OrtEpDevice instance to create the importer for. + * \param[out] out_importer Output parameter set to the created OrtExternalResourceImporter instance. + * Returns nullptr if the EP does not support external resource import. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateExternalResourceImporterForDevice, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + + /** \brief Release an OrtExternalResourceImporter instance. + * + * \param[in] importer The OrtExternalResourceImporter instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalResourceImporter); + + /// @} + /// \name Memory Import + /// @{ + + /** \brief Check if the external resource importer can import a specific memory handle type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] handle_type The type of external memory handle to check. + * \param[out] out_supported Set to true if the handle type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CanImportMemory, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + + /** \brief Import external memory into the execution provider. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external memory handle and properties. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle. + * The caller owns the returned handle and must call ReleaseExternalMemoryHandle to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + + /** \brief Release an OrtExternalMemoryHandle instance. + * + * \param[in] handle The OrtExternalMemoryHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalMemoryHandle); + + /** \brief Create a tensor backed by imported external memory. + * + * The created tensor is a view over the imported memory and does not copy data. + * The OrtExternalMemoryHandle must remain valid for the lifetime of the tensor. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] mem_handle The imported external memory handle. + * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. + * \param[in] tensor_location Optional OrtMemoryInfo for the tensor location. May be nullptr. + * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * The caller owns the returned tensor and must call ReleaseValue to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateTensorFromMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor); + + /// @} + /// \name Semaphore Import + /// @{ + + /** \brief Check if the external resource importer can import a specific semaphore type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] type The type of external semaphore to check. + * \param[out] out_supported Set to true if the semaphore type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CanImportSemaphore, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + + /** \brief Import an external semaphore into the execution provider. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external semaphore handle and type. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + + /** \brief Release an OrtExternalSemaphoreHandle instance. + * + * \param[in] handle The OrtExternalSemaphoreHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalSemaphoreHandle); + + /** \brief Wait on an external semaphore on the EP's stream. + * + * Inserts a wait operation into the EP's stream that blocks until the semaphore + * reaches the specified value. This is used to synchronize with external GPU work + * (e.g., D3D12 timeline fence). + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to wait on. + * \param[in] value The fence/semaphore value to wait for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(WaitSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Signal an external semaphore from the EP's stream. + * + * Inserts a signal operation into the EP's stream that sets the semaphore + * to the specified value when reached. This is used to notify external GPU work + * (e.g., D3D12 timeline fence) that ORT inference is complete. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to signal from. + * \param[in] value The fence/semaphore value to signal. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SignalSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /// @} +}; + /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c5de8b3e40a23..4d675277dbad2 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -236,6 +236,20 @@ inline const OrtCompileApi& GetCompileApi() { return *api; } +/// +/// This returns a reference to the ORT C Interop API. Used for external resource import with EPs. +/// +/// ORT C Interop API reference +inline const OrtInteropApi& GetInteropApi() { + auto* api = GetApi().GetInteropApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Interop API is not available in this build", ORT_FAIL); + } + + return *api; +} + /// /// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider. /// diff --git a/onnxruntime/core/session/ep_interop_api.cc b/onnxruntime/core/session/ep_interop_api.cc new file mode 100644 index 0000000000000..aa580792873d6 --- /dev/null +++ b/onnxruntime/core/session/ep_interop_api.cc @@ -0,0 +1,389 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/ep_interop_api.h" + +#if !defined(ORT_MINIMAL_BUILD) +#include + +#include "core/session/ort_apis.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_env.h" +#include "core/session/abi_devices.h" +#include "core/common/logging/logging.h" +#include "core/framework/error_code_helper.h" +#else +#include "core/framework/error_code_helper.h" +#include "core/session/ort_apis.h" +#endif // !defined(ORT_MINIMAL_BUILD) + +using namespace onnxruntime; + +#if !defined(ORT_MINIMAL_BUILD) + +// Wrapper class for OrtExternalResourceImporterImpl +namespace { +struct ExternalResourceImporterWrapper { + const OrtEpDevice* ep_device; + OrtExternalResourceImporterImpl* impl; + + ExternalResourceImporterWrapper(const OrtEpDevice* device, OrtExternalResourceImporterImpl* importer) + : ep_device(device), impl(importer) {} + + ~ExternalResourceImporterWrapper() { + if (impl && impl->Release) { + impl->Release(impl); + } + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalResourceImporterWrapper); +}; + +} // namespace + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { + API_IMPL_BEGIN + if (ep_device == nullptr || out_importer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and out_importer must be provided."); + } + + *out_importer = nullptr; + + // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. + const auto* factory = ep_device->ep_factory; + if (factory == nullptr || + factory->ort_version_supported < 24 || + factory->CreateExternalResourceImporterForDevice == nullptr) { + // EP doesn't support external resource import - not an error, just return nullptr + return nullptr; + } + + OrtExternalResourceImporterImpl* impl = nullptr; + ORT_API_RETURN_IF_ERROR(factory->CreateExternalResourceImporterForDevice( + ep_device->GetMutableFactory(), + ep_device, + &impl)); + + if (impl == nullptr) { + // EP supports the factory method but returned null - not supported for this device + return nullptr; + } + + auto wrapper = std::make_unique(ep_device, impl); + *out_importer = reinterpret_cast(wrapper.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer) { +#if !defined(ORT_MINIMAL_BUILD) + delete reinterpret_cast(importer); +#else + ORT_UNUSED_PARAMETER(importer); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportMemory == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportMemory(wrapper->impl, handle_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External memory import is not supported by this EP."); + } + + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalMemoryHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportMemory(wrapper->impl, desc, &handle)); + + if (handle == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportMemory returned null handle."); + } + + *out_handle = handle; + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle) { +#if !defined(ORT_MINIMAL_BUILD) + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); + } +#else + ORT_UNUSED_PARAMETER(handle); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* /*tensor_location*/, + _Outptr_ OrtValue** out_tensor) { + API_IMPL_BEGIN + if (importer == nullptr || mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, mem_handle, tensor_desc, and out_tensor must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->CreateTensorFromMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateTensorFromMemory is not supported by this EP."); + } + + OrtValue* tensor = nullptr; + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->CreateTensorFromMemory(imp_wrapper->impl, mem_handle, tensor_desc, &tensor)); + + *out_tensor = tensor; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportSemaphore == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportSemaphore(wrapper->impl, type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External semaphore import is not supported by this EP."); + } + + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalSemaphoreHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportSemaphore(wrapper->impl, desc, &handle)); + + if (handle == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportSemaphore returned null handle."); + } + + *out_handle = handle; + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle) { +#if !defined(ORT_MINIMAL_BUILD) + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); + } +#else + ORT_UNUSED_PARAMETER(handle); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->WaitSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "WaitSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->WaitSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->SignalSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SignalSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->SignalSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); + + return nullptr; + API_IMPL_END +} + +#else // defined(ORT_MINIMAL_BUILD) + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(ep_device); + ORT_UNUSED_PARAMETER(out_importer); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(handle_type); + ORT_UNUSED_PARAMETER(out_supported); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(desc); + ORT_UNUSED_PARAMETER(out_handle); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(mem_handle); + ORT_UNUSED_PARAMETER(tensor_desc); + ORT_UNUSED_PARAMETER(tensor_location); + ORT_UNUSED_PARAMETER(out_tensor); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(type); + ORT_UNUSED_PARAMETER(out_supported); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(desc); + ORT_UNUSED_PARAMETER(out_handle); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(semaphore_handle); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(value); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(semaphore_handle); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(value); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +#endif // !defined(ORT_MINIMAL_BUILD) + +static constexpr OrtInteropApi ort_interop_api = { + // NOTE: Application compatibility with newer versions of ORT depends on the Api order within this struct so + // all new functions must be added at the end, and no functions that already exist in an officially released version + // of ORT can be reordered or removed. + + &OrtInteropAPI::CreateExternalResourceImporterForDevice, + &OrtInteropAPI::ReleaseExternalResourceImporter, + &OrtInteropAPI::CanImportMemory, + &OrtInteropAPI::ImportMemory, + &OrtInteropAPI::ReleaseExternalMemoryHandle, + &OrtInteropAPI::CreateTensorFromMemory, + &OrtInteropAPI::CanImportSemaphore, + &OrtInteropAPI::ImportSemaphore, + &OrtInteropAPI::ReleaseExternalSemaphoreHandle, + &OrtInteropAPI::WaitSemaphore, + &OrtInteropAPI::SignalSemaphore, + // End of Version 24 - DO NOT MODIFY ABOVE +}; + +// Checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtInteropApi, SignalSemaphore) / sizeof(void*) == 10, + "Size of version 24 Api cannot change"); // initial version in ORT 1.24 + +ORT_API(const OrtInteropApi*, OrtInteropAPI::GetInteropApi) { + return &ort_interop_api; +} diff --git a/onnxruntime/core/session/ep_interop_api.h b/onnxruntime/core/session/ep_interop_api.h new file mode 100644 index 0000000000000..25ac7b58af8c0 --- /dev/null +++ b/onnxruntime/core/session/ep_interop_api.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/session/onnxruntime_c_api.h" + +namespace OrtInteropAPI { + +// implementation that returns the API struct +ORT_API(const OrtInteropApi*, GetInteropApi); + +ORT_API_STATUS_IMPL(CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + +ORT_API(void, ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer); + +ORT_API_STATUS_IMPL(CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + +ORT_API(void, ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle); + +ORT_API_STATUS_IMPL(CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor); + +ORT_API_STATUS_IMPL(CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + +ORT_API(void, ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle); + +ORT_API_STATUS_IMPL(WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +ORT_API_STATUS_IMPL(SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +} // namespace OrtInteropAPI diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index c658365adedec..a0e090a4d70ad 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,6 +38,7 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" +#include "core/session/ep_interop_api.h" #include "core/session/plugin_ep/ep_api.h" #include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" @@ -3564,240 +3565,9 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, API_IMPL_END } -// ============================================================================ -// External Resource Importer APIs (Version 1.24) -// ============================================================================ - -// Wrapper class for OrtExternalResourceImporterImpl -namespace { -struct ExternalResourceImporterWrapper { - const OrtEpDevice* ep_device; - OrtExternalResourceImporterImpl* impl; - - ExternalResourceImporterWrapper(const OrtEpDevice* device, OrtExternalResourceImporterImpl* importer) - : ep_device(device), impl(importer) {} - - ~ExternalResourceImporterWrapper() { - if (impl && impl->Release) { - impl->Release(impl); - } - } - - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalResourceImporterWrapper); -}; - -} // namespace - -ORT_API_STATUS_IMPL(OrtApis::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, - _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { - API_IMPL_BEGIN - if (ep_device == nullptr || out_importer == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and out_importer must be provided."); - } - - *out_importer = nullptr; - - // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. - const auto* factory = ep_device->ep_factory; - if (factory == nullptr || - factory->ort_version_supported < 24 || - factory->CreateExternalResourceImporterForDevice == nullptr) { - // EP doesn't support external resource import - not an error, just return nullptr - return nullptr; - } - - OrtExternalResourceImporterImpl* impl = nullptr; - ORT_API_RETURN_IF_ERROR(factory->CreateExternalResourceImporterForDevice( - ep_device->GetMutableFactory(), - ep_device, - &impl)); - - if (impl == nullptr) { - // EP supports the factory method but returned null - not supported for this device - return nullptr; - } - - auto wrapper = std::make_unique(ep_device, impl); - *out_importer = reinterpret_cast(wrapper.release()); - - return nullptr; - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer) { - if (importer != nullptr) { - delete reinterpret_cast(importer); - } -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportMemory, _In_ const OrtExternalResourceImporter* importer, - _In_ OrtExternalMemoryHandleType handle_type, - _Out_ bool* out_supported) { - API_IMPL_BEGIN - if (importer == nullptr || out_supported == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); - } - - auto* wrapper = reinterpret_cast(importer); - if (wrapper->impl == nullptr || wrapper->impl->CanImportMemory == nullptr) { - *out_supported = false; - return nullptr; - } - - *out_supported = wrapper->impl->CanImportMemory(wrapper->impl, handle_type); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportMemory, _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalMemoryDescriptor* desc, - _Outptr_ OrtExternalMemoryHandle** out_handle) { - API_IMPL_BEGIN - if (importer == nullptr || desc == nullptr || out_handle == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); - } - - auto* wrapper = reinterpret_cast(importer); - if (wrapper->impl == nullptr || wrapper->impl->ImportMemory == nullptr) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External memory import is not supported by this EP."); - } - - // EP creates derived type and returns base pointer. EP owns the handle lifetime. - OrtExternalMemoryHandle* handle = nullptr; - ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportMemory(wrapper->impl, desc, &handle)); - - if (handle == nullptr) { - return OrtApis::CreateStatus(ORT_FAIL, "ImportMemory returned null handle."); - } - - *out_handle = handle; - - return nullptr; - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle) { - if (handle != nullptr && handle->Release != nullptr) { - handle->Release(handle); - } -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalMemoryHandle* mem_handle, - _In_ const OrtExternalTensorDescriptor* tensor_desc, - _In_opt_ const OrtMemoryInfo* /*tensor_location*/, - _Outptr_ OrtValue** out_tensor) { - API_IMPL_BEGIN - if (importer == nullptr || mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, mem_handle, tensor_desc, and out_tensor must be provided."); - } - - auto* imp_wrapper = reinterpret_cast(importer); - - if (imp_wrapper->impl == nullptr || imp_wrapper->impl->CreateTensorFromMemory == nullptr) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateTensorFromMemory is not supported by this EP."); - } - - OrtValue* tensor = nullptr; - ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->CreateTensorFromMemory(imp_wrapper->impl, mem_handle, tensor_desc, &tensor)); - - *out_tensor = tensor; - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreType type, - _Out_ bool* out_supported) { - API_IMPL_BEGIN - if (importer == nullptr || out_supported == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); - } - - auto* wrapper = reinterpret_cast(importer); - if (wrapper->impl == nullptr || wrapper->impl->CanImportSemaphore == nullptr) { - *out_supported = false; - return nullptr; - } - - *out_supported = wrapper->impl->CanImportSemaphore(wrapper->impl, type); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportSemaphore, _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalSemaphoreDescriptor* desc, - _Outptr_ OrtExternalSemaphoreHandle** out_handle) { - API_IMPL_BEGIN - if (importer == nullptr || desc == nullptr || out_handle == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); - } - - auto* wrapper = reinterpret_cast(importer); - if (wrapper->impl == nullptr || wrapper->impl->ImportSemaphore == nullptr) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External semaphore import is not supported by this EP."); - } - - // EP creates derived type and returns base pointer. EP owns the handle lifetime. - OrtExternalSemaphoreHandle* handle = nullptr; - ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportSemaphore(wrapper->impl, desc, &handle)); - - if (handle == nullptr) { - return OrtApis::CreateStatus(ORT_FAIL, "ImportSemaphore returned null handle."); - } - - *out_handle = handle; - - return nullptr; - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle) { - if (handle != nullptr && handle->Release != nullptr) { - handle->Release(handle); - } -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_WaitSemaphore, _In_ OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreHandle* semaphore_handle, - _In_ OrtSyncStream* stream, - _In_ uint64_t value) { - API_IMPL_BEGIN - if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); - } - - auto* imp_wrapper = reinterpret_cast(importer); - - if (imp_wrapper->impl == nullptr || imp_wrapper->impl->WaitSemaphore == nullptr) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "WaitSemaphore is not supported by this EP."); - } - - ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->WaitSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); - - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_SignalSemaphore, _In_ OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreHandle* semaphore_handle, - _In_ OrtSyncStream* stream, - _In_ uint64_t value) { - API_IMPL_BEGIN - if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); - } - - auto* imp_wrapper = reinterpret_cast(importer); - - if (imp_wrapper->impl == nullptr || imp_wrapper->impl->SignalSemaphore == nullptr) { - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SignalSemaphore is not supported by this EP."); - } - - ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->SignalSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); - - return nullptr; - API_IMPL_END +// GetInteropApi - returns the Interop API struct +ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { + return OrtInteropAPI::GetInteropApi(); } #else // defined(ORT_MINIMAL_BUILD) @@ -3885,84 +3655,9 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, API_IMPL_END } -// External Resource Importer minimal build stubs -ORT_API_STATUS_IMPL(OrtApis::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* /*ep_device*/, - _Outptr_result_maybenull_ OrtExternalResourceImporter** /*out_importer*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateExternalResourceImporterForDevice is not supported in a minimal build."); - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* /*importer*/) { - fprintf(stderr, "External resource import is not supported in a minimal build.\n"); -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportMemory, _In_ const OrtExternalResourceImporter* /*importer*/, - _In_ OrtExternalMemoryHandleType /*handle_type*/, - _Out_ bool* /*out_supported*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportMemory, _In_ OrtExternalResourceImporter* /*importer*/, - _In_ const OrtExternalMemoryDescriptor* /*desc*/, - _Outptr_ OrtExternalMemoryHandle** /*out_handle*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* /*handle*/) { - fprintf(stderr, "External resource import is not supported in a minimal build.\n"); -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CreateTensorFromMemory, _In_ OrtExternalResourceImporter* /*importer*/, - _In_ const OrtExternalMemoryHandle* /*mem_handle*/, - _In_ const OrtExternalTensorDescriptor* /*tensor_desc*/, - _In_opt_ const OrtMemoryInfo* /*tensor_location*/, - _Outptr_ OrtValue** /*out_tensor*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_CanImportSemaphore, _In_ const OrtExternalResourceImporter* /*importer*/, - _In_ OrtExternalSemaphoreType /*type*/, - _Out_ bool* /*out_supported*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_ImportSemaphore, _In_ OrtExternalResourceImporter* /*importer*/, - _In_ const OrtExternalSemaphoreDescriptor* /*desc*/, - _Outptr_ OrtExternalSemaphoreHandle** /*out_handle*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); - API_IMPL_END -} - -ORT_API(void, OrtApis::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* /*handle*/) { - fprintf(stderr, "External resource import is not supported in a minimal build.\n"); -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_WaitSemaphore, _In_ OrtExternalResourceImporter* /*importer*/, - _In_ OrtExternalSemaphoreHandle* /*semaphore_handle*/, - _In_ OrtSyncStream* /*stream*/, - _In_ uint64_t /*value*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::ExternalResourceImporter_SignalSemaphore, _In_ OrtExternalResourceImporter* /*importer*/, - _In_ OrtExternalSemaphoreHandle* /*semaphore_handle*/, - _In_ OrtSyncStream* /*stream*/, - _In_ uint64_t /*value*/) { - API_IMPL_BEGIN - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External resource import is not supported in a minimal build."); - API_IMPL_END +ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { + fprintf(stderr, "The Interop API is not supported in a minimal build.\n"); + return nullptr; } ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession* /*ort_session*/, @@ -4591,17 +4286,7 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::TensorTypeAndShape_HasShape, &OrtApis::KernelInfo_GetConfigEntries, - &OrtApis::CreateExternalResourceImporterForDevice, - &OrtApis::ReleaseExternalResourceImporter, - &OrtApis::ExternalResourceImporter_CanImportMemory, - &OrtApis::ExternalResourceImporter_ImportMemory, - &OrtApis::ReleaseExternalMemoryHandle, - &OrtApis::ExternalResourceImporter_CreateTensorFromMemory, - &OrtApis::ExternalResourceImporter_CanImportSemaphore, - &OrtApis::ExternalResourceImporter_ImportSemaphore, - &OrtApis::ReleaseExternalSemaphoreHandle, - &OrtApis::ExternalResourceImporter_WaitSemaphore, - &OrtApis::ExternalResourceImporter_SignalSemaphore, + &OrtApis::GetInteropApi, &OrtApis::SessionGetEpDeviceForOutputs, }; diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 96ea33c8027d0..3a82a4ca3f362 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -755,47 +755,8 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); -// External Resource Importer APIs -ORT_API_STATUS_IMPL(CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, - _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); - -ORT_API(void, ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer); - -ORT_API_STATUS_IMPL(ExternalResourceImporter_CanImportMemory, _In_ const OrtExternalResourceImporter* importer, - _In_ OrtExternalMemoryHandleType handle_type, - _Out_ bool* out_supported); - -ORT_API_STATUS_IMPL(ExternalResourceImporter_ImportMemory, _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalMemoryDescriptor* desc, - _Outptr_ OrtExternalMemoryHandle** out_handle); - -ORT_API(void, ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle); - -ORT_API_STATUS_IMPL(ExternalResourceImporter_CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalMemoryHandle* mem_handle, - _In_ const OrtExternalTensorDescriptor* tensor_desc, - _In_opt_ const OrtMemoryInfo* tensor_location, - _Outptr_ OrtValue** out_tensor); - -ORT_API_STATUS_IMPL(ExternalResourceImporter_CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreType type, - _Out_ bool* out_supported); - -ORT_API_STATUS_IMPL(ExternalResourceImporter_ImportSemaphore, _In_ OrtExternalResourceImporter* importer, - _In_ const OrtExternalSemaphoreDescriptor* desc, - _Outptr_ OrtExternalSemaphoreHandle** out_handle); - -ORT_API(void, ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle); - -ORT_API_STATUS_IMPL(ExternalResourceImporter_WaitSemaphore, _In_ OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreHandle* semaphore_handle, - _In_ OrtSyncStream* stream, - _In_ uint64_t value); - -ORT_API_STATUS_IMPL(ExternalResourceImporter_SignalSemaphore, _In_ OrtExternalResourceImporter* importer, - _In_ OrtExternalSemaphoreHandle* semaphore_handle, - _In_ OrtSyncStream* stream, - _In_ uint64_t value); +// Interop API +ORT_API(const OrtInteropApi*, GetInteropApi); ORT_API_STATUS_IMPL(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, diff --git a/onnxruntime/test/autoep/test_external_resource_importer.cc b/onnxruntime/test/autoep/test_external_resource_importer.cc index 0362b74bbb27d..06a818bf4805f 100644 --- a/onnxruntime/test/autoep/test_external_resource_importer.cc +++ b/onnxruntime/test/autoep/test_external_resource_importer.cc @@ -28,6 +28,10 @@ class ExternalResourceImporterTest : public ::testing::Test { ep_device_ = registered_ep_.get(); } + const OrtInteropApi& GetInteropApi() const { + return Ort::GetInteropApi(); + } + RegisteredEpDeviceUniquePtr registered_ep_; const OrtEpDevice* ep_device_ = nullptr; }; @@ -35,7 +39,7 @@ class ExternalResourceImporterTest : public ::testing::Test { // Test: Create External Resource Importer TEST_F(ExternalResourceImporterTest, CreateExternalResourceImporter) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { std::string error = Ort::GetApi().GetErrorMessage(status); @@ -46,13 +50,13 @@ TEST_F(ExternalResourceImporterTest, CreateExternalResourceImporter) { ASSERT_NE(importer, nullptr) << "External resource importer should not be null"; // Release the importer - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Memory Import Capability TEST_F(ExternalResourceImporterTest, CanImportMemory) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -60,25 +64,25 @@ TEST_F(ExternalResourceImporterTest, CanImportMemory) { // Check D3D12 Resource support bool can_import_resource = false; - status = Ort::GetApi().ExternalResourceImporter_CanImportMemory( + status = GetInteropApi().CanImportMemory( importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import_resource); ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_RESOURCE should succeed"; EXPECT_TRUE(can_import_resource) << "Example EP should support D3D12 Resource import"; // Check D3D12 Heap support bool can_import_heap = false; - status = Ort::GetApi().ExternalResourceImporter_CanImportMemory( + status = GetInteropApi().CanImportMemory( importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, &can_import_heap); ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_HEAP should succeed"; EXPECT_TRUE(can_import_heap) << "Example EP should support D3D12 Heap import"; - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Semaphore Import Capability TEST_F(ExternalResourceImporterTest, CanImportSemaphore) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -86,18 +90,18 @@ TEST_F(ExternalResourceImporterTest, CanImportSemaphore) { // Check D3D12 Fence support bool can_import_fence = false; - status = Ort::GetApi().ExternalResourceImporter_CanImportSemaphore( + status = GetInteropApi().CanImportSemaphore( importer, ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE, &can_import_fence); ASSERT_EQ(status, nullptr) << "CanImportSemaphore for D3D12_FENCE should succeed"; EXPECT_TRUE(can_import_fence) << "Example EP should support D3D12 Fence import"; - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Import Memory (Simulated) TEST_F(ExternalResourceImporterTest, ImportMemory) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -116,20 +120,20 @@ TEST_F(ExternalResourceImporterTest, ImportMemory) { mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; OrtExternalMemoryHandle* mem_handle = nullptr; - status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); ASSERT_EQ(status, nullptr) << "ImportMemory should succeed"; ASSERT_NE(mem_handle, nullptr) << "Memory handle should not be null"; // Release memory handle - Ort::GetApi().ReleaseExternalMemoryHandle(mem_handle); + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Create Tensor from Imported Memory TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -153,7 +157,7 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; OrtExternalMemoryHandle* mem_handle = nullptr; - status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); ASSERT_EQ(status, nullptr); // Create tensor from imported memory @@ -165,7 +169,7 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { tensor_desc.offset_bytes = 0; OrtValue* tensor = nullptr; - status = Ort::GetApi().ExternalResourceImporter_CreateTensorFromMemory( + status = GetInteropApi().CreateTensorFromMemory( importer, mem_handle, &tensor_desc, nullptr, &tensor); ASSERT_EQ(status, nullptr) << "CreateTensorFromMemory should succeed"; ASSERT_NE(tensor, nullptr) << "Tensor should not be null"; @@ -197,14 +201,14 @@ TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { // Cleanup Ort::GetApi().ReleaseValue(tensor); - Ort::GetApi().ReleaseExternalMemoryHandle(mem_handle); - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Import Semaphore (Simulated) TEST_F(ExternalResourceImporterTest, ImportSemaphore) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -219,20 +223,20 @@ TEST_F(ExternalResourceImporterTest, ImportSemaphore) { sem_desc.native_handle = dummy_handle; OrtExternalSemaphoreHandle* sem_handle = nullptr; - status = Ort::GetApi().ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + status = GetInteropApi().ImportSemaphore(importer, &sem_desc, &sem_handle); ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; ASSERT_NE(sem_handle, nullptr) << "Semaphore handle should not be null"; // Release semaphore handle - Ort::GetApi().ReleaseExternalSemaphoreHandle(sem_handle); + GetInteropApi().ReleaseExternalSemaphoreHandle(sem_handle); - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Wait and Signal Semaphore (Simulated) TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -252,35 +256,35 @@ TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { sem_desc.native_handle = dummy_handle; OrtExternalSemaphoreHandle* sem_handle = nullptr; - status = Ort::GetApi().ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + status = GetInteropApi().ImportSemaphore(importer, &sem_desc, &sem_handle); ASSERT_EQ(status, nullptr); // Signal the semaphore with value 1 - status = Ort::GetApi().ExternalResourceImporter_SignalSemaphore(importer, sem_handle, stream, 1); + status = GetInteropApi().SignalSemaphore(importer, sem_handle, stream, 1); ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; // Wait for value 1 (should succeed immediately since we just signaled it) - status = Ort::GetApi().ExternalResourceImporter_WaitSemaphore(importer, sem_handle, stream, 1); + status = GetInteropApi().WaitSemaphore(importer, sem_handle, stream, 1); ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; // Signal with value 5 - status = Ort::GetApi().ExternalResourceImporter_SignalSemaphore(importer, sem_handle, stream, 5); + status = GetInteropApi().SignalSemaphore(importer, sem_handle, stream, 5); ASSERT_EQ(status, nullptr); // Wait for value 3 (should succeed since current value is 5) - status = Ort::GetApi().ExternalResourceImporter_WaitSemaphore(importer, sem_handle, stream, 3); + status = GetInteropApi().WaitSemaphore(importer, sem_handle, stream, 3); ASSERT_EQ(status, nullptr); // Cleanup - Ort::GetApi().ReleaseExternalSemaphoreHandle(sem_handle); + GetInteropApi().ReleaseExternalSemaphoreHandle(sem_handle); Ort::GetApi().ReleaseSyncStream(stream); - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Multiple Memory Imports TEST_F(ExternalResourceImporterTest, MultipleMemoryImports) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -299,23 +303,23 @@ TEST_F(ExternalResourceImporterTest, MultipleMemoryImports) { mem_desc.offset_bytes = 0; mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; - status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &handles[i]); + status = GetInteropApi().ImportMemory(importer, &mem_desc, &handles[i]); ASSERT_EQ(status, nullptr) << "ImportMemory " << i << " should succeed"; ASSERT_NE(handles[i], nullptr); } // Release all handles for (int i = 0; i < kNumBuffers; ++i) { - Ort::GetApi().ReleaseExternalMemoryHandle(handles[i]); + GetInteropApi().ReleaseExternalMemoryHandle(handles[i]); } - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: Access Mode Variations TEST_F(ExternalResourceImporterTest, AccessModeVariations) { OrtExternalResourceImporter* importer = nullptr; - OrtStatus* status = Ort::GetApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); if (status != nullptr) { Ort::GetApi().ReleaseStatus(status); GTEST_SKIP() << "External resource interop not supported"; @@ -336,14 +340,14 @@ TEST_F(ExternalResourceImporterTest, AccessModeVariations) { mem_desc.access_mode = access_mode; OrtExternalMemoryHandle* mem_handle = nullptr; - status = Ort::GetApi().ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); ASSERT_EQ(status, nullptr) << "ImportMemory with access_mode " << access_mode << " should succeed"; ASSERT_NE(mem_handle, nullptr); - Ort::GetApi().ReleaseExternalMemoryHandle(mem_handle); + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); } - Ort::GetApi().ReleaseExternalResourceImporter(importer); + GetInteropApi().ReleaseExternalResourceImporter(importer); } // Test: SessionGetEpDeviceForOutputs