diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index e59463b6b91f1..5ec45a64e46bb 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -146,9 +146,9 @@ endif () target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen) add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) else() - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) endif() target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index d1b652229e4b6..03df16c315376 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -330,6 +330,9 @@ ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. ORT_RUNTIME_CLASS(ExternalInitializerInfo); +ORT_RUNTIME_CLASS(ExternalResourceImporter); // Capability object for external resource import +ORT_RUNTIME_CLASS(ExternalMemoryHandle); // EP-imported view of shared external allocation +ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared external semaphore #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -878,6 +881,9 @@ typedef struct OrtModelEditorApi OrtModelEditorApi; struct OrtCompileApi; typedef struct OrtCompileApi OrtCompileApi; +struct OrtInteropApi; +typedef struct OrtInteropApi OrtInteropApi; + struct OrtEpApi; typedef struct OrtEpApi OrtEpApi; @@ -955,6 +961,84 @@ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t n * * \nosubgrouping */ + +/** \addtogroup Global + * @{ + */ + +/** \brief External memory handle type for importing GPU resources. + * + * \since Version 1.24. + */ +typedef enum OrtExternalMemoryHandleType { + ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(resource) */ + ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 1, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(heap) */ +} OrtExternalMemoryHandleType; + +/** \brief Access mode for imported external memory. + * + * \since Version 1.24. + */ +typedef enum OrtExternalMemoryAccessMode { + ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE = 0, /**< Memory can be read and written */ + ORT_EXTERNAL_MEMORY_ACCESS_READ_ONLY = 1, /**< Memory is read-only */ + ORT_EXTERNAL_MEMORY_ACCESS_WRITE_ONLY = 2, /**< Memory is write-only */ +} OrtExternalMemoryAccessMode; + +/** \brief Descriptor for importing external memory. + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +typedef struct OrtExternalMemoryDescriptor { + uint32_t version; /**< Must be ORT_API_VERSION */ + OrtExternalMemoryHandleType handle_type; /**< Type of the external memory handle */ + void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ + size_t size_bytes; /**< Total size in bytes of the external allocation */ + size_t offset_bytes; /**< Offset in bytes into the allocation (default 0) */ + OrtExternalMemoryAccessMode access_mode; /**< Access mode for the imported memory */ +} OrtExternalMemoryDescriptor; + +/** \brief External semaphore type for GPU synchronization. + * + * \since Version 1.24. + */ +typedef enum OrtExternalSemaphoreType { + ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE = 0, /**< Shared HANDLE from ID3D12Device::CreateSharedHandle(fence) */ +} OrtExternalSemaphoreType; + +/** \brief Descriptor for importing external semaphores. + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +typedef struct OrtExternalSemaphoreDescriptor { + uint32_t version; /**< Must be ORT_API_VERSION */ + OrtExternalSemaphoreType type; /**< Type of the external semaphore */ + void* native_handle; /**< Platform-specific handle (e.g., Windows HANDLE) */ +} OrtExternalSemaphoreDescriptor; + +/** \brief Descriptor for creating a tensor from imported external memory. + * + * \note The version field must be set to ORT_API_VERSION. + * This ensures forward compatibility as fields may be added in future versions. + * + * \since Version 1.24. + */ +typedef struct OrtExternalTensorDescriptor { + uint32_t version; /**< Must be ORT_API_VERSION */ + ONNXTensorElementDataType element_type; /**< Data type of tensor elements */ + const int64_t* shape; /**< Array of dimension sizes */ + size_t rank; /**< Number of dimensions */ + size_t offset_bytes; /**< Optional offset within imported memory (default 0) */ +} OrtExternalTensorDescriptor; + +/// @} + /* * Public enum for compiled model compatibility across EPs. */ @@ -6608,6 +6692,45 @@ struct OrtApi { * \since Version 1.24 */ ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Get the EP Interop API instance. + * + * Get the Interop API instance to work with external resources. This API provides functions + * for importing external GPU memory and semaphores for zero-copy sharing between ORT inference + * and other GPU workloads. + * + * \return Interop API struct instance. + * + * \since Version 1.24. + */ + const OrtInteropApi*(ORT_API_CALL* GetInteropApi)(void); + + /** \brief Get the EP device assigned to each session output. + * + * Returns the OrtEpDevice assigned to each output of the session after graph partitioning. + * This allows validation that outputs are placed on the expected device for external resource sharing. + * + * The EP device for each output is determined by which execution provider claims that output + * during graph partitioning. This information is useful for: + * - Validating that outputs will be placed on the expected device for external resource sharing + * - Deciding whether to use external memory handles for outputs + * + * \param[in] session The OrtSession instance to query. + * \param[out] outputs_ep_devices An array to be filled with the EP device for each output. + * The array must be allocated by the caller with space for + * OrtEpDevice* values for each output. + * The order is the same as returned by SessionGetOutputName. + * \param[in] num_outputs The number of outputs in the session. Must match SessionGetOutputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24 + */ + ORT_API2_STATUS(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_outputs); + + /// @} }; /* @@ -7405,6 +7528,214 @@ struct OrtCompileApi { _In_ OrtGetInitializerLocationFunc get_initializer_location_func, _In_ void* state); }; +/** + * \brief The OrtInteropApi struct provides functions for external resource interop with execution providers. + * + * This API enables importing external GPU resources (memory and semaphores) for zero-copy sharing + * between ORT inference and other GPU workloads (e.g., D3D12 applications, media pipelines). + * + * The API is designed to be EP-agnostic and can be extended to support various GPU interop mechanisms + * (D3D12 shared handles, CUDA external memory, Vulkan, etc.). + * + * Example usage (error handling not shown): + * const OrtInteropApi* interop_api = ort_api->GetInteropApi(); + * OrtExternalResourceImporter* importer = NULL; + * + * status = interop_api->CreateExternalResourceImporterForDevice(ep_device, &importer); + * bool can_import = false; + * status = interop_api->CanImportMemory(importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import); + * if (can_import) { + * OrtExternalMemoryHandle* mem_handle = NULL; + * status = interop_api->ImportMemory(importer, &mem_desc, &mem_handle); + * // ... use mem_handle to create tensors ... + * interop_api->ReleaseExternalMemoryHandle(mem_handle); + * } + * interop_api->ReleaseExternalResourceImporter(importer); + * + * \since Version 1.24. + */ +struct OrtInteropApi { + /// \name OrtExternalResourceImporter + /// @{ + + /** \brief Create an external resource importer for a specific EP device. + * + * The external resource importer is a capability object that provides methods for importing + * external GPU memory and semaphores for zero-copy import with an execution provider. + * + * \param[in] ep_device The OrtEpDevice instance to create the importer for. + * \param[out] out_importer Output parameter set to the created OrtExternalResourceImporter instance. + * Returns nullptr if the EP does not support external resource import. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateExternalResourceImporterForDevice, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + + /** \brief Release an OrtExternalResourceImporter instance. + * + * \param[in] importer The OrtExternalResourceImporter instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalResourceImporter); + + /// @} + /// \name Memory Import + /// @{ + + /** \brief Check if the external resource importer can import a specific memory handle type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] handle_type The type of external memory handle to check. + * \param[out] out_supported Set to true if the handle type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CanImportMemory, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + + /** \brief Import external memory into the execution provider. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external memory handle and properties. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle. + * The caller owns the returned handle and must call ReleaseExternalMemoryHandle to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + + /** \brief Release an OrtExternalMemoryHandle instance. + * + * \param[in] handle The OrtExternalMemoryHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalMemoryHandle); + + /** \brief Create a tensor backed by imported external memory. + * + * The created tensor is a view over the imported memory and does not copy data. + * The OrtExternalMemoryHandle must remain valid for the lifetime of the tensor. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] mem_handle The imported external memory handle. + * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. + * \param[in] tensor_location Optional OrtMemoryInfo for the tensor location. May be nullptr. + * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * The caller owns the returned tensor and must call ReleaseValue to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateTensorFromMemory, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor); + + /// @} + /// \name Semaphore Import + /// @{ + + /** \brief Check if the external resource importer can import a specific semaphore type. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] type The type of external semaphore to check. + * \param[out] out_supported Set to true if the semaphore type is supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CanImportSemaphore, + _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + + /** \brief Import an external semaphore into the execution provider. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] desc Descriptor containing the external semaphore handle and type. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + + /** \brief Release an OrtExternalSemaphoreHandle instance. + * + * \param[in] handle The OrtExternalSemaphoreHandle instance to release. May be nullptr. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(ExternalSemaphoreHandle); + + /** \brief Wait on an external semaphore on the EP's stream. + * + * Inserts a wait operation into the EP's stream that blocks until the semaphore + * reaches the specified value. This is used to synchronize with external GPU work + * (e.g., D3D12 timeline fence). + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to wait on. + * \param[in] value The fence/semaphore value to wait for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(WaitSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Signal an external semaphore from the EP's stream. + * + * Inserts a signal operation into the EP's stream that sets the semaphore + * to the specified value when reached. This is used to notify external GPU work + * (e.g., D3D12 timeline fence) that ORT inference is complete. + * + * \param[in] importer The OrtExternalResourceImporter instance. + * \param[in] semaphore_handle The imported external semaphore. + * \param[in] stream The OrtSyncStream to signal from. + * \param[in] value The fence/semaphore value to signal. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SignalSemaphore, + _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /// @} +}; + /* * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index bc75aabc7e229..4d675277dbad2 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -236,6 +236,20 @@ inline const OrtCompileApi& GetCompileApi() { return *api; } +/// +/// This returns a reference to the ORT C Interop API. Used for external resource import with EPs. +/// +/// ORT C Interop API reference +inline const OrtInteropApi& GetInteropApi() { + auto* api = GetApi().GetInteropApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("Interop API is not available in this build", ORT_FAIL); + } + + return *api; +} + /// /// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider. /// @@ -1607,6 +1621,7 @@ struct ConstSessionImpl : Base { std::vector GetMemoryInfoForInputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForInputs std::vector GetMemoryInfoForOutputs() const; ///< Wrapper for OrtApi::SessionGetMemoryInfoForOutputs std::vector GetEpDeviceForInputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForInputs + std::vector GetEpDeviceForOutputs() const; ///< Wrapper for OrtApi::SessionGetEpDeviceForOutputs /** \brief Returns a copy of input name at the specified index. * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index aff1061a67fea..dc73614ef0445 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1660,6 +1660,19 @@ inline std::vector ConstSessionImpl::GetEpDeviceForInputs() co return input_devices; } +template +inline std::vector ConstSessionImpl::GetEpDeviceForOutputs() const { + auto num_outputs = GetOutputCount(); + std::vector output_devices; + if (num_outputs > 0) { + output_devices.resize(num_outputs); + ThrowOnError(GetApi().SessionGetEpDeviceForOutputs(this->p_, + reinterpret_cast(output_devices.data()), + num_outputs)); + } + return output_devices; +} + template inline uint64_t ConstSessionImpl::GetProfilingStartTimeNs() const { uint64_t out; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 6fa5c8dea04e6..eb716afc76d6f 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -24,6 +24,66 @@ ORT_RUNTIME_CLASS(DataTransferImpl); ORT_RUNTIME_CLASS(SyncNotificationImpl); ORT_RUNTIME_CLASS(SyncStreamImpl); +ORT_RUNTIME_CLASS(ExternalResourceImporterImpl); + +/** \brief Base struct for imported external memory handles. + * + * EPs derive from this struct to add EP-specific fields (e.g., CUdeviceptr for CUDA). + * EP is responsible for creating and releasing instances of the derived type. + * + * Example derived type for CUDA EP: + * \code + * struct MyCudaExternalMemoryHandle : OrtExternalMemoryHandle { + * CUexternalMemory ext_memory; + * CUdeviceptr mapped_ptr; + * bool is_dedicated; + * }; + * \endcode + * + * \since Version 1.24. + */ +struct OrtExternalMemoryHandle { + uint32_t version; ///< Must be ORT_API_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking + size_t size_bytes; ///< Size of the imported memory + size_t offset_bytes; ///< Offset into the imported memory + + /** \brief Release callback for this handle. EP sets this to its release function. + * + * ORT calls this when ReleaseExternalMemoryHandle is invoked. The EP's callback + * should cast the handle to its derived type and delete it. + */ + void(ORT_API_CALL* Release)(_In_ OrtExternalMemoryHandle* handle); +}; + +/** \brief Base struct for imported external semaphore handles. + * + * EPs derive from this struct to add EP-specific fields (e.g., CUexternalSemaphore for CUDA). + * EP is responsible for creating and releasing instances of the derived type. + * + * Example derived type for CUDA EP: + * \code + * struct MyCudaExternalSemaphoreHandle : OrtExternalSemaphoreHandle { + * CUexternalSemaphore ext_semaphore; + * }; + * \endcode + * + * \since Version 1.24. + */ +struct OrtExternalSemaphoreHandle { + uint32_t version; ///< Must be ORT_API_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalSemaphoreType type; ///< Original semaphore type + + /** \brief Release callback for this handle. EP sets this to its release function. + * + * ORT calls this when ReleaseExternalSemaphoreHandle is invoked. The EP's callback + * should cast the handle to its derived type and delete it. + */ + void(ORT_API_CALL* Release)(_In_ OrtExternalSemaphoreHandle* handle); +}; + // Opaque types for kernel-based EPs ORT_RUNTIME_CLASS(KernelRegistry); ORT_RUNTIME_CLASS(KernelDefBuilder); @@ -190,6 +250,180 @@ struct OrtSyncStreamImpl { ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr); }; +/** \brief Struct that an EP implements for external resource import (memory + semaphore import). + * + * This capability object provides methods for importing external GPU memory and semaphores + * for zero-copy import. EPs that support D3D12, CUDA, HIP, or Vulkan external resource APIs + * can implement this interface. + * + * \since Version 1.24. + */ +struct OrtExternalResourceImporterImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + // Memory operations (stream-independent) + + /** \brief Check if the implementation can import external memory of the given handle type. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle_type The type of external memory handle to check. + * \return True if the handle type is supported. + * + * \since Version 1.24. + */ + ORT_API_T(bool, CanImportMemory, + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type); + + /** \brief Import external memory. + * + * The EP creates a derived type of OrtExternalMemoryHandle and returns a pointer to the base. + * EP is responsible for the lifetime of the handle (release via ReleaseMemory). + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] desc Descriptor containing the external memory handle and properties. + * \param[out] out_handle Output parameter set to the created OrtExternalMemoryHandle (EP's derived type). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + + /** \brief Release an imported external memory handle. + * + * The EP deletes its derived type instance. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The OrtExternalMemoryHandle to release (EP casts to its derived type). + * + * \since Version 1.24. + */ + ORT_API_T(void, ReleaseMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandle* handle); + + /** \brief Create a tensor backed by imported external memory. + * + * The created tensor is a view over the imported memory and does not copy data. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] mem_handle The imported external memory handle (EP casts to its derived type). + * \param[in] tensor_desc Descriptor specifying tensor element type, shape, and optional offset. + * \param[out] out_tensor Output parameter set to the created OrtValue containing the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateTensorFromMemory, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor); + + // Semaphore operations (require stream) + + /** \brief Check if the implementation can import external semaphores of the given type. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] type The type of external semaphore to check. + * \return True if the semaphore type is supported. + * + * \since Version 1.24. + */ + ORT_API_T(bool, CanImportSemaphore, + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type); + + /** \brief Import an external semaphore. + * + * The EP creates a derived type of OrtExternalSemaphoreHandle and returns a pointer to the base. + * EP is responsible for the lifetime of the handle (release via ReleaseSemaphore). + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] desc Descriptor containing the external semaphore handle and type. + * \param[out] out_handle Output parameter set to the created OrtExternalSemaphoreHandle (EP's derived type). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(ImportSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + + /** \brief Release an imported external semaphore handle. + * + * The EP deletes its derived type instance. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The OrtExternalSemaphoreHandle to release (EP casts to its derived type). + * + * \since Version 1.24. + */ + ORT_API_T(void, ReleaseSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle); + + /** \brief Wait on an external semaphore on the EP's stream. + * + * Inserts a wait operation into the EP's stream that blocks until the semaphore + * reaches the specified value. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The imported external semaphore (EP casts to its derived type). + * \param[in] stream The OrtSyncStream to wait on. + * \param[in] value The fence/semaphore value to wait for. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(WaitSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + /** \brief Signal an external semaphore from the EP's stream. + * + * Inserts a signal operation into the EP's stream that sets the semaphore + * to the specified value when reached. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * \param[in] handle The imported external semaphore (EP casts to its derived type). + * \param[in] stream The OrtSyncStream to signal from. + * \param[in] value The fence/semaphore value to signal. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SignalSemaphore, + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + + // Release the capability object itself + + /** \brief Release the OrtExternalResourceImporterImpl instance. + * + * This is called by ORT when the OrtExternalResourceImporterImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtExternalResourceImporterImpl instance. + * + * \since Version 1.24. + */ + ORT_API_T(void, Release, _In_ OrtExternalResourceImporterImpl* this_ptr); +}; + struct OrtNodeFusionOptions; typedef struct OrtNodeFusionOptions OrtNodeFusionOptions; @@ -1413,6 +1647,32 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); + + /** \brief Create an OrtExternalResourceImporterImpl for external resource import. + * + * This is used to create an external resource importer that enables zero-copy import of + * external GPU memory (e.g., D3D12 shared resources) and synchronization primitives + * (e.g., D3D12 timeline fences). + * + * EPs that support external resource import (via CUDA, HIP, Vulkan, or D3D12 APIs) can + * implement this to allow applications to share GPU resources without copies. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] ep_device The OrtEpDevice to create the external resource importer for. + * \param[out] out_importer The created OrtExternalResourceImporterImpl instance. + * Set to nullptr if external resource import is not supported. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. + * An EP factory should only implement this if it supports external resource import. + * If not implemented or not supported, return ORT_NOT_IMPLEMENTED or set out_importer to nullptr. + * + * \since Version 1.24. + */ + ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer); }; #ifdef __cplusplus diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e5015e705958d..545330c5bee20 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -4,6 +4,7 @@ #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/framework/provider_options.h" @@ -19,6 +20,8 @@ #include "nv_data_transfer.h" #include "nv_allocator.h" +#include + using namespace onnxruntime; namespace onnxruntime { @@ -516,32 +519,509 @@ struct NvTrtRtxSyncStreamImpl : OrtSyncStreamImpl { const OrtApi& ort_api; }; +#if defined(_WIN32) + +// External Resource Import Implementation (D3D12 to CUDA) +/** + * @brief Derived handle for imported external memory from D3D12 to CUDA. + * + * Derives from OrtExternalMemoryHandle (base struct) and adds CUDA-specific fields. + * This struct holds the CUDA external memory object and the mapped device pointer + * that can be used for zero-copy tensor creation. + */ +struct NvTrtRtxExternalMemoryHandle : OrtExternalMemoryHandle { + CUexternalMemory ext_memory; ///< CUDA external memory object + CUdeviceptr mapped_ptr; ///< Mapped device pointer for tensor access + bool is_dedicated; ///< Whether the D3D12 resource is a dedicated allocation + + NvTrtRtxExternalMemoryHandle() + : ext_memory(nullptr), mapped_ptr(0), is_dedicated(true) { + // Initialize base struct fields + version = ORT_API_VERSION; + ep_device = nullptr; + handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + size_bytes = 0; + offset_bytes = 0; + Release = ReleaseCallback; + } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalMemoryHandle* handle) noexcept { + if (handle == nullptr) return; + auto* derived = static_cast(handle); + // Destroy the external memory object (also releases mapped buffer) + if (derived->ext_memory != nullptr) { + cuDestroyExternalMemory(derived->ext_memory); + } + delete derived; + } +}; + +/** + * @brief Derived handle for imported external semaphore from D3D12 fence to CUDA. + * + * Derives from OrtExternalSemaphoreHandle (base struct) and adds CUDA-specific fields. + * D3D12 timeline fences are imported as CUDA external semaphores, enabling + * GPU-GPU synchronization between D3D12 and CUDA streams. + */ +struct NvTrtRtxExternalSemaphoreHandle : OrtExternalSemaphoreHandle { + CUexternalSemaphore ext_semaphore; ///< CUDA external semaphore object + + NvTrtRtxExternalSemaphoreHandle() + : ext_semaphore(nullptr) { + // Initialize base struct fields + version = ORT_API_VERSION; + ep_device = nullptr; + type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + Release = ReleaseCallback; + } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalSemaphoreHandle* handle) noexcept { + if (handle == nullptr) return; + auto* derived = static_cast(handle); + // Destroy the external semaphore object + if (derived->ext_semaphore != nullptr) { + cuDestroyExternalSemaphore(derived->ext_semaphore); + } + delete derived; + } +}; + +/** + * @brief Implementation of OrtExternalResourceImporterImpl for NvTensorRtRtx EP. + * + * This struct implements the external resource importer interface using CUDA Driver APIs + * to import D3D12 shared resources and timeline fences for zero-copy import. + * + * Supported handle types: + * - ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE → CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + * - ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP → CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP + * - ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE → CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE + */ +struct NvTrtRtxExternalResourceImporterImpl : OrtExternalResourceImporterImpl { + NvTrtRtxExternalResourceImporterImpl(const OrtEpDevice* ep_device, const OrtApi& ort_api_in) + : ep_device_{ep_device}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + + // Memory operations + CanImportMemory = CanImportMemoryImpl; + ImportMemory = ImportMemoryImpl; + ReleaseMemory = ReleaseMemoryImpl; + CreateTensorFromMemory = CreateTensorFromMemoryImpl; + + // Semaphore operations + CanImportSemaphore = CanImportSemaphoreImpl; + ImportSemaphore = ImportSemaphoreImpl; + ReleaseSemaphore = ReleaseSemaphoreImpl; + WaitSemaphore = WaitSemaphoreImpl; + SignalSemaphore = SignalSemaphoreImpl; + + // Release + Release = ReleaseImpl; + } + + static bool ORT_API_CALL CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type) noexcept { + (void)this_ptr; + // CUDA supports both D3D12 resource and heap handles + return handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE || + handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP; + } + + static OrtStatus* ORT_API_CALL ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportMemory"); + } + + // Validate descriptor version - check minimum supported version for forward compatibility + if (desc->version < ORT_API_VERSION) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "OrtExternalMemoryDescriptor version too old"); + } + + *out_handle = nullptr; + + // Validate handle type + if (!CanImportMemoryImpl(this_ptr, desc->handle_type)) { + return impl.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external memory handle type for CUDA import"); + } + + // Validate offset does not exceed allocation size + if (desc->offset_bytes > desc->size_bytes) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "offset_bytes exceeds size_bytes in OrtExternalMemoryDescriptor"); + } + + // Set CUDA device for this EP. The imported external memory handle is associated with + // the device where it was imported and remains valid regardless of subsequent cudaSetDevice + // calls. Multi-GPU scenarios with different sessions/EPs work correctly because each + // importer is bound to its EP's device via ep_device_->device_memory_info. + (void)cuCtxSetCurrent(nullptr); // Reset context + CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.DeviceId())); + + // Map ORT handle type to CUDA handle type + CUexternalMemoryHandleType cu_handle_type; + bool is_dedicated = true; + switch (desc->handle_type) { + case ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE: + cu_handle_type = CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + is_dedicated = true; // D3D12 committed resources are dedicated + break; + case ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP: + cu_handle_type = CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP; + is_dedicated = false; // D3D12 heaps are not dedicated + break; + default: + // Should not reach here - CanImportMemory already validated handle type + return impl.ort_api.CreateStatus(ORT_EP_FAIL, "Unexpected external memory handle type"); + } + + // Setup external memory handle descriptor + CUDA_EXTERNAL_MEMORY_HANDLE_DESC ext_mem_desc = {}; + ext_mem_desc.type = cu_handle_type; + ext_mem_desc.handle.win32.handle = desc->native_handle; + ext_mem_desc.size = desc->size_bytes; + ext_mem_desc.flags = is_dedicated ? CUDA_EXTERNAL_MEMORY_DEDICATED : 0; + + // Import the external memory + CUexternalMemory ext_memory = nullptr; + CUresult cu_result = cuImportExternalMemory(&ext_memory, &ext_mem_desc); + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuImportExternalMemory failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + // Map the external memory to get a device pointer + CUDA_EXTERNAL_MEMORY_BUFFER_DESC buffer_desc = {}; + buffer_desc.offset = desc->offset_bytes; + buffer_desc.size = desc->size_bytes - desc->offset_bytes; + buffer_desc.flags = 0; + + CUdeviceptr mapped_ptr = 0; + cu_result = cuExternalMemoryGetMappedBuffer(&mapped_ptr, ext_memory, &buffer_desc); + if (cu_result != CUDA_SUCCESS) { + cuDestroyExternalMemory(ext_memory); + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuExternalMemoryGetMappedBuffer failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + // Create and return the derived handle (cast to base pointer) + auto handle = std::make_unique(); + + handle->ep_device = impl.ep_device_; + handle->handle_type = desc->handle_type; + handle->size_bytes = desc->size_bytes; + handle->offset_bytes = desc->offset_bytes; + handle->ext_memory = ext_memory; + handle->mapped_ptr = mapped_ptr; + handle->is_dedicated = is_dedicated; + + *out_handle = handle.release(); + return nullptr; + } + + static void ORT_API_CALL ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandle* handle) noexcept { + (void)this_ptr; + + if (handle == nullptr) { + return; + } + + // The handle has a Release callback that does the actual cleanup + // This method is called from OrtExternalResourceImporterImpl::ReleaseMemory + // The Release callback in the handle will call the static ReleaseCallback + auto* mem_handle = static_cast(handle); + + // Destroy the external memory object (also releases mapped buffer) + if (mem_handle->ext_memory != nullptr) { + cuDestroyExternalMemory(mem_handle->ext_memory); + } + + delete mem_handle; + } + + static OrtStatus* ORT_API_CALL CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept { + auto& impl = *static_cast(this_ptr); + + if (mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to CreateTensorFromMemory"); + } + + // Validate descriptor version - check minimum supported version for forward compatibility + if (tensor_desc->version < ORT_API_VERSION) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "OrtExternalTensorDescriptor version too old"); + } + + *out_tensor = nullptr; + + auto* handle = static_cast(mem_handle); + + // Validate tensor offset does not exceed available buffer size + size_t available_size = handle->size_bytes - handle->offset_bytes; + if (tensor_desc->offset_bytes > available_size) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "tensor offset_bytes exceeds available imported memory size"); + } + + // Calculate the data pointer with tensor offset + void* data_ptr = reinterpret_cast(handle->mapped_ptr + tensor_desc->offset_bytes); + + // Get memory info from the EP device (the importer is associated with the OrtEpDevice) + const OrtMemoryInfo* memory_info = impl.ep_device_->device_memory_info; + + // Create tensor that references the imported memory. The tensor does not own the memory - + // the user manages the lifetime of both the OrtValue and OrtExternalMemoryHandle. + // The user must keep the handle alive while the tensor is in use. + // No deleter is needed since this is for inference inputs/outputs where the user controls lifetime. + OrtStatus* status = impl.ort_api.CreateTensorWithDataAsOrtValue( + memory_info, + data_ptr, + handle->size_bytes - tensor_desc->offset_bytes, + tensor_desc->shape, + tensor_desc->rank, + tensor_desc->element_type, + out_tensor); + + return status; + } + + static bool ORT_API_CALL CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type) noexcept { + (void)this_ptr; + // CUDA supports D3D12 timeline fences + return type == ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + } + + static OrtStatus* ORT_API_CALL ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportSemaphore"); + } + + // Validate descriptor version - check minimum supported version for forward compatibility + if (desc->version < ORT_API_VERSION) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "OrtExternalSemaphoreDescriptor version too old"); + } + + *out_handle = nullptr; + + // Validate semaphore type + if (!CanImportSemaphoreImpl(this_ptr, desc->type)) { + return impl.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external semaphore type for CUDA import"); + } + + // Set CUDA device for this EP. Imported semaphore handles remain valid regardless of + // subsequent cudaSetDevice calls. + CUDA_RETURN_IF_ERROR(cudaSetDevice(impl.DeviceId())); + + // Setup external semaphore handle descriptor for D3D12 fence + CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC ext_sem_desc = {}; + ext_sem_desc.type = CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE; + ext_sem_desc.handle.win32.handle = desc->native_handle; + ext_sem_desc.flags = 0; + + // Import the external semaphore + CUexternalSemaphore ext_semaphore = nullptr; + CUresult cu_result = cuImportExternalSemaphore(&ext_semaphore, &ext_sem_desc); + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuImportExternalSemaphore failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + // Create and return the derived handle (cast to base pointer) + auto handle = std::make_unique(); + + // Populate base struct fields + handle->ep_device = impl.ep_device_; + handle->type = desc->type; + + // Populate derived fields + handle->ext_semaphore = ext_semaphore; + + *out_handle = handle.release(); // Return base pointer + return nullptr; + } + + static void ORT_API_CALL ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle) noexcept { + (void)this_ptr; + + if (handle == nullptr) { + return; + } + + auto* sem_handle = static_cast(handle); + + if (sem_handle->ext_semaphore != nullptr) { + cuDestroyExternalSemaphore(sem_handle->ext_semaphore); + } + + delete sem_handle; + } + + static OrtStatus* ORT_API_CALL WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr || stream == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to WaitSemaphore"); + } + + auto* sem_handle = static_cast(handle); + + // Get the CUDA stream from OrtSyncStream + cudaStream_t cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + + // Setup wait parameters for D3D12 fence (timeline semaphore) + CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS wait_params = {}; + wait_params.params.fence.value = value; + wait_params.flags = 0; + + // Wait on the external semaphore asynchronously + CUresult cu_result = cuWaitExternalSemaphoresAsync( + &sem_handle->ext_semaphore, + &wait_params, + 1, // numExtSems + cuda_stream); + + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuWaitExternalSemaphoresAsync failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + return nullptr; + } + + static OrtStatus* ORT_API_CALL SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr || stream == nullptr) { + return impl.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to SignalSemaphore"); + } + + auto* sem_handle = static_cast(handle); + + // Get the CUDA stream from OrtSyncStream + cudaStream_t cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + + // Setup signal parameters for D3D12 fence (timeline semaphore) + CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS signal_params = {}; + signal_params.params.fence.value = value; + signal_params.flags = 0; + + // Signal the external semaphore asynchronously + CUresult cu_result = cuSignalExternalSemaphoresAsync( + &sem_handle->ext_semaphore, + &signal_params, + 1, // numExtSems + cuda_stream); + + if (cu_result != CUDA_SUCCESS) { + const char* error_str = nullptr; + cuGetErrorString(cu_result, &error_str); + std::string error_msg = "cuSignalExternalSemaphoresAsync failed: "; + error_msg += error_str ? error_str : "unknown error"; + return impl.ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + return nullptr; + } + + static void ORT_API_CALL ReleaseImpl(_In_ OrtExternalResourceImporterImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + /// @brief Get the CUDA device ID from the EP device's memory info. + int DeviceId() const { + return ep_device_->device_memory_info->device.Id(); + } + + private: + const OrtEpDevice* ep_device_; + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +#endif // defined(_WIN32) + // OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection. struct NvTensorRtRtxEpFactory : OrtEpFactory { using MemoryInfoUniquePtr = std::unique_ptr>; NvTensorRtRtxEpFactory(const OrtApi& ort_api_in, - const OrtLogger& default_logger_in) : ort_api{ort_api_in}, + const OrtLogger& default_logger_in) : OrtEpFactory{}, // Zero-initialize base struct + ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()}, default_logger{default_logger_in}, data_transfer_impl{ort_api_in} { + // Initialize all OrtEpFactory function pointers explicitly to avoid garbage values + // Required members GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; - GetVendorId = GetVendorIdImpl; GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + // Allocator CreateAllocator = CreateAllocatorImpl; ReleaseAllocator = ReleaseAllocatorImpl; + // Data transfer CreateDataTransfer = CreateDataTransferImpl; + // Stream support IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + // Optional members - explicitly set to nullptr if not implemented + ValidateCompiledModelCompatibilityInfo = nullptr; // Not implemented + SetEnvironmentOptions = nullptr; // Not implemented + + // External resource import (D3D12 to CUDA) + CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } @@ -735,6 +1215,42 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return nullptr; } + /** @brief Create an external resource importer for D3D12 to CUDA import. + * + * This enables zero-copy import of D3D12 shared resources and timeline fences. + * The implementation uses CUDA Driver APIs (cuImportExternalMemory, cuImportExternalSemaphore). + * + * @param this_ptr The OrtEpFactory instance. + * @param ep_device The OrtEpDevice to create the importer for. + * @param out_importer Output parameter set to the created OrtExternalResourceImporterImpl. + * @return nullptr on success, OrtStatus with error on failure. + */ + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtEpDevice* ep_device, + OrtExternalResourceImporterImpl** out_importer) noexcept { + auto& factory = *static_cast(this_ptr); + + if (out_importer == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "out_importer cannot be nullptr"); + } + + *out_importer = nullptr; + +#if defined(_WIN32) + // Create the external resource importer + auto importer = std::make_unique(ep_device, factory.ort_api); + *out_importer = importer.release(); + + return nullptr; +#else + ORT_UNUSED_PARAMETER(ep_device); + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "External resource import is only available on Windows builds."); +#endif + } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { gpu_memory_infos.reserve(num_devices); host_accessible_memory_infos.reserve(num_devices); diff --git a/onnxruntime/core/session/ep_interop_api.cc b/onnxruntime/core/session/ep_interop_api.cc new file mode 100644 index 0000000000000..aa580792873d6 --- /dev/null +++ b/onnxruntime/core/session/ep_interop_api.cc @@ -0,0 +1,389 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/ep_interop_api.h" + +#if !defined(ORT_MINIMAL_BUILD) +#include + +#include "core/session/ort_apis.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_env.h" +#include "core/session/abi_devices.h" +#include "core/common/logging/logging.h" +#include "core/framework/error_code_helper.h" +#else +#include "core/framework/error_code_helper.h" +#include "core/session/ort_apis.h" +#endif // !defined(ORT_MINIMAL_BUILD) + +using namespace onnxruntime; + +#if !defined(ORT_MINIMAL_BUILD) + +// Wrapper class for OrtExternalResourceImporterImpl +namespace { +struct ExternalResourceImporterWrapper { + const OrtEpDevice* ep_device; + OrtExternalResourceImporterImpl* impl; + + ExternalResourceImporterWrapper(const OrtEpDevice* device, OrtExternalResourceImporterImpl* importer) + : ep_device(device), impl(importer) {} + + ~ExternalResourceImporterWrapper() { + if (impl && impl->Release) { + impl->Release(impl); + } + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ExternalResourceImporterWrapper); +}; + +} // namespace + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { + API_IMPL_BEGIN + if (ep_device == nullptr || out_importer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and out_importer must be provided."); + } + + *out_importer = nullptr; + + // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. + const auto* factory = ep_device->ep_factory; + if (factory == nullptr || + factory->ort_version_supported < 24 || + factory->CreateExternalResourceImporterForDevice == nullptr) { + // EP doesn't support external resource import - not an error, just return nullptr + return nullptr; + } + + OrtExternalResourceImporterImpl* impl = nullptr; + ORT_API_RETURN_IF_ERROR(factory->CreateExternalResourceImporterForDevice( + ep_device->GetMutableFactory(), + ep_device, + &impl)); + + if (impl == nullptr) { + // EP supports the factory method but returned null - not supported for this device + return nullptr; + } + + auto wrapper = std::make_unique(ep_device, impl); + *out_importer = reinterpret_cast(wrapper.release()); + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer) { +#if !defined(ORT_MINIMAL_BUILD) + delete reinterpret_cast(importer); +#else + ORT_UNUSED_PARAMETER(importer); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportMemory == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportMemory(wrapper->impl, handle_type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External memory import is not supported by this EP."); + } + + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalMemoryHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportMemory(wrapper->impl, desc, &handle)); + + if (handle == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportMemory returned null handle."); + } + + *out_handle = handle; + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle) { +#if !defined(ORT_MINIMAL_BUILD) + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); + } +#else + ORT_UNUSED_PARAMETER(handle); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* /*tensor_location*/, + _Outptr_ OrtValue** out_tensor) { + API_IMPL_BEGIN + if (importer == nullptr || mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, mem_handle, tensor_desc, and out_tensor must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->CreateTensorFromMemory == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateTensorFromMemory is not supported by this EP."); + } + + OrtValue* tensor = nullptr; + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->CreateTensorFromMemory(imp_wrapper->impl, mem_handle, tensor_desc, &tensor)); + + *out_tensor = tensor; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + if (importer == nullptr || out_supported == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer and out_supported must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->CanImportSemaphore == nullptr) { + *out_supported = false; + return nullptr; + } + + *out_supported = wrapper->impl->CanImportSemaphore(wrapper->impl, type); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) { + API_IMPL_BEGIN + if (importer == nullptr || desc == nullptr || out_handle == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, desc, and out_handle must be provided."); + } + + auto* wrapper = reinterpret_cast(importer); + if (wrapper->impl == nullptr || wrapper->impl->ImportSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "External semaphore import is not supported by this EP."); + } + + // EP creates derived type and returns base pointer. EP owns the handle lifetime. + OrtExternalSemaphoreHandle* handle = nullptr; + ORT_API_RETURN_IF_ERROR(wrapper->impl->ImportSemaphore(wrapper->impl, desc, &handle)); + + if (handle == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "ImportSemaphore returned null handle."); + } + + *out_handle = handle; + + return nullptr; + API_IMPL_END +} + +ORT_API(void, OrtInteropAPI::ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle) { +#if !defined(ORT_MINIMAL_BUILD) + if (handle != nullptr && handle->Release != nullptr) { + handle->Release(handle); + } +#else + ORT_UNUSED_PARAMETER(handle); +#endif // !defined(ORT_MINIMAL_BUILD) +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->WaitSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "WaitSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->WaitSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + if (importer == nullptr || semaphore_handle == nullptr || stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "importer, semaphore_handle, and stream must be provided."); + } + + auto* imp_wrapper = reinterpret_cast(importer); + + if (imp_wrapper->impl == nullptr || imp_wrapper->impl->SignalSemaphore == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SignalSemaphore is not supported by this EP."); + } + + ORT_API_RETURN_IF_ERROR(imp_wrapper->impl->SignalSemaphore(imp_wrapper->impl, semaphore_handle, stream, value)); + + return nullptr; + API_IMPL_END +} + +#else // defined(ORT_MINIMAL_BUILD) + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(ep_device); + ORT_UNUSED_PARAMETER(out_importer); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(handle_type); + ORT_UNUSED_PARAMETER(out_supported); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(desc); + ORT_UNUSED_PARAMETER(out_handle); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(mem_handle); + ORT_UNUSED_PARAMETER(tensor_desc); + ORT_UNUSED_PARAMETER(tensor_location); + ORT_UNUSED_PARAMETER(out_tensor); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(type); + ORT_UNUSED_PARAMETER(out_supported); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(desc); + ORT_UNUSED_PARAMETER(out_handle); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(semaphore_handle); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(value); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtInteropAPI::SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) { + API_IMPL_BEGIN + ORT_UNUSED_PARAMETER(importer); + ORT_UNUSED_PARAMETER(semaphore_handle); + ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(value); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Interop API is not supported in this build"); + API_IMPL_END +} + +#endif // !defined(ORT_MINIMAL_BUILD) + +static constexpr OrtInteropApi ort_interop_api = { + // NOTE: Application compatibility with newer versions of ORT depends on the Api order within this struct so + // all new functions must be added at the end, and no functions that already exist in an officially released version + // of ORT can be reordered or removed. + + &OrtInteropAPI::CreateExternalResourceImporterForDevice, + &OrtInteropAPI::ReleaseExternalResourceImporter, + &OrtInteropAPI::CanImportMemory, + &OrtInteropAPI::ImportMemory, + &OrtInteropAPI::ReleaseExternalMemoryHandle, + &OrtInteropAPI::CreateTensorFromMemory, + &OrtInteropAPI::CanImportSemaphore, + &OrtInteropAPI::ImportSemaphore, + &OrtInteropAPI::ReleaseExternalSemaphoreHandle, + &OrtInteropAPI::WaitSemaphore, + &OrtInteropAPI::SignalSemaphore, + // End of Version 24 - DO NOT MODIFY ABOVE +}; + +// Checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtInteropApi, SignalSemaphore) / sizeof(void*) == 10, + "Size of version 24 Api cannot change"); // initial version in ORT 1.24 + +ORT_API(const OrtInteropApi*, OrtInteropAPI::GetInteropApi) { + return &ort_interop_api; +} diff --git a/onnxruntime/core/session/ep_interop_api.h b/onnxruntime/core/session/ep_interop_api.h new file mode 100644 index 0000000000000..25ac7b58af8c0 --- /dev/null +++ b/onnxruntime/core/session/ep_interop_api.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/session/onnxruntime_c_api.h" + +namespace OrtInteropAPI { + +// implementation that returns the API struct +ORT_API(const OrtInteropApi*, GetInteropApi); + +ORT_API_STATUS_IMPL(CreateExternalResourceImporterForDevice, _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporter** out_importer); + +ORT_API(void, ReleaseExternalResourceImporter, _Frees_ptr_opt_ OrtExternalResourceImporter* importer); + +ORT_API_STATUS_IMPL(CanImportMemory, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalMemoryHandleType handle_type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ImportMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle); + +ORT_API(void, ReleaseExternalMemoryHandle, _Frees_ptr_opt_ OrtExternalMemoryHandle* handle); + +ORT_API_STATUS_IMPL(CreateTensorFromMemory, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _In_opt_ const OrtMemoryInfo* tensor_location, + _Outptr_ OrtValue** out_tensor); + +ORT_API_STATUS_IMPL(CanImportSemaphore, _In_ const OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreType type, + _Out_ bool* out_supported); + +ORT_API_STATUS_IMPL(ImportSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle); + +ORT_API(void, ReleaseExternalSemaphoreHandle, _Frees_ptr_opt_ OrtExternalSemaphoreHandle* handle); + +ORT_API_STATUS_IMPL(WaitSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +ORT_API_STATUS_IMPL(SignalSemaphore, _In_ OrtExternalResourceImporter* importer, + _In_ OrtExternalSemaphoreHandle* semaphore_handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value); + +} // namespace OrtInteropAPI diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5523dc78b5d2..e9b0c2f230263 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3480,6 +3480,48 @@ common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector& ep_devices) const { + ep_devices.clear(); + +#if defined(ORT_MINIMAL_BUILD) + return common::Status(common::ONNXRUNTIME, common::FAIL, + "GetEpDeviceForOutputs is not available in a minimal build."); +#else + if (!is_inited_) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized."); + } + + std::pair outputs = GetModelOutputs(); + + ORT_RETURN_IF_ERROR(outputs.first); + + const auto& def_list = *outputs.second; + ep_devices.reserve(def_list.size()); + + const auto& available_eps = environment_.GetOrtEpDevices(); + + for (const auto* def : def_list) { + InlinedVector node_info_vec; + ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + assert(!node_info_vec.empty()); + // If we have an output that is not produced by any node, + // then we return nullptr. + const auto* p_node = node_info_vec.front().p_node; + if (p_node != nullptr) { + const auto ep_name = p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } else { + ep_devices.push_back(nullptr); + } + } + + return Status::OK(); +#endif +} + common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { { std::lock_guard l(session_mutex_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 8bea15c169ed4..ff54f6fa7bca0 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -484,6 +484,15 @@ class InferenceSession { * This is required for a user to know the location of the input/output when autoep selection is enabled. */ common::Status GetEpDeviceForInputs(InlinedVector& memory_info) const; + + /** + * Get the OrtEpDevice (if available) for the outputs of the model. + * + * This is required for a user to validate that outputs will be placed on the expected device + * for external resource sharing. + */ + common::Status GetEpDeviceForOutputs(InlinedVector& memory_info) const; + /** * Get the current number of in-progress concurrent Run calls. */ diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 82f7cef4aec49..a0e090a4d70ad 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,6 +38,7 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" +#include "core/session/ep_interop_api.h" #include "core/session/plugin_ep/ep_api.h" #include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" @@ -3372,6 +3373,34 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession* ort_session, + _Out_writes_(num_values) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_values) { + API_IMPL_BEGIN + if (ort_session == nullptr || outputs_ep_devices == nullptr || num_values == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to SessionGetEpDeviceForOutputs."); + } + + auto session = reinterpret_cast(ort_session); + + InlinedVector ep_devices; + + ORT_API_RETURN_IF_STATUS_NOT_OK(session->GetEpDeviceForOutputs(ep_devices)); + + auto num_found = ep_devices.size(); + if (num_found > num_values) { + auto msg = MakeString("Number of outputs ", num_found, " exceeds the provided size of ", num_values); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, msg.c_str()); + } + + for (size_t i = 0; i < num_values; ++i) { + outputs_ep_devices[i] = (i < num_found) ? ep_devices[i] : nullptr; + } + + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStream** ort_stream) { @@ -3536,6 +3565,11 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, API_IMPL_END } +// GetInteropApi - returns the Interop API struct +ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { + return OrtInteropAPI::GetInteropApi(); +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { @@ -3621,6 +3655,19 @@ ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, API_IMPL_END } +ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { + fprintf(stderr, "The Interop API is not supported in a minimal build.\n"); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForOutputs, _In_ const OrtSession* /*ort_session*/, + _Out_writes_(num_values) const OrtEpDevice** /*outputs_ep_devices*/, + _In_ size_t /*num_values*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SessionGetEpDeviceForOutputs is not supported in a minimal build."); + API_IMPL_END +} + #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -4238,6 +4285,9 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::TensorTypeAndShape_HasShape, &OrtApis::KernelInfo_GetConfigEntries, + + &OrtApis::GetInteropApi, + &OrtApis::SessionGetEpDeviceForOutputs, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f3525d8de7b95..3a82a4ca3f362 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -755,4 +755,11 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); +// Interop API +ORT_API(const OrtInteropApi*, GetInteropApi); + +ORT_API_STATUS_IMPL(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, + _In_ size_t num_outputs); + } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 364bab471ddbe..8dc92802aa84b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -33,6 +33,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::IsStreamAware = Forward::IsStreamAware; OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions; + OrtEpFactory::CreateExternalResourceImporterForDevice = Forward::CreateExternalResourceImporterForDevice; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 6eb83a117fb63..ae98f2c0ac589 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -91,6 +91,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->SetEnvironmentOptions(options); } + OrtStatus* CreateExternalResourceImporterForDevice(_In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + return impl_->CreateExternalResourceImporterForDevice(ep_device, importer); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index de9e2d44431bf..20a47715df2b8 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -88,6 +88,14 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* CreateExternalResourceImporterForDevice( + _In_ const OrtEpDevice* /*ep_device*/, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + // Default implementation does not support external resource import + *importer = nullptr; + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 8c5ef526baba1..26173f0055ed7 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -65,6 +65,18 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); } + OrtStatus* CreateExternalResourceImporterForDevice( + const OrtEpDevice* ep_device, + OrtExternalResourceImporterImpl** importer) noexcept override { + // OrtEpFactory::CreateExternalResourceImporterForDevice was added in ORT 1.24. + if (ep_factory_.ort_version_supported < 24 || + ep_factory_.CreateExternalResourceImporterForDevice == nullptr) { + *importer = nullptr; + return nullptr; + } + return ep_factory_.CreateExternalResourceImporterForDevice(&ep_factory_, ep_device, importer); + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 65c396181f0a7..2530ae8eb3c2b 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -87,6 +87,13 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->SetEnvironmentOptions(options); } + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDevice( + _In_ OrtEpFactory* this_ptr, + _In_ const OrtEpDevice* ep_device, + _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** importer) noexcept { + return static_cast(this_ptr)->CreateExternalResourceImporterForDevice(ep_device, importer); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc new file mode 100644 index 0000000000000..c3d413ac82ab9 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc @@ -0,0 +1,272 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_external_resource_importer.h" + +#include +#include +#include +#include +#include + +ExampleExternalResourceImporter::ExampleExternalResourceImporter(const ApiPtrs& apis) + : OrtExternalResourceImporterImpl{}, apis_{apis} { + ort_version_supported = ORT_API_VERSION; + + // Memory operations + CanImportMemory = CanImportMemoryImpl; + ImportMemory = ImportMemoryImpl; + ReleaseMemory = ReleaseMemoryImpl; + CreateTensorFromMemory = CreateTensorFromMemoryImpl; + + // Semaphore operations + CanImportSemaphore = CanImportSemaphoreImpl; + ImportSemaphore = ImportSemaphoreImpl; + ReleaseSemaphore = ReleaseSemaphoreImpl; + WaitSemaphore = WaitSemaphoreImpl; + SignalSemaphore = SignalSemaphoreImpl; + + // Release + Release = ReleaseImpl; +} + +/*static*/ +bool ORT_API_CALL ExampleExternalResourceImporter::CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalMemoryHandleType handle_type) noexcept { + // The example EP supports both D3D12 resource and heap handle types for testing + return handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE || + handle_type == ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportMemory"); + } + + *out_handle = nullptr; + + // Validate handle type + if (!CanImportMemoryImpl(this_ptr, desc->handle_type)) { + return impl.apis_.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external memory handle type"); + } + + // In a real implementation, you would: + // 1. Open/import the native handle (e.g., cuImportExternalMemory for CUDA) + // 2. Map the memory to get a device pointer + // + // For testing purposes, we simulate this by allocating CPU memory + // that mirrors the size of the external allocation. + + auto* handle = new (std::nothrow) ExampleExternalMemoryHandle(); + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external memory handle"); + } + + // Allocate simulated memory (using CPU memory for the example) + size_t effective_size = desc->size_bytes - desc->offset_bytes; + handle->simulated_ptr = std::make_unique(effective_size); + + handle->size_bytes = desc->size_bytes; + handle->offset_bytes = desc->offset_bytes; + handle->handle_type = desc->handle_type; + handle->access_mode = desc->access_mode; + + *out_handle = handle; + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalMemoryHandle* handle) noexcept { + if (handle == nullptr) { + return; + } + + auto* mem_handle = static_cast(handle); + delete mem_handle; // destructor frees simulated_ptr +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept { + auto& impl = *static_cast(this_ptr); + + if (mem_handle == nullptr || tensor_desc == nullptr || out_tensor == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to CreateTensorFromMemory"); + } + + *out_tensor = nullptr; + + auto* handle = static_cast(mem_handle); + + // Calculate the data pointer with tensor offset + void* data_ptr = handle->simulated_ptr.get() + tensor_desc->offset_bytes; + + // For the example EP, we use CPU memory info since we're simulating with CPU memory + // In a real implementation, you would use the appropriate GPU memory info + OrtMemoryInfo* memory_info = nullptr; + OrtStatus* status = impl.apis_.ort_api.CreateMemoryInfo( + "Cpu", // For testing, we use CPU memory + OrtDeviceAllocator, + 0, // device ID + OrtMemTypeDefault, + &memory_info); + + if (status != nullptr) { + return status; + } + + // Calculate buffer size + size_t buffer_size = handle->size_bytes - handle->offset_bytes - tensor_desc->offset_bytes; + + // Create tensor with pre-allocated memory + status = impl.apis_.ort_api.CreateTensorWithDataAsOrtValue( + memory_info, + data_ptr, + buffer_size, + tensor_desc->shape, + tensor_desc->rank, + tensor_desc->element_type, + out_tensor); + + impl.apis_.ort_api.ReleaseMemoryInfo(memory_info); + return status; +} + +/*static*/ +bool ORT_API_CALL ExampleExternalResourceImporter::CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalSemaphoreType type) noexcept { + // The example EP supports D3D12 fence for testing + return type == ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept { + auto& impl = *static_cast(this_ptr); + + if (desc == nullptr || out_handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to ImportSemaphore"); + } + + *out_handle = nullptr; + + // Validate semaphore type + if (!CanImportSemaphoreImpl(this_ptr, desc->type)) { + return impl.apis_.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "Unsupported external semaphore type"); + } + + // In a real implementation, you would: + // 1. Import the native fence handle (e.g., cuImportExternalSemaphore for CUDA) + // + // For testing purposes, we create a simulated semaphore using an atomic counter + + auto* handle = new (std::nothrow) ExampleExternalSemaphoreHandle(); + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external semaphore handle"); + } + + handle->type = desc->type; + handle->value.store(0); + + *out_handle = handle; + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* /*this_ptr*/, + _In_ OrtExternalSemaphoreHandle* handle) noexcept { + if (handle == nullptr) { + return; + } + + auto* sem_handle = static_cast(handle); + delete sem_handle; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to WaitSemaphore"); + } + + // stream can be nullptr for synchronous wait + (void)stream; + + auto* sem_handle = static_cast(handle); + + // In a real implementation, you would: + // 1. Queue a wait operation on the GPU stream (e.g., cuWaitExternalSemaphoresAsync) + // + // For testing, we do a simple spin-wait on the atomic counter + // with a reasonable timeout to prevent infinite loops in tests + + const int max_iterations = 10000; + int iterations = 0; + while (sem_handle->value.load() < value && iterations < max_iterations) { + std::this_thread::sleep_for(std::chrono::microseconds(100)); + ++iterations; + } + + if (iterations >= max_iterations) { + return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "WaitSemaphore timed out"); + } + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept { + auto& impl = *static_cast(this_ptr); + + if (handle == nullptr) { + return impl.apis_.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments to SignalSemaphore"); + } + + // stream can be nullptr for synchronous signal + (void)stream; + + auto* sem_handle = static_cast(handle); + + // In a real implementation, you would: + // 1. Queue a signal operation on the GPU stream (e.g., cuSignalExternalSemaphoresAsync) + // + // For testing, we simply update the atomic counter + + sem_handle->value.store(value); + + return nullptr; +} + +/*static*/ +void ORT_API_CALL ExampleExternalResourceImporter::ReleaseImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h new file mode 100644 index 0000000000000..06b003cd3feaa --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../plugin_ep_utils.h" + +#include +#include +#include + +/** + * @brief Example derived handle for imported external memory. + * + * Derives from OrtExternalMemoryHandle and adds example-specific fields. + * This mock implementation simulates imported external memory for testing purposes. + * In a real EP, this would hold a GPU-mapped pointer from an imported D3D12/Vulkan/CUDA resource. + */ +struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { + std::unique_ptr simulated_ptr; ///< Simulated mapped pointer (CPU memory for testing) + OrtExternalMemoryAccessMode access_mode; ///< Access mode for the imported memory + + ExampleExternalMemoryHandle() + : simulated_ptr(nullptr), access_mode(ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE) { + // Initialize base struct fields + version = ORT_API_VERSION; + ep_device = nullptr; + handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + size_bytes = 0; + offset_bytes = 0; + Release = ReleaseCallback; + } + + ~ExampleExternalMemoryHandle() = default; + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalMemoryHandle* handle) noexcept { + if (handle == nullptr) return; + delete static_cast(handle); + } +}; + +/** + * @brief Example derived handle for imported external semaphore. + * + * Derives from OrtExternalSemaphoreHandle and adds example-specific fields. + * This mock implementation simulates imported external semaphores for testing purposes. + * In a real EP, this would hold an imported D3D12 fence / Vulkan semaphore / CUDA external semaphore. + */ +struct ExampleExternalSemaphoreHandle : OrtExternalSemaphoreHandle { + std::atomic value; ///< Simulated fence value for testing + + ExampleExternalSemaphoreHandle() + : value(0) { + // Initialize base struct fields + version = ORT_API_VERSION; + ep_device = nullptr; + type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + Release = ReleaseCallback; + } + + static void ORT_API_CALL ReleaseCallback(_In_ OrtExternalSemaphoreHandle* handle) noexcept { + if (handle == nullptr) return; + delete static_cast(handle); + } +}; + +/** + * @brief Example implementation of OrtExternalResourceImporterImpl. + * + * This is a mock implementation that simulates external resource interop for testing + * the ORT public API without requiring actual D3D12/CUDA/Vulkan hardware. + * + * Key features: + * - Reports support for D3D12 resource/heap and fence handle types + * - Creates simulated memory mappings using CPU memory + * - Simulates fence wait/signal operations + * - Allows tensor creation from "imported" memory + */ +class ExampleExternalResourceImporter : public OrtExternalResourceImporterImpl { + public: + ExampleExternalResourceImporter(const ApiPtrs& apis); + + static bool ORT_API_CALL CanImportMemoryImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandleType handle_type) noexcept; + + static OrtStatus* ORT_API_CALL ImportMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryDescriptor* desc, + _Outptr_ OrtExternalMemoryHandle** out_handle) noexcept; + + static void ORT_API_CALL ReleaseMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalMemoryHandle* handle) noexcept; + + static OrtStatus* ORT_API_CALL CreateTensorFromMemoryImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalMemoryHandle* mem_handle, + _In_ const OrtExternalTensorDescriptor* tensor_desc, + _Outptr_ OrtValue** out_tensor) noexcept; + + static bool ORT_API_CALL CanImportSemaphoreImpl( + _In_ const OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreType type) noexcept; + + static OrtStatus* ORT_API_CALL ImportSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ const OrtExternalSemaphoreDescriptor* desc, + _Outptr_ OrtExternalSemaphoreHandle** out_handle) noexcept; + + static void ORT_API_CALL ReleaseSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle) noexcept; + + static OrtStatus* ORT_API_CALL WaitSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept; + + static OrtStatus* ORT_API_CALL SignalSemaphoreImpl( + _In_ OrtExternalResourceImporterImpl* this_ptr, + _In_ OrtExternalSemaphoreHandle* handle, + _In_ OrtSyncStream* stream, + _In_ uint64_t value) noexcept; + + static void ORT_API_CALL ReleaseImpl(_In_ OrtExternalResourceImporterImpl* this_ptr) noexcept; + + private: + ApiPtrs apis_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 85f9504b14a4d..fd652b8882df9 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -37,6 +37,8 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", @@ -308,3 +310,22 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac return nullptr; } + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtEpDevice* /*ep_device*/, + OrtExternalResourceImporterImpl** out_importer) noexcept { + auto& factory = *static_cast(this_ptr); + + if (out_importer == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "out_importer cannot be nullptr"); + } + + // Create the external resource importer + auto importer = std::make_unique(factory); + *out_importer = importer.release(); + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 196e67fc5c558..230fdef772e2f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -7,6 +7,7 @@ #include "ep_arena.h" #include "ep_data_transfer.h" +#include "ep_external_resource_importer.h" #include "../plugin_ep_utils.h" /// @@ -67,6 +68,11 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept; + static OrtStatus* ORT_API_CALL CreateExternalResourceImporterForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtEpDevice* ep_device, + OrtExternalResourceImporterImpl** out_importer) noexcept; + const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name diff --git a/onnxruntime/test/autoep/test_external_resource_importer.cc b/onnxruntime/test/autoep/test_external_resource_importer.cc new file mode 100644 index 0000000000000..06a818bf4805f --- /dev/null +++ b/onnxruntime/test/autoep/test_external_resource_importer.cc @@ -0,0 +1,393 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Tests for the External Resource Interop API using the example_plugin_ep. +// This tests the public ORT API without requiring actual D3D12/CUDA hardware. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +class ExternalResourceImporterTest : public ::testing::Test { + protected: + void SetUp() override { + // Register the example EP and get the device using shared utility + Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, registered_ep_); + ASSERT_NE(registered_ep_.get(), nullptr) << "Example EP device not found"; + ep_device_ = registered_ep_.get(); + } + + const OrtInteropApi& GetInteropApi() const { + return Ort::GetInteropApi(); + } + + RegisteredEpDeviceUniquePtr registered_ep_; + const OrtEpDevice* ep_device_ = nullptr; +}; + +// Test: Create External Resource Importer +TEST_F(ExternalResourceImporterTest, CreateExternalResourceImporter) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + + if (status != nullptr) { + std::string error = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "CreateExternalResourceImporterForDevice not supported: " << error; + } + + ASSERT_NE(importer, nullptr) << "External resource importer should not be null"; + + // Release the importer + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Memory Import Capability +TEST_F(ExternalResourceImporterTest, CanImportMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Check D3D12 Resource support + bool can_import_resource = false; + status = GetInteropApi().CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import_resource); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_RESOURCE should succeed"; + EXPECT_TRUE(can_import_resource) << "Example EP should support D3D12 Resource import"; + + // Check D3D12 Heap support + bool can_import_heap = false; + status = GetInteropApi().CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, &can_import_heap); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_HEAP should succeed"; + EXPECT_TRUE(can_import_heap) << "Example EP should support D3D12 Heap import"; + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Semaphore Import Capability +TEST_F(ExternalResourceImporterTest, CanImportSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Check D3D12 Fence support + bool can_import_fence = false; + status = GetInteropApi().CanImportSemaphore( + importer, ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE, &can_import_fence); + ASSERT_EQ(status, nullptr) << "CanImportSemaphore for D3D12_FENCE should succeed"; + EXPECT_TRUE(can_import_fence) << "Example EP should support D3D12 Fence import"; + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Import Memory (Simulated) +TEST_F(ExternalResourceImporterTest, ImportMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Import memory (using a dummy handle for testing) + const size_t buffer_size = 1024 * sizeof(float); + void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); // Simulated handle + + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = dummy_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr) << "ImportMemory should succeed"; + ASSERT_NE(mem_handle, nullptr) << "Memory handle should not be null"; + + // Release memory handle + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Create Tensor from Imported Memory +TEST_F(ExternalResourceImporterTest, CreateTensorFromMemory) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Create tensor shape: [1, 3, 32, 32] + const int64_t batch = 1, channels = 3, height = 32, width = 32; + const int64_t shape[] = {batch, channels, height, width}; + const size_t num_elements = batch * channels * height * width; + const size_t buffer_size = num_elements * sizeof(float); + + // Import memory + void* dummy_handle = reinterpret_cast(static_cast(0x12345678)); + + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = dummy_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr); + + // Create tensor from imported memory + OrtExternalTensorDescriptor tensor_desc = {}; + tensor_desc.version = ORT_API_VERSION; + tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + tensor_desc.shape = shape; + tensor_desc.rank = 4; + tensor_desc.offset_bytes = 0; + + OrtValue* tensor = nullptr; + status = GetInteropApi().CreateTensorFromMemory( + importer, mem_handle, &tensor_desc, nullptr, &tensor); + ASSERT_EQ(status, nullptr) << "CreateTensorFromMemory should succeed"; + ASSERT_NE(tensor, nullptr) << "Tensor should not be null"; + + // Verify tensor properties + OrtTensorTypeAndShapeInfo* type_info = nullptr; + status = Ort::GetApi().GetTensorTypeAndShape(tensor, &type_info); + ASSERT_EQ(status, nullptr); + + size_t rank = 0; + status = Ort::GetApi().GetDimensionsCount(type_info, &rank); + ASSERT_EQ(status, nullptr); + EXPECT_EQ(rank, 4u); + + std::vector actual_shape(rank); + status = Ort::GetApi().GetDimensions(type_info, actual_shape.data(), rank); + ASSERT_EQ(status, nullptr); + EXPECT_EQ(actual_shape[0], batch); + EXPECT_EQ(actual_shape[1], channels); + EXPECT_EQ(actual_shape[2], height); + EXPECT_EQ(actual_shape[3], width); + + ONNXTensorElementDataType elem_type; + status = Ort::GetApi().GetTensorElementType(type_info, &elem_type); + ASSERT_EQ(status, nullptr); + EXPECT_EQ(elem_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + Ort::GetApi().ReleaseTensorTypeAndShapeInfo(type_info); + + // Cleanup + Ort::GetApi().ReleaseValue(tensor); + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Import Semaphore (Simulated) +TEST_F(ExternalResourceImporterTest, ImportSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Import semaphore (using a dummy handle for testing) + void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_API_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = dummy_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = GetInteropApi().ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; + ASSERT_NE(sem_handle, nullptr) << "Semaphore handle should not be null"; + + // Release semaphore handle + GetInteropApi().ReleaseExternalSemaphoreHandle(sem_handle); + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Wait and Signal Semaphore (Simulated) +TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + // Create a stream for the EP + OrtSyncStream* stream = nullptr; + status = Ort::GetApi().CreateSyncStreamForEpDevice(ep_device_, nullptr, &stream); + ASSERT_EQ(status, nullptr) << "CreateSyncStreamForEpDevice should succeed"; + + // Import semaphore + void* dummy_handle = reinterpret_cast(static_cast(0xABCDEF00)); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_API_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = dummy_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = GetInteropApi().ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr); + + // Signal the semaphore with value 1 + status = GetInteropApi().SignalSemaphore(importer, sem_handle, stream, 1); + ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; + + // Wait for value 1 (should succeed immediately since we just signaled it) + status = GetInteropApi().WaitSemaphore(importer, sem_handle, stream, 1); + ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; + + // Signal with value 5 + status = GetInteropApi().SignalSemaphore(importer, sem_handle, stream, 5); + ASSERT_EQ(status, nullptr); + + // Wait for value 3 (should succeed since current value is 5) + status = GetInteropApi().WaitSemaphore(importer, sem_handle, stream, 3); + ASSERT_EQ(status, nullptr); + + // Cleanup + GetInteropApi().ReleaseExternalSemaphoreHandle(sem_handle); + Ort::GetApi().ReleaseSyncStream(stream); + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Multiple Memory Imports +TEST_F(ExternalResourceImporterTest, MultipleMemoryImports) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + constexpr int kNumBuffers = 5; + std::vector handles(kNumBuffers); + + // Import multiple memory regions + for (int i = 0; i < kNumBuffers; ++i) { + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = reinterpret_cast(static_cast(0x10000000 + i * 0x1000)); + mem_desc.size_bytes = (i + 1) * 1024; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + status = GetInteropApi().ImportMemory(importer, &mem_desc, &handles[i]); + ASSERT_EQ(status, nullptr) << "ImportMemory " << i << " should succeed"; + ASSERT_NE(handles[i], nullptr); + } + + // Release all handles + for (int i = 0; i < kNumBuffers; ++i) { + GetInteropApi().ReleaseExternalMemoryHandle(handles[i]); + } + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: Access Mode Variations +TEST_F(ExternalResourceImporterTest, AccessModeVariations) { + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = GetInteropApi().CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr) { + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "External resource interop not supported"; + } + + const OrtExternalMemoryAccessMode access_modes[] = { + ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE, + ORT_EXTERNAL_MEMORY_ACCESS_READ_ONLY, + ORT_EXTERNAL_MEMORY_ACCESS_WRITE_ONLY}; + + for (auto access_mode : access_modes) { + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = reinterpret_cast(static_cast(0x12345678)); + mem_desc.size_bytes = 4096; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = access_mode; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = GetInteropApi().ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr) << "ImportMemory with access_mode " << access_mode << " should succeed"; + ASSERT_NE(mem_handle, nullptr); + + GetInteropApi().ReleaseExternalMemoryHandle(mem_handle); + } + + GetInteropApi().ReleaseExternalResourceImporter(importer); +} + +// Test: SessionGetEpDeviceForOutputs +TEST_F(ExternalResourceImporterTest, SessionGetEpDeviceForOutputs) { + // Load a simple model with the example EP + Ort::SessionOptions session_options; + + // Add the example EP to the session + const OrtEpDevice* devices[] = {ep_device_}; + OrtStatus* status = Ort::GetApi().SessionOptionsAppendExecutionProvider_V2( + session_options, *ort_env, devices, 1, nullptr, nullptr, 0); + if (status != nullptr) { + std::string error = Ort::GetApi().GetErrorMessage(status); + Ort::GetApi().ReleaseStatus(status); + GTEST_SKIP() << "Example EP not available: " << error; + } + + // Create session with test model (mul_1.onnx - a simple model the example EP supports) + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + + // Get output count + size_t num_outputs = session.GetOutputCount(); + ASSERT_GT(num_outputs, 0U) << "Model should have at least one output"; + + // Get EP devices for outputs + std::vector output_devices(num_outputs); + status = Ort::GetApi().SessionGetEpDeviceForOutputs( + session, output_devices.data(), num_outputs); + ASSERT_EQ(status, nullptr) << "SessionGetEpDeviceForOutputs should succeed"; + + // Validate that we got EP devices (may be nullptr if not assigned to EP) + // At least verify the call succeeded and returned valid array + for (size_t i = 0; i < num_outputs; ++i) { + if (output_devices[i] != nullptr) { + // If an EP device is returned, validate it has a name + const char* ep_name = Ort::GetApi().EpDevice_EpName(output_devices[i]); + ASSERT_NE(ep_name, nullptr) << "EP device should have a name"; + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc new file mode 100644 index 0000000000000..6cc959dd3faa7 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_external_resource_importer_test.cc @@ -0,0 +1,847 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This test validates the D3D12 ↔ CUDA external resource import functionality +// for the NvTensorRtRtx execution provider. +// +// Test Coverage: +// 1. External Resource Importer creation and destruction +// 2. Memory import capability check (D3D12 Resource & Heap) +// 3. Semaphore import capability check (D3D12 Fence) +// 4. D3D12 shared resource import to CUDA +// 5. Tensor creation from imported external memory +// 6. D3D12 timeline fence import for GPU synchronization +// 7. Wait/Signal semaphore operations +// 8. Full inference pipeline with zero-copy external memory + +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/providers/provider_test_utils.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" +#include "test/unittest_util/framework_test_utils.h" +#include "test/util/include/scoped_env_vars.h" +#include "test/common/random_generator.h" + +#include +#include +#include +#include + +#if defined(_WIN32) +#include +#include +#include +using Microsoft::WRL::ComPtr; +#endif + +// Include CUDA headers for pointer attribute verification +#include + +using namespace std; +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::logging; + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +#if defined(_WIN32) + +// Helper functions for D3D12 resource creation +class D3D12ResourceHelper { + public: + static void CreateSharedBuffer(ID3D12Device* device, + size_t size, + ID3D12Resource** out_resource, + D3D12_RESOURCE_STATES initial_state = D3D12_RESOURCE_STATE_COMMON) { + D3D12_RESOURCE_DESC desc = {}; + desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + desc.Alignment = 0; + desc.Width = size; + desc.Height = 1; + desc.DepthOrArraySize = 1; + desc.MipLevels = 1; + desc.Format = DXGI_FORMAT_UNKNOWN; + desc.SampleDesc.Count = 1; + desc.SampleDesc.Quality = 0; + desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + desc.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS; + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_DEFAULT; + heap_props.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; + heap_props.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; + heap_props.CreationNodeMask = 1; + heap_props.VisibleNodeMask = 1; + + // Create with SHARED heap flag for cross-API import + HRESULT hr = device->CreateCommittedResource( + &heap_props, + D3D12_HEAP_FLAG_SHARED, + &desc, + initial_state, + nullptr, + IID_PPV_ARGS(out_resource)); + + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create shared D3D12 buffer, HRESULT: 0x" << std::hex << hr; + } + } + + static void CreateUploadBuffer(ID3D12Device* device, size_t size, ID3D12Resource** out_resource) { + D3D12_RESOURCE_DESC desc = {}; + desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + desc.Alignment = 0; + desc.Width = size; + desc.Height = 1; + desc.DepthOrArraySize = 1; + desc.MipLevels = 1; + desc.Format = DXGI_FORMAT_UNKNOWN; + desc.SampleDesc.Count = 1; + desc.SampleDesc.Quality = 0; + desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + desc.Flags = D3D12_RESOURCE_FLAG_NONE; + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_UPLOAD; + + HRESULT hr = device->CreateCommittedResource( + &heap_props, + D3D12_HEAP_FLAG_NONE, + &desc, + D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, + IID_PPV_ARGS(out_resource)); + + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create upload buffer, HRESULT: 0x" << std::hex << hr; + } + } + + static void CreateReadbackBuffer(ID3D12Device* device, size_t size, ID3D12Resource** out_resource) { + D3D12_RESOURCE_DESC desc = {}; + desc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + desc.Alignment = 0; + desc.Width = size; + desc.Height = 1; + desc.DepthOrArraySize = 1; + desc.MipLevels = 1; + desc.Format = DXGI_FORMAT_UNKNOWN; + desc.SampleDesc.Count = 1; + desc.SampleDesc.Quality = 0; + desc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + desc.Flags = D3D12_RESOURCE_FLAG_NONE; + + D3D12_HEAP_PROPERTIES heap_props = {}; + heap_props.Type = D3D12_HEAP_TYPE_READBACK; + + HRESULT hr = device->CreateCommittedResource( + &heap_props, + D3D12_HEAP_FLAG_NONE, + &desc, + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_PPV_ARGS(out_resource)); + + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create readback buffer, HRESULT: 0x" << std::hex << hr; + } + } + + static void FlushAndWait(ID3D12Device* device, ID3D12CommandQueue* queue) { + ComPtr fence; + HRESULT hr = device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&fence)); + if (FAILED(hr)) { + GTEST_FAIL() << "Failed to create fence for flush, HRESULT: 0x" << std::hex << hr; + } + + HANDLE event = CreateEvent(nullptr, FALSE, FALSE, nullptr); + queue->Signal(fence.Get(), 1); + fence->SetEventOnCompletion(1, event); + WaitForSingleObject(event, INFINITE); + CloseHandle(event); + } +}; + +// Test Fixture +class ExternalResourceImporterTest : public testing::Test { + protected: + void SetUp() override { + // Get the ORT API + ort_api_ = &Ort::GetApi(); + + // Try to create D3D12 device + HRESULT hr = D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + d3d12_available_ = true; + + // Create command queue + D3D12_COMMAND_QUEUE_DESC queue_desc = {}; + queue_desc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE; + queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; + hr = d3d12_device_->CreateCommandQueue(&queue_desc, IID_PPV_ARGS(&command_queue_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + + // Create command allocator and list + hr = d3d12_device_->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, + IID_PPV_ARGS(&command_allocator_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + + hr = d3d12_device_->CreateCommandList(0, D3D12_COMMAND_LIST_TYPE_COMPUTE, + command_allocator_.Get(), nullptr, + IID_PPV_ARGS(&command_list_)); + if (FAILED(hr)) { + d3d12_available_ = false; + return; + } + command_list_->Close(); + + // Register NvTensorRtRtx EP + Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, registered_ep_); + if (registered_ep_.get() == nullptr) { + ep_available_ = false; + return; + } + ep_available_ = true; + ep_device_ = registered_ep_.get(); + } + + void TearDown() override { + // Release resources + command_list_.Reset(); + command_allocator_.Reset(); + command_queue_.Reset(); + d3d12_device_.Reset(); + } + + bool IsD3D12Available() const { return d3d12_available_; } + bool IsEPAvailable() const { return ep_available_; } + + ComPtr d3d12_device_; + ComPtr command_queue_; + ComPtr command_allocator_; + ComPtr command_list_; + const OrtApi* ort_api_ = nullptr; + RegisteredEpDeviceUniquePtr registered_ep_; // RAII - auto-unregisters EP + const OrtEpDevice* ep_device_ = nullptr; + bool d3d12_available_ = false; + bool ep_available_ = false; +}; + +// Test: External Resource Importer Creation +TEST_F(ExternalResourceImporterTest, CreateExternalResourceImporter) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + // Create external resource importer + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + + if (status != nullptr) { + std::string error = ort_api_->GetErrorMessage(status); + ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "CreateExternalResourceImporterForDevice not supported: " << error; + } + + if (importer == nullptr) { + // EP doesn't support external resource import yet + GTEST_SKIP() << "External resource import not yet implemented by this EP"; + } + + // Release the importer + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Memory Import Capability Check +TEST_F(ExternalResourceImporterTest, CanImportMemoryCapabilities) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Check D3D12 Resource support + bool can_import_resource = false; + status = ort_api_->ExternalResourceImporter_CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, &can_import_resource); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_RESOURCE should succeed"; + EXPECT_TRUE(can_import_resource) << "Should support D3D12 Resource import"; + + // Check D3D12 Heap support + bool can_import_heap = false; + status = ort_api_->ExternalResourceImporter_CanImportMemory( + importer, ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, &can_import_heap); + ASSERT_EQ(status, nullptr) << "CanImportMemory for D3D12_HEAP should succeed"; + EXPECT_TRUE(can_import_heap) << "Should support D3D12 Heap import"; + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Semaphore Import Capability Check +TEST_F(ExternalResourceImporterTest, CanImportSemaphoreCapabilities) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Check D3D12 Fence support + bool can_import_fence = false; + status = ort_api_->ExternalResourceImporter_CanImportSemaphore( + importer, ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE, &can_import_fence); + ASSERT_EQ(status, nullptr) << "CanImportSemaphore for D3D12_FENCE should succeed"; + EXPECT_TRUE(can_import_fence) << "Should support D3D12 Fence import"; + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Import D3D12 Shared Resource +TEST_F(ExternalResourceImporterTest, ImportD3D12SharedResource) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create a shared D3D12 buffer + const size_t buffer_size = 1024 * sizeof(float); + ComPtr d3d12_buffer; + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &d3d12_buffer); + + // Create shared handle + HANDLE shared_handle = nullptr; + HRESULT hr = d3d12_device_->CreateSharedHandle(d3d12_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)) << "Failed to create shared handle"; + + // Import the memory + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = shared_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr) << "ImportMemory should succeed"; + ASSERT_NE(mem_handle, nullptr) << "Memory handle should not be null"; + + // Release memory handle + ort_api_->ReleaseExternalMemoryHandle(mem_handle); + + // Close shared handle + CloseHandle(shared_handle); + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Create Tensor from Imported Memory +TEST_F(ExternalResourceImporterTest, CreateTensorFromImportedMemory) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create tensor shape: [1, 3, 32, 32] (batch, channels, height, width) + const int64_t batch = 1, channels = 3, height = 32, width = 32; + const int64_t shape[] = {batch, channels, height, width}; + const size_t num_elements = batch * channels * height * width; + const size_t buffer_size = num_elements * sizeof(float); + + // Create shared D3D12 buffer + ComPtr d3d12_buffer; + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &d3d12_buffer); + + // Create shared handle + HANDLE shared_handle = nullptr; + HRESULT hr = d3d12_device_->CreateSharedHandle(d3d12_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)); + + // Import the memory + OrtExternalMemoryDescriptor mem_desc = {}; + mem_desc.version = ORT_API_VERSION; + mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + mem_desc.native_handle = shared_handle; + mem_desc.size_bytes = buffer_size; + mem_desc.offset_bytes = 0; + mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryHandle* mem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &mem_desc, &mem_handle); + ASSERT_EQ(status, nullptr); + + // Create tensor from imported memory + OrtExternalTensorDescriptor tensor_desc = {}; + tensor_desc.version = ORT_API_VERSION; + tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + tensor_desc.shape = shape; + tensor_desc.rank = 4; + tensor_desc.offset_bytes = 0; + + OrtValue* tensor = nullptr; + status = ort_api_->ExternalResourceImporter_CreateTensorFromMemory(importer, mem_handle, &tensor_desc, nullptr, &tensor); + ASSERT_EQ(status, nullptr) << "CreateTensorFromMemory should succeed"; + ASSERT_NE(tensor, nullptr) << "Tensor should not be null"; + + // Verify tensor properties + OrtTensorTypeAndShapeInfo* type_info = nullptr; + status = ort_api_->GetTensorTypeAndShape(tensor, &type_info); + ASSERT_EQ(status, nullptr); + + size_t rank = 0; + ort_api_->GetDimensionsCount(type_info, &rank); + EXPECT_EQ(rank, 4u); + + std::vector actual_shape(rank); + ort_api_->GetDimensions(type_info, actual_shape.data(), rank); + EXPECT_EQ(actual_shape[0], batch); + EXPECT_EQ(actual_shape[1], channels); + EXPECT_EQ(actual_shape[2], height); + EXPECT_EQ(actual_shape[3], width); + + ONNXTensorElementDataType elem_type; + ort_api_->GetTensorElementType(type_info, &elem_type); + EXPECT_EQ(elem_type, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + + ort_api_->ReleaseTensorTypeAndShapeInfo(type_info); + + // Get the tensor's data pointer and verify it's CUDA device memory + // This proves the D3D12 to CUDA memory import actually happened + void* tensor_data = nullptr; + status = ort_api_->GetTensorMutableData(tensor, &tensor_data); + ASSERT_EQ(status, nullptr) << "GetTensorMutableData should succeed"; + ASSERT_NE(tensor_data, nullptr) << "Tensor data pointer should not be null"; + + // Use cudaPointerGetAttributes to verify this is CUDA device memory + cudaPointerAttributes attrs; + cudaError_t cuda_err = cudaPointerGetAttributes(&attrs, tensor_data); + ASSERT_EQ(cuda_err, cudaSuccess) << "cudaPointerGetAttributes failed: " << cudaGetErrorString(cuda_err); + EXPECT_EQ(attrs.type, cudaMemoryTypeDevice) + << "Memory should be CUDA device memory, but got type " << attrs.type + << " (cudaMemoryTypeDevice=" << cudaMemoryTypeDevice << ")"; + EXPECT_NE(attrs.device, -1) << "Device should be valid"; + + // Cleanup + ort_api_->ReleaseValue(tensor); + ort_api_->ReleaseExternalMemoryHandle(mem_handle); + CloseHandle(shared_handle); + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Import D3D12 Timeline Fence +TEST_F(ExternalResourceImporterTest, ImportD3D12Fence) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create a D3D12 fence with SHARED flag for cross-API import + ComPtr d3d12_fence; + HRESULT hr = d3d12_device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&d3d12_fence)); + ASSERT_TRUE(SUCCEEDED(hr)) << "Failed to create D3D12 fence"; + + // Create shared handle + HANDLE shared_handle = nullptr; + hr = d3d12_device_->CreateSharedHandle(d3d12_fence.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)) << "Failed to create shared fence handle"; + + // Import the semaphore + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_API_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = shared_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; + ASSERT_NE(sem_handle, nullptr) << "Semaphore handle should not be null"; + + // Release semaphore handle + ort_api_->ReleaseExternalSemaphoreHandle(sem_handle); + + // Close shared handle + CloseHandle(shared_handle); + + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Wait and Signal Semaphore +TEST_F(ExternalResourceImporterTest, WaitAndSignalSemaphore) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create a CUDA stream via ORT + OrtSyncStream* ort_stream = nullptr; + status = ort_api_->CreateSyncStreamForEpDevice(ep_device_, nullptr, &ort_stream); + ASSERT_EQ(status, nullptr) << "CreateSyncStreamForEpDevice should succeed"; + + // Create a D3D12 fence + ComPtr d3d12_fence; + HRESULT hr = d3d12_device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&d3d12_fence)); + ASSERT_TRUE(SUCCEEDED(hr)); + + HANDLE shared_handle = nullptr; + hr = d3d12_device_->CreateSharedHandle(d3d12_fence.Get(), nullptr, GENERIC_ALL, nullptr, &shared_handle); + ASSERT_TRUE(SUCCEEDED(hr)); + + // Import semaphore + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_API_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = shared_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr); + + // Signal the fence from D3D12 side + uint64_t signal_value = 1; + command_queue_->Signal(d3d12_fence.Get(), signal_value); + + // Wait on the fence from CUDA side + status = ort_api_->ExternalResourceImporter_WaitSemaphore(importer, sem_handle, ort_stream, signal_value); + ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; + + // Signal from CUDA side + uint64_t cuda_signal_value = 2; + status = ort_api_->ExternalResourceImporter_SignalSemaphore(importer, sem_handle, ort_stream, cuda_signal_value); + ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; + + // Synchronize by getting the native stream handle and calling cudaStreamSynchronize + // (In real code, the signal enqueued on the stream will complete when the stream is flushed) + void* stream_handle = ort_api_->SyncStream_GetHandle(ort_stream); + ASSERT_NE(stream_handle, nullptr); + // Note: In production code, you'd call cudaStreamSynchronize((cudaStream_t)stream_handle) + // For this test, we just verify the handle is valid + + // Verify D3D12 can see the signaled value + HANDLE wait_event = CreateEvent(nullptr, FALSE, FALSE, nullptr); + d3d12_fence->SetEventOnCompletion(cuda_signal_value, wait_event); + DWORD wait_result = WaitForSingleObject(wait_event, 5000); // 5 second timeout + CloseHandle(wait_event); + EXPECT_EQ(wait_result, WAIT_OBJECT_0) << "D3D12 should see the fence signaled by CUDA"; + + // Cleanup + ort_api_->ReleaseExternalSemaphoreHandle(sem_handle); + CloseHandle(shared_handle); + ort_api_->ReleaseSyncStream(ort_stream); + ort_api_->ReleaseExternalResourceImporter(importer); +} + +// Test: Full Inference with External Memory (E2E) +// This test validates the complete D3D12 to CUDA interop pipeline: +// 1. Create D3D12 shared resources and fences +// 2. Import them into CUDA via OrtExternalResourceImporter +// 3. Create ORT tensors from imported memory +// 4. Run inference with proper synchronization +// 5. Verify output correctness +TEST_F(ExternalResourceImporterTest, FullInferenceWithExternalMemory) { + if (!IsD3D12Available()) { + GTEST_SKIP() << "D3D12 not available"; + } + if (!IsEPAvailable()) { + GTEST_SKIP() << "NvTensorRtRtx EP not available"; + } + + // Create a simple ReLU model using shared utility pattern + PathString model_path = ORT_TSTR("external_mem_relu_test.onnx"); + { + onnxruntime::Model model("relu_test", false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(64); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(64); + + auto& input_arg = graph.GetOrCreateNodeArg("X", &tensor_type); + auto& output_arg = graph.GetOrCreateNodeArg("Y", &tensor_type); + graph.AddNode("relu", "Relu", "ReLU operation", {&input_arg}, {&output_arg}); + + ASSERT_STATUS_OK(graph.Resolve()); + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_path)); + } + + const int64_t batch = 1, channels = 3, dim = 64; + const int64_t shape[] = {batch, channels, dim, dim}; + const size_t num_elements = batch * channels * dim * dim; + const size_t buffer_size = num_elements * sizeof(float); + + // Create external resource importer + OrtExternalResourceImporter* importer = nullptr; + OrtStatus* status = ort_api_->CreateExternalResourceImporterForDevice(ep_device_, &importer); + if (status != nullptr || importer == nullptr) { + if (status != nullptr) ort_api_->ReleaseStatus(status); + clearFileIfExists(model_path); + GTEST_SKIP() << "External resource import not supported"; + } + + // Create CUDA stream via ORT + OrtSyncStream* ort_stream = nullptr; + status = ort_api_->CreateSyncStreamForEpDevice(ep_device_, nullptr, &ort_stream); + ASSERT_EQ(status, nullptr); + + // Create shared D3D12 buffers for input and output + ComPtr input_buffer, output_buffer; + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &input_buffer); + D3D12ResourceHelper::CreateSharedBuffer(d3d12_device_.Get(), buffer_size, &output_buffer); + + // Create shared handles for cross-API import + HANDLE input_handle = nullptr, output_handle = nullptr; + d3d12_device_->CreateSharedHandle(input_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &input_handle); + d3d12_device_->CreateSharedHandle(output_buffer.Get(), nullptr, GENERIC_ALL, nullptr, &output_handle); + + // Import memory into CUDA + OrtExternalMemoryDescriptor input_mem_desc = {}; + input_mem_desc.version = ORT_API_VERSION; + input_mem_desc.handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; + input_mem_desc.native_handle = input_handle; + input_mem_desc.size_bytes = buffer_size; + input_mem_desc.offset_bytes = 0; + input_mem_desc.access_mode = ORT_EXTERNAL_MEMORY_ACCESS_READ_WRITE; + + OrtExternalMemoryDescriptor output_mem_desc = input_mem_desc; + output_mem_desc.native_handle = output_handle; + + OrtExternalMemoryHandle *input_mem = nullptr, *output_mem = nullptr; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &input_mem_desc, &input_mem); + ASSERT_EQ(status, nullptr) << "ImportMemory for input should succeed (proves cuImportExternalMemory called)"; + status = ort_api_->ExternalResourceImporter_ImportMemory(importer, &output_mem_desc, &output_mem); + ASSERT_EQ(status, nullptr) << "ImportMemory for output should succeed"; + + // Create ORT tensors from imported memory + OrtExternalTensorDescriptor tensor_desc = {}; + tensor_desc.version = ORT_API_VERSION; + tensor_desc.element_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + tensor_desc.shape = shape; + tensor_desc.rank = 4; + tensor_desc.offset_bytes = 0; + + OrtValue *input_tensor = nullptr, *output_tensor = nullptr; + status = ort_api_->ExternalResourceImporter_CreateTensorFromMemory(importer, input_mem, &tensor_desc, nullptr, &input_tensor); + ASSERT_EQ(status, nullptr); + status = ort_api_->ExternalResourceImporter_CreateTensorFromMemory(importer, output_mem, &tensor_desc, nullptr, &output_tensor); + ASSERT_EQ(status, nullptr); + + // Verify the tensor data pointers are CUDA device memory + void* input_data_ptr = nullptr; + void* output_data_ptr = nullptr; + status = ort_api_->GetTensorMutableData(input_tensor, &input_data_ptr); + ASSERT_EQ(status, nullptr); + status = ort_api_->GetTensorMutableData(output_tensor, &output_data_ptr); + ASSERT_EQ(status, nullptr); + + cudaPointerAttributes input_attrs, output_attrs; + ASSERT_EQ(cudaPointerGetAttributes(&input_attrs, input_data_ptr), cudaSuccess); + ASSERT_EQ(cudaPointerGetAttributes(&output_attrs, output_data_ptr), cudaSuccess); + EXPECT_EQ(input_attrs.type, cudaMemoryTypeDevice) << "Input tensor must be CUDA device memory"; + EXPECT_EQ(output_attrs.type, cudaMemoryTypeDevice) << "Output tensor must be CUDA device memory"; + + // Create D3D12 fence for bidirectional synchronization + ComPtr sync_fence; + d3d12_device_->CreateFence(0, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&sync_fence)); + HANDLE fence_handle = nullptr; + d3d12_device_->CreateSharedHandle(sync_fence.Get(), nullptr, GENERIC_ALL, nullptr, &fence_handle); + + OrtExternalSemaphoreDescriptor sem_desc = {}; + sem_desc.version = ORT_API_VERSION; + sem_desc.type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + sem_desc.native_handle = fence_handle; + + OrtExternalSemaphoreHandle* sem_handle = nullptr; + status = ort_api_->ExternalResourceImporter_ImportSemaphore(importer, &sem_desc, &sem_handle); + ASSERT_EQ(status, nullptr) << "ImportSemaphore should succeed"; + + // Setup test data via D3D12 upload buffer + ComPtr upload_buffer; + D3D12ResourceHelper::CreateUploadBuffer(d3d12_device_.Get(), buffer_size, &upload_buffer); + + // Generate test data: alternating positive and negative values for ReLU verification + std::vector test_data(num_elements); + for (size_t i = 0; i < num_elements; ++i) { + test_data[i] = (i % 2 == 0) ? static_cast(i + 1) : -static_cast(i + 1); + } + + void* upload_ptr = nullptr; + upload_buffer->Map(0, nullptr, &upload_ptr); + memcpy(upload_ptr, test_data.data(), buffer_size); + upload_buffer->Unmap(0, nullptr); + + // Copy upload buffer to input buffer via D3D12 + command_allocator_->Reset(); + command_list_->Reset(command_allocator_.Get(), nullptr); + command_list_->CopyBufferRegion(input_buffer.Get(), 0, upload_buffer.Get(), 0, buffer_size); + command_list_->Close(); + + ID3D12CommandList* cmd_lists[] = {command_list_.Get()}; + command_queue_->ExecuteCommandLists(1, cmd_lists); + + // Signal fence after D3D12 upload completes + uint64_t upload_complete_value = 1; + command_queue_->Signal(sync_fence.Get(), upload_complete_value); + + // Make CUDA wait for D3D12 upload to complete + status = ort_api_->ExternalResourceImporter_WaitSemaphore(importer, sem_handle, ort_stream, upload_complete_value); + ASSERT_EQ(status, nullptr) << "WaitSemaphore should succeed"; + + // Setup ORT session with user_compute_stream + Ort::SessionOptions session_options; + session_options.SetExecutionMode(ORT_SEQUENTIAL); + session_options.DisableMemPattern(); + session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); + + // Configure to use our CUDA stream + char stream_address[32]; + size_t stream_addr_val = reinterpret_cast(ort_api_->SyncStream_GetHandle(ort_stream)); + sprintf(stream_address, "%llu", static_cast(stream_addr_val)); + const char* option_keys[] = {"user_compute_stream", "has_user_compute_stream"}; + const char* option_values[] = {stream_address, "1"}; + + // Add the NvTensorRtRtx EP with user stream + status = ort_api_->SessionOptionsAppendExecutionProvider_V2( + session_options, *ort_env, &ep_device_, 1, option_keys, option_values, 2); + ASSERT_EQ(status, nullptr); + + // Create session + Ort::Session session(*ort_env, model_path.c_str(), session_options); + + // Create IoBinding and bind external tensors + Ort::IoBinding io_binding(session); + Ort::AllocatorWithDefaultOptions allocator; + + Ort::AllocatedStringPtr input_name = session.GetInputNameAllocated(0, allocator); + Ort::AllocatedStringPtr output_name = session.GetOutputNameAllocated(0, allocator); + + io_binding.BindInput(input_name.get(), Ort::Value(input_tensor)); + io_binding.BindOutput(output_name.get(), Ort::Value(output_tensor)); + io_binding.SynchronizeInputs(); + + // Run inference. ORT synchronizes the stream before returning, so GPU work is complete + // when we signal the semaphore below. + session.Run(Ort::RunOptions{}, io_binding); + + // Signal from CUDA that inference is complete + uint64_t inference_complete_value = 2; + status = ort_api_->ExternalResourceImporter_SignalSemaphore(importer, sem_handle, ort_stream, inference_complete_value); + ASSERT_EQ(status, nullptr) << "SignalSemaphore should succeed"; + + // Wait on D3D12 for CUDA inference to complete + command_queue_->Wait(sync_fence.Get(), inference_complete_value); + + // Copy output to readback buffer + ComPtr readback_buffer; + D3D12ResourceHelper::CreateReadbackBuffer(d3d12_device_.Get(), buffer_size, &readback_buffer); + + command_allocator_->Reset(); + command_list_->Reset(command_allocator_.Get(), nullptr); + command_list_->CopyBufferRegion(readback_buffer.Get(), 0, output_buffer.Get(), 0, buffer_size); + command_list_->Close(); + + command_queue_->ExecuteCommandLists(1, cmd_lists); + D3D12ResourceHelper::FlushAndWait(d3d12_device_.Get(), command_queue_.Get()); + + // Read back and verify ReLU output: max(0, x) + std::vector output_data(num_elements); + void* readback_ptr = nullptr; + readback_buffer->Map(0, nullptr, &readback_ptr); + memcpy(output_data.data(), readback_ptr, buffer_size); + readback_buffer->Unmap(0, nullptr); + + // Verify ReLU correctness + for (size_t i = 0; i < num_elements; ++i) { + float expected = std::max(0.0f, test_data[i]); + EXPECT_FLOAT_EQ(output_data[i], expected) + << "Mismatch at index " << i << ": input=" << test_data[i] + << ", expected=" << expected << ", got=" << output_data[i]; + } + + // Note: io_binding takes ownership of input_tensor and output_tensor, so don't release them manually + + // Cleanup + ort_api_->ReleaseExternalSemaphoreHandle(sem_handle); + ort_api_->ReleaseExternalMemoryHandle(output_mem); + ort_api_->ReleaseExternalMemoryHandle(input_mem); + CloseHandle(fence_handle); + CloseHandle(output_handle); + CloseHandle(input_handle); + ort_api_->ReleaseSyncStream(ort_stream); + ort_api_->ReleaseExternalResourceImporter(importer); + + clearFileIfExists(model_path); +} + +#endif // _WIN32 + +} // namespace test +} // namespace onnxruntime