diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a8e7d2cc9299a..9cf672da44337 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -244,6 +244,31 @@ option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF) option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF) option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF) +# DX for interop feature option +option(onnxruntime_USE_DX_FOR_INTEROP "Build with the DX for Interop feature." OFF) + +if (onnxruntime_USE_DX_FOR_INTEROP) + add_compile_definitions(DX_FOR_INTEROP=1) +else() + add_compile_definitions(DX_FOR_INTEROP=0) +endif() + +# Vulkan for interop feature option +find_package(Vulkan QUIET) +option(onnxruntime_USE_VULKAN_FOR_INTEROP "Build with the Vulkan for Interop feature." OFF) + +if (onnxruntime_USE_VULKAN_FOR_INTEROP AND Vulkan_FOUND) + if (WIN32) + add_compile_definitions(VK_USE_PLATFORM_WIN32_KHR=1) + endif() + add_compile_definitions(VULKAN_FOR_INTEROP=1) +else() + add_compile_definitions(VULKAN_FOR_INTEROP=0) + if (NOT Vulkan_FOUND) + message(STATUS "Vulkan not found. Vulkan interop disabled.") + endif() +endif() + option(onnxruntime_USE_TENSORRT_INTERFACE "Build ONNXRuntime shared lib which is compatible with TensorRT EP interface" OFF) option(onnxruntime_USE_NV_INTERFACE "Build ONNXRuntime shared lib which is compatible with NV EP interface" OFF) option(onnxruntime_USE_CUDA_INTERFACE "Build ONNXRuntime shared lib which is compatible with Cuda EP interface" OFF) @@ -1198,6 +1223,9 @@ function(onnxruntime_configure_target target_name) set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props) endif() target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + if(Vulkan_FOUND) + target_include_directories(${target_name} PRIVATE ${Vulkan_INCLUDE_DIRS}) + endif() if (onnxruntime_ENABLE_TRAINING_OPS) target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT}) endif() diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake index e59463b6b91f1..5ec45a64e46bb 100644 --- a/cmake/onnxruntime_providers_nv.cmake +++ b/cmake/onnxruntime_providers_nv.cmake @@ -146,9 +146,9 @@ endif () target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen) add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) else() - target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) + target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver) endif() target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 1ae7b5c9eb991..20351450906c6 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -451,6 +451,10 @@ public struct OrtApi public IntPtr Graph_GetModelMetadata; public IntPtr GetModelCompatibilityForEpDevices; public IntPtr CreateExternalInitializerInfo; + + public IntPtr GetOrtFenceForGraphicsInterop; + public IntPtr InteropEpWait; + public IntPtr InteropEpSignal; } internal static class NativeMethods @@ -482,7 +486,7 @@ static NativeMethods() DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetApi, typeof(DOrtGetApi)); #endif - const uint ORT_API_VERSION = 14; + const uint ORT_API_VERSION = 15; #if NETSTANDARD2_0 IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION); api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi)); @@ -847,7 +851,7 @@ static NativeMethods() api_.CreateSyncStreamForEpDevice, typeof(DOrtCreateSyncStreamForEpDevice)); - OrtSyncStream_GetHandle = + OrtSyncStream_GetHandle = (DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer( api_.SyncStream_GetHandle, typeof(DOrtSyncStream_GetHandle)); @@ -861,6 +865,21 @@ static NativeMethods() (DOrtCopyTensors)Marshal.GetDelegateForFunctionPointer( api_.CopyTensors, typeof(DOrtCopyTensors)); + + OrtGetOrtFenceForGraphicsInterop = + (DOrtGetOrtFenceForGraphicsInterop)Marshal.GetDelegateForFunctionPointer( + api_.GetOrtFenceForGraphicsInterop, + typeof(DOrtGetOrtFenceForGraphicsInterop)); + + OrtInteropEpWait = + (DOrtInteropEpWait)Marshal.GetDelegateForFunctionPointer( + api_.InteropEpWait, + typeof(DOrtInteropEpWait)); + + OrtInteropEpSignal = + (DOrtInteropEpSignal)Marshal.GetDelegateForFunctionPointer( + api_.InteropEpSignal, + typeof(DOrtInteropEpSignal)); } internal class NativeLib @@ -2644,7 +2663,7 @@ public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, byte[] /* const char* */ value); /// - /// Get the value for the provided key. + /// Get the value for the provided key. /// /// Value. Returns IntPtr.Zero if key was not found. [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -2743,6 +2762,30 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, out IntPtr /* OrtSyncStream** */ stream ); + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetOrtFenceForGraphicsInterop( + IntPtr /* OrtSession* */ session, + IntPtr /* struct GraphicsInteropParams* */ graphicsInteropParams, + IntPtr /* struct FenceInteropParams* */ fenceInteropParams, + out IntPtr /* void** */ ortFence + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtInteropEpWait( + IntPtr /* OrtSession* */ session, + IntPtr /* void* */ ortFence, + IntPtr /* OrtSyncStream* */ stream, + uint /* uint64_t */ fenceValue + ); + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtInteropEpSignal( + IntPtr /* OrtSession* */ session, + IntPtr /* void* */ ortFence, + IntPtr /* OrtSyncStream* */ stream, + uint /* uint64_t */ fenceValue + ); + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* void* */ DOrtSyncStream_GetHandle( IntPtr /* OrtSyncStream* */ stream @@ -2760,6 +2803,9 @@ out IntPtr /* OrtSyncStream** */ stream public static DOrtEpDevice_Device OrtEpDevice_Device; public static DOrtEpDevice_MemoryInfo OrtEpDevice_MemoryInfo; public static DOrtCreateSyncStreamForEpDevice OrtCreateSyncStreamForEpDevice; + public static DOrtGetOrtFenceForGraphicsInterop OrtGetOrtFenceForGraphicsInterop; + public static DOrtInteropEpWait OrtInteropEpWait; + public static DOrtInteropEpSignal OrtInteropEpSignal; public static DOrtSyncStream_GetHandle OrtSyncStream_GetHandle; public static DOrtReleaseSyncStream OrtReleaseSyncStream; @@ -2767,7 +2813,7 @@ out IntPtr /* OrtSyncStream** */ stream // Auto Selection EP registration and selection customization /// - /// Register an execution provider library. + /// Register an execution provider library. /// The library must implement CreateEpFactories and ReleaseEpFactory. /// /// Environment to add the EP library to. diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index f54f4a5a6f1ef..a2bf333444107 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -38,6 +38,14 @@ class GraphOptimizerRegistry; #include "core/framework/tuning_context.h" #include "core/session/onnxruntime_c_api.h" +#if DX_FOR_INTEROP && _WIN32 +#include +#endif + +#if VULKAN_FOR_INTEROP +#include +#endif + struct OrtEpDevice; struct OrtRunOptions; @@ -92,6 +100,86 @@ class IExecutionProvider { public: virtual ~IExecutionProvider() = default; + virtual Status GetExtSemaphore(const struct GraphicsInteropParams* graphicsInteropParams, const struct FenceInteropParams* fenceInteropParams, void** extSemFence) { + auto interop_params_sptr = std::make_shared(*fenceInteropParams); + *extSemFence = new std::shared_ptr(interop_params_sptr); + ORT_UNUSED_PARAMETER(graphicsInteropParams); + return Status::OK(); + } + + virtual Status SetupInteropEpWait(void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) { + ORT_UNUSED_PARAMETER(stream); + auto* sptr_ptr = static_cast*>(extSemFence); + std::shared_ptr interopWaitParamsSptr = *sptr_ptr; + delete sptr_ptr; + + auto* interopWaitParams = interopWaitParamsSptr.get(); + + ExternalSyncPrimitive extSyncPrimitive = interopWaitParams->extSyncPrimitive; + // to-do: The fallback logic needs more refinement to deal with multi threaded scenarios. + if (extSyncPrimitive == ExternalSyncPrimitive_D3D12Fence) { +#if DX_FOR_INTEROP && _WIN32 + HANDLE hEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr); + reinterpret_cast(interopWaitParams->FencePtr.pFence)->SetEventOnCompletion(fenceValue, hEvent); + WaitForSingleObject(hEvent, INFINITE); + CloseHandle(hEvent); + return Status::OK(); +#endif + } + else if(extSyncPrimitive == ExternalSyncPrimitive_VulkanSemaphore) + { +#if VULKAN_FOR_INTEROP + PFN_vkWaitForFences pfnVkWaitForFences = reinterpret_cast( + reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkGetDeviceProcAddr)( + reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), "vkWaitForFences")); + + if (!pfnVkWaitForFences) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get function pointer for vkWaitForFences"); + } + VkResult result = pfnVkWaitForFences(reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), 1, reinterpret_cast(&interopWaitParams->FencePtr.pVkFence), VK_TRUE, UINT64_MAX); + + if (result != VK_SUCCESS) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "vkWaitForFences failed with Vulkan error code: " + std::to_string(result)); + } + + PFN_vkResetFences pfnVkResetFences = reinterpret_cast( + reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkGetDeviceProcAddr)( + reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), "vkResetFences")); + if (pfnVkResetFences) { + pfnVkResetFences(reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), 1, reinterpret_cast(&interopWaitParams->FencePtr.pVkFence)); + } + + return Status::OK(); +#endif + } + ORT_UNUSED_PARAMETER(fenceValue); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported External Sync primitive"); + } + virtual Status SetupInteropEpSignal(const OrtEpApi* ortEpApi, void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) { + ORT_UNUSED_PARAMETER(extSemFence); + ORT_UNUSED_PARAMETER(fenceValue); + + const OrtSyncStreamImpl* streamImpl; + OrtSyncNotificationImpl* streamNotification; + streamImpl = ortEpApi->SyncStream_GetImpl(static_cast(stream)); + + OrtStatus* status = nullptr; + status = streamImpl->CreateNotification(const_cast(streamImpl), &streamNotification); + if(status != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create notification"); + } + + status = streamNotification->Activate(streamNotification); + if(status != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to activate notification"); + } + status = streamNotification->WaitOnHost(streamNotification); + if(status != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait on host"); + } + return Status::OK(); + } + /** * Returns a data transfer object that implements methods to copy to and * from this device. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 434aa075e62d6..2480a31d690f2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -329,6 +329,7 @@ ORT_RUNTIME_CLASS(HardwareDevice); ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. +ORT_RUNTIME_CLASS(Fence); ORT_RUNTIME_CLASS(ExternalInitializerInfo); #ifdef _MSC_VER @@ -507,6 +508,43 @@ typedef enum OrtExecutionProviderDevicePolicy { OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER, } OrtExecutionProviderDevicePolicy; +typedef enum ExternalSyncPrimitive { + ExternalSyncPrimitive_D3D12Fence, + ExternalSyncPrimitive_VulkanSemaphore, // Vulkan timeline semaphore +} ExternalSyncPrimitive; + +typedef struct VulkanDeviceParams { + void* pVkDevice; + void* pVkGetDeviceProcAddr; +} VulkanDeviceParams; + +typedef struct GraphicsInteropParams { + ExternalSyncPrimitive extSyncPrimitive; + union DevicePtr { + struct DXDeviceParams { + void* pDevice; + void* pCommandQueue; + } DXDeviceParams; + VulkanDeviceParams VulkanDeviceParams; + } DevicePtr; + +} GraphicsInteropParams; + +typedef struct FenceInteropParams { + ExternalSyncPrimitive extSyncPrimitive; + union FencePtr { + void* pFence; + void* pVkFence; + void* VkSemaphore; + } FencePtr; + VulkanDeviceParams VulkanDeviceParams; +} FenceInteropParams; + +typedef struct SemaphoreEpMap { + void* extSemFence; + void* selectedEp; +} SemaphoreEpMap; + /** \brief Delegate to allow providing custom OrtEpDevice selection logic * * This delegate is called by the EP selection code to allow the user to provide custom device selection logic. @@ -6590,6 +6628,99 @@ struct OrtApi { * \since Version 1.24 */ ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** \brief Setup Graphics Interopcontext for an execution provider device. + * + * This function enables D3D12/Vulkan interoperability with a graphics API command queue/device. Once setup, any OrtSyncStream + * created for this ep_device via CreateSyncStreamForEpDevice will be created, enabling efficient GPU-side synchronization. + * + * This must be called BEFORE CreateSyncStreamForEpDevice for the same ep_device. + * + * \param[in] ep_device The OrtEpDevice to setup Graphics interop for. + * \param[in] graphicsInteropParams Pointer to struct containing D3D12 command queue or Vulkan device info. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24 + */ + ORT_API2_STATUS(SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_ const struct GraphicsInteropParams* graphicsInteropParams); + + /** + * \brief Create an OrtFence wrapper for an external GPU fence/semaphore to enable graphics interop. + * + * This function creates an OrtFence object that wraps an external synchronization primitive (D3D12 fence or Vulkan timeline semaphore) + * provided by the caller. The resulting OrtFence can be used with InteropEpWait and InteropEpSignal to synchronize ONNX Runtime + * computation with external graphics APIs like D3D12 or Vulkan. + * + * This enables integration between ONNX Runtime GPU computation and external graphics or compute pipelines, allowing efficient + * GPU-side synchronization between ONNX Runtime inference and external graphics workloads. + * + * Prerequisites: + * - SetupGraphicsInteropForEpDevice must be called first for the relevant ep_device. + * - The external fence/semaphore provided must be compatible with the graphics API specified in graphicsInteropParams. + * + * \param[in] session An OrtSession instance whose execution providers participate in the graphics interop. + * \param[in] graphicsInteropParams Pointer to a struct specifying the graphics API type (D3D12 or Vulkan) and device information. + * \param[in] fenceInteropParams Pointer to a struct containing the external fence/semaphore handle to be wrapped. + * \param[out] ortFence Pointer to receive the created OrtFence object. This OrtFence wraps the external synchronization primitive + * and can be used with InteropEpWait and InteropEpSignal APIs. + * + * \retval ORT_OK On success. + * \retval ORT_NOT_IMPLEMENTED If none of the active providers support graphics-fence interop. + * \retval ORT_FAIL or provider-specific error if the operation fails for another reason. + * + * \since Version 1.24 + */ + ORT_API2_STATUS(GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence); + + /** + * \brief Wait on a graphics interop external fence/semaphore using an ONNX Runtime execution provider. + * + * This function synchronizes ONNX Runtime computation with an external graphics or compute API by waiting on an external + * semaphore or fence. It is typically used in scenarios where GPU computation in ONNX Runtime must be synchronized with + * non-ONNX Runtime workloads, such as graphics pipelines (e.g., Direct3D 12 or Vulkan). + * + * The external synchronization primitive (such as a fence or semaphore handle) referenced by `extSemFence` should be + * obtained from a previous call to GetOrtFenceForGraphicsInterop. + * + * The implementation and support for this is execution provider-specific. + * + * \param[in] extSemFence The handle to the external synchronization primitive, as returned from GetOrtFenceForGraphicsInterop. + * \param[in] stream The OrtSyncStream instance on which the synchronization will be performed. + * \param[in] fenceValue The fence value for synchronization (if required for the specific graphics API/interop scenario). + * + * \retval ORT_OK On successful synchronization. + * \retval ORT_NOT_IMPLEMENTED If the current execution provider does not support graphics interop wait. + * \retval ORT_FAIL or provider-specific error if the synchronization operation fails for another reason. + * + * \since Version 1.24 + */ + ORT_API2_STATUS(InteropEpWait, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue); + + /** + * \brief Signal a graphics interop external fence/semaphore using an ONNX Runtime execution provider. + * + * This function synchronizes external graphics or compute APIs with ONNX Runtime computation by signaling an external + * semaphore or fence. This is typically used when GPU computation in ONNX Runtime has completed and an external + * workload (such as a graphics pipeline) should continue execution. + * + * The external synchronization primitive (such as a fence or semaphore handle) referenced by `extSemFence` should + * have been obtained from a previous call to GetOrtFenceForGraphicsInterop. + * + * The behavior and support for this operation is execution provider-specific. + * + * \param[in] extSemFence The handle to the external synchronization primitive, as returned from GetOrtFenceForGraphicsInterop. + * \param[in] stream The OrtSyncStream on which the signal operation should occur. + * \param[in] fenceValue The fence value to signal (if required for the specific graphics API or interop scenario). + * + * \retval ORT_OK On successful signal operation. + * \retval ORT_NOT_IMPLEMENTED If the current execution provider does not support graphics interop signaling. + * \retval ORT_FAIL or provider-specific error if signaling fails for another reason. + * + * \since Version 1.24 + */ + ORT_API2_STATUS(InteropEpSignal, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index de38085914516..e15fe2dd69857 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1010,6 +1010,7 @@ struct OrtEpFactory { * This is used to create a synchronization stream for the memory device that can be used for operations outside of * a session. * + * * \param[in] this_ptr The OrtEpFactory instance. * \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for. * \param[in] stream_options Options for stream creation. May be nullptr. @@ -1052,6 +1053,24 @@ struct OrtEpFactory { * \since Version 1.24. */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); + + /** \brief Setup Graphics interop for a memory device. + * + * This function sets up Graphics interop associated with a graphics API (D3D12/Vulkan) command queue/device. + * + * Optional - EP factories that don't support Graphics interop setup should set this to nullptr. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_device The OrtMemoryDevice to setup Graphics interop for. + * \param[in] graphicsInteropParams Graphics API parameters (D3D12 command queue or Vulkan device info). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(SetupGraphicsInterop, _In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _In_ const struct GraphicsInteropParams* graphicsInteropParams); }; #ifdef __cplusplus diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index e2a8005aba1da..110d14fcda3e0 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -33,6 +33,13 @@ #include // TODO: find a better way to share this #include "core/providers/cuda/cuda_stream_handle.h" +#include "core/providers/cuda/cuda_common.h" + +#if VULKAN_FOR_INTEROP +#if _WIN32 +typedef VkResult (VKAPI_PTR *PFN_vkGetSemaphoreWin32HandleKHR)(VkDevice device, const VkSemaphoreGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle); +#endif +#endif #ifdef _WIN32 #include @@ -104,6 +111,111 @@ static bool IsSupportedInputOutputDataType(ONNXTensorElementDataType data_type) } } +Status NvExecutionProvider::GetExtSemaphore(const struct GraphicsInteropParams* graphicsInteropParams, const struct FenceInteropParams* fenceInteropParams, void** extSemFence) +{ + if (!info_.has_user_compute_stream) + { + (void)GetPerThreadContext(); + } + cudaExternalSemaphore_t cSemFence; + cudaExternalSemaphoreHandleDesc semHandleDesc = {}; + + assert(graphicsInteropParams->extSyncPrimitive == fenceInteropParams->extSyncPrimitive && + "ExternalSyncPrimitive mismatch between graphicsInteropParams and fenceInteropParams"); + (void)graphicsInteropParams; + + ExternalSyncPrimitive extSyncPrimitive = fenceInteropParams->extSyncPrimitive; + + if(extSyncPrimitive == ExternalSyncPrimitive_D3D12Fence) + { +#if DX_FOR_INTEROP && _WIN32 + HANDLE sharedFenceHandle = nullptr; + if(reinterpret_cast(graphicsInteropParams->DevicePtr.DXDeviceParams.pDevice)->CreateSharedHandle(reinterpret_cast(fenceInteropParams->FencePtr.pFence), nullptr, GENERIC_ALL, nullptr, &sharedFenceHandle) != S_OK) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create shared handle for D3D12 fence"); + } + semHandleDesc.type = cudaExternalSemaphoreHandleTypeD3D12Fence; + semHandleDesc.handle.win32.handle = sharedFenceHandle; +#endif + } + else if(extSyncPrimitive == ExternalSyncPrimitive_VulkanSemaphore) + { +#if VULKAN_FOR_INTEROP +#if _WIN32 + VkSemaphoreGetWin32HandleInfoKHR handleInfo = {}; + handleInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_GET_WIN32_HANDLE_INFO_KHR; + handleInfo.semaphore = reinterpret_cast(fenceInteropParams->FencePtr.VkSemaphore); + handleInfo.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR; + + HANDLE sharedFenceHandle = nullptr; + + // Get the function pointer for vkGetSemaphoreWin32HandleKHR + PFN_vkGetSemaphoreWin32HandleKHR pfnVkGetSemaphoreWin32HandleKHR = reinterpret_cast( + reinterpret_cast(graphicsInteropParams->DevicePtr.VulkanDeviceParams.pVkGetDeviceProcAddr)( + reinterpret_cast(graphicsInteropParams->DevicePtr.VulkanDeviceParams.pVkDevice), "vkGetSemaphoreWin32HandleKHR")); + + if (pfnVkGetSemaphoreWin32HandleKHR && + pfnVkGetSemaphoreWin32HandleKHR(reinterpret_cast(graphicsInteropParams->DevicePtr.VulkanDeviceParams.pVkDevice), &handleInfo, &sharedFenceHandle) == VK_SUCCESS) + { + semHandleDesc.type = cudaExternalSemaphoreHandleTypeOpaqueWin32; + semHandleDesc.handle.win32.handle = sharedFenceHandle; + } + else + { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get shared semaphore handle"); + } +#endif +#endif + } + if(cudaImportExternalSemaphore(&cSemFence, &semHandleDesc) != cudaSuccess) + { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to import external semaphore"); + } +#ifdef _WIN32 + CloseHandle(semHandleDesc.handle.win32.handle); +#endif + *extSemFence = cSemFence; + + return Status::OK(); +} + +Status NvExecutionProvider::SetupInteropEpWait(void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) +{ + LOGS_DEFAULT(VERBOSE) << "NvExecutionProvider::SetupInteropEpWait() called."; + + // make CUDA wait for the upload by Graphics API to finish + cudaExternalSemaphoreWaitParams waitParams = {}; + waitParams.params.fence.value = fenceValue; + cudaExternalSemaphore_t cSemFence = static_cast(extSemFence); + cudaStream_t cudaStream = static_cast(Ort::GetApi().SyncStream_GetHandle(stream)); + cudaError_t result = cudaWaitExternalSemaphoresAsync(&cSemFence, &waitParams, 1, cudaStream); + if(result != cudaSuccess) + { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for external semaphore"); + } + + return Status::OK(); +} + +Status NvExecutionProvider::SetupInteropEpSignal(const OrtEpApi* ortEpApi, void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) +{ + LOGS_DEFAULT(VERBOSE) << "NvExecutionProvider::SetupInteropEpSignal() called."; + + cudaStream_t cudaStream = static_cast(Ort::GetApi().SyncStream_GetHandle(stream)); + + // make Graphics API wait for the CUDA kernel to finish + cudaExternalSemaphoreSignalParams signalParams = {}; + signalParams.params.fence.value = fenceValue; + cudaExternalSemaphore_t cSemFence = static_cast(extSemFence); + cudaError_t result = cudaSignalExternalSemaphoresAsync(&cSemFence, &signalParams, 1, cudaStream); + if(result != cudaSuccess) + { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to signal external semaphore"); + } + + ORT_UNUSED_PARAMETER(ortEpApi); + return Status::OK(); +} + // Helper function to check if a data type is supported by NvTensorRTRTX EP static bool IsSupportedDataType(ONNXTensorElementDataType data_type) { switch (data_type) { diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index bb8f687db094f..7be9feb48c8e6 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -289,6 +289,10 @@ class NvExecutionProvider : public IExecutionProvider { // explicit NvExecutionProvider(const ProviderOptions& provider_options_map, const ConfigOptions* config_options); virtual ~NvExecutionProvider(); + virtual Status GetExtSemaphore(const struct GraphicsInteropParams* graphicsInteropParams, const struct FenceInteropParams* fenceInteropParams, void** extSemFence) override; + virtual Status SetupInteropEpWait(void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) override; + virtual Status SetupInteropEpSignal(const OrtEpApi* ortEpApi, void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) override; + cublasHandle_t PerThreadDefaultCublasHandle() { return GetPerThreadContext().CublasHandle(); } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index c3fbccef84883..1635e987fb334 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -531,6 +531,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { CreateDataTransfer = CreateDataTransferImpl; IsStreamAware = IsStreamAwareImpl; + SetupGraphicsInterop = SetupCigContextImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. @@ -706,6 +707,62 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return true; } + static OrtStatus* ORT_API_CALL SetupCigContextImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const struct GraphicsInteropParams* graphicsInteropParams) noexcept { + auto& factory = *static_cast(this_ptr); + auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); + + // Initialize CUDA + CUresult result = cuInit(0); + if (result != CUDA_SUCCESS) { + return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to initialize CUDA for CIG context creation"); + } + + // Get LUID from CUDA device + cudaDeviceProp cuda_prop; + cudaError_t cuda_err = cudaGetDeviceProperties(&cuda_prop, device_id); + if (cuda_err != cudaSuccess) { + return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to get CUDA device properties"); + } + + // Create CIG context based on graphics API type + CUcontext cig_context = nullptr; + + if (graphicsInteropParams->extSyncPrimitive == ExternalSyncPrimitive_D3D12Fence) { +#if DX_FOR_INTEROP && _WIN32 + // Get LUID of memory device and D3D12 device and compare it to that of memory device + uint64_t cuda_luid = *reinterpret_cast(cuda_prop.luid); + LUID d3d12_luid = reinterpret_cast(graphicsInteropParams->DevicePtr.DXDeviceParams.pDevice)->GetAdapterLuid(); + uint64_t d3d12_luid_64 = (static_cast(d3d12_luid.HighPart) << 32) | d3d12_luid.LowPart; + if (d3d12_luid_64 != cuda_luid) { + return factory.ort_api.CreateStatus(ORT_FAIL, "D3D12 device LUID does not match CUDA device LUID"); + } + + // Create CIG context bound to D3D12 command queue + CUctxCigParam ctxCigParams = { CIG_DATA_TYPE_D3D12_COMMAND_QUEUE, reinterpret_cast(graphicsInteropParams->DevicePtr.DXDeviceParams.pCommandQueue) }; + CUctxCreateParams ctxParams = { nullptr, 0, &ctxCigParams }; + + result = cuCtxCreate_v4(&cig_context, &ctxParams, 0, device_id); + if (result != CUDA_SUCCESS) { + return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to create CIG context for D3D12"); + } +#else + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "CIG context creation not supported on this platform"); +#endif + } else if (graphicsInteropParams->extSyncPrimitive == ExternalSyncPrimitive_VulkanSemaphore) { + // TODO: Add Vulkan CIG context support if needed + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Vulkan CIG context not yet implemented"); + } else { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Unsupported graphics API for CIG context"); + } + + // Store the CIG context for this device + factory.cig_contexts_[device_id] = cig_context; + + return nullptr; + } + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, const OrtMemoryDevice* memory_device, const OrtKeyValuePairs* /*stream_options*/, @@ -714,8 +771,23 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); cudaStream_t stream = nullptr; - CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); - CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + + // Check if we have a CIG context for this device + auto cig_it = factory.cig_contexts_.find(device_id); + if (cig_it != factory.cig_contexts_.end()) { + // We have a CIG context - make it current and create stream on it + CUresult result = cuCtxSetCurrent(cig_it->second); + if (result != CUDA_SUCCESS) { + return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to set CIG context current"); + } + + // Create stream on the CIG context + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } else { + // No CIG context - use default behavior + CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } const OrtDevice* ort_device = static_cast(memory_device); @@ -770,6 +842,9 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { // we use a shared instance for the OrtDataTransferImpl instead of creating a new one on every call to NvTrtRtxDataTransferImpl data_transfer_impl; + // Map to store CIG context per device ID (for D3D12/Vulkan interop) + std::unordered_map cig_contexts_; + NvTensorRtRtxEpFactory(const NvTensorRtRtxEpFactory&) = delete; NvTensorRtRtxEpFactory& operator=(const NvTensorRtRtxEpFactory&) = delete; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 14f0892687ad1..3d8267e5e2e1f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -730,6 +730,12 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const #endif // !defined(ORT_MINIMAL_BUILD) InferenceSession::~InferenceSession() { + + for (auto* sptr_ptr : ort_fences_for_cleanup_) { + delete sptr_ptr; + } + ort_fences_for_cleanup_.clear(); + if (session_options_.enable_profiling) { ORT_TRY { EndProfiling(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 8bea15c169ed4..61c7c611f87e6 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -653,6 +653,14 @@ class InferenceSession { return session_id_; } + void RegisterOrtFenceForCleanup(std::shared_ptr* ortFence) { + if(!ortFence) { + ORT_THROW("Cannot register null fence for cleanup"); + } + std::lock_guard lock(ort_fences_mutex_); + ort_fences_for_cleanup_.push_back(ortFence); + } + protected: #if !defined(ORT_MINIMAL_BUILD) @@ -1045,6 +1053,8 @@ class InferenceSession { // Enable nodestats collection std::optional node_stats_recorder_; #endif + mutable std::mutex ort_fences_mutex_; + std::vector*> ort_fences_for_cleanup_; }; struct SessionIOBinding { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 546b11ae580d5..1663730719a91 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3366,6 +3366,43 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_ const struct GraphicsInteropParams* graphicsInteropParams) { + API_IMPL_BEGIN + if (ep_device == nullptr || graphicsInteropParams == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and graphicsInteropParams must be provided."); + } + + if (ep_device->device_memory_info == nullptr) + { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device must have valid device_memory_info."); + } + const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr; + + if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not use DEFAULT memory of a non-CPU device."); + } + + const auto* factory = ep_device->ep_factory; + if (!factory->IsStreamAware(factory)) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support streams."); + } + + // Check if the factory supports Graphics interop setup + if (factory->SetupGraphicsInterop == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support Graphics interop setup."); + } + + // Call the EP factory to setup Graphics interop + ORT_API_RETURN_IF_ERROR(factory->SetupGraphicsInterop(ep_device->GetMutableFactory(), + static_cast(device), // alias + graphicsInteropParams)); + + return nullptr; + + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStream** ort_stream) { @@ -3374,6 +3411,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and stream must be provided."); } + if (ep_device->device_memory_info == nullptr) + { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device must have valid device_memory_info."); + } const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr; if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) { @@ -3407,6 +3448,108 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence) { + API_IMPL_BEGIN + auto* inference_session = reinterpret_cast(session); + if (!inference_session) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Session is null"); + } + + auto semaphore_ep_map_sptr = std::make_shared(); + semaphore_ep_map_sptr->extSemFence = nullptr; + semaphore_ep_map_sptr->selectedEp = nullptr; + + const auto& session_state = inference_session->GetSessionState(); + const auto& execution_providers = session_state.GetExecutionProviders(); + const auto& graph_viewer = session_state.GetGraphViewer(); + + // Collect the unique set of execution providers assigned to nodes in the graph. + std::unordered_set active_provider_types; + for (const auto& node : graph_viewer.Nodes()) { + if (!node.GetExecutionProviderType().empty()) { + active_provider_types.insert(node.GetExecutionProviderType()); + } + } + + // Call GetExtSemaphore only for the providers that are actively being used. + for (const auto& provider_type : active_provider_types) { + const onnxruntime::IExecutionProvider* const_provider = execution_providers.Get(provider_type); + if (const_provider) { + auto* provider = const_cast(const_provider); + auto status = provider->GetExtSemaphore(graphicsInteropParams, fenceInteropParams, &semaphore_ep_map_sptr->extSemFence); + if(status.IsOK()) { + semaphore_ep_map_sptr->selectedEp = provider; + auto* wrapper = new std::shared_ptr(semaphore_ep_map_sptr); + inference_session->RegisterOrtFenceForCleanup(wrapper); + *ortFence = reinterpret_cast(wrapper); + return nullptr; + } + if (status.Code() != onnxruntime::common::StatusCode::NOT_IMPLEMENTED) { + return ToOrtStatus(status); + } + } + } + + return OrtApis::CreateStatus(ORT_FAIL, "No active execution provider returned a semaphore for the given session"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::InteropEpWait, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue) { + API_IMPL_BEGIN + auto* sptr_ptr = reinterpret_cast*>(ortFence); + auto* semaphoreEpMap = sptr_ptr->get(); + + if(semaphoreEpMap->extSemFence == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External Fence Semaphore is null."); + } + if(stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Stream is null."); + } + + const onnxruntime::IExecutionProvider* selectedEp = static_cast(semaphoreEpMap->selectedEp); + if(selectedEp){ + auto* execution_provider = const_cast(selectedEp); + auto status = execution_provider->SetupInteropEpWait(semaphoreEpMap->extSemFence, stream, fenceValue); + if(status.IsOK()) { + return nullptr; + } + if (status.Code() != onnxruntime::common::StatusCode::NOT_IMPLEMENTED) { + return ToOrtStatus(status); + } + } + + return OrtApis::CreateStatus(ORT_FAIL, "No execution provider is actively being used"); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::InteropEpSignal, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue) { + API_IMPL_BEGIN + auto* sptr_ptr = reinterpret_cast*>(ortFence); + auto* semaphoreEpMap = sptr_ptr->get(); // Get raw pointer for existing code + + if(semaphoreEpMap->extSemFence == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External Fence Semaphore is null."); + } + if(stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Stream is null."); + } + + const onnxruntime::IExecutionProvider* selectedEp = static_cast(semaphoreEpMap->selectedEp); + if(selectedEp){ + auto* execution_provider = const_cast(selectedEp); + auto status = execution_provider->SetupInteropEpSignal(OrtApis::GetEpApi(), semaphoreEpMap->extSemFence, stream, fenceValue); + if(status.IsOK()) { + return nullptr; + } + if (status.Code() != onnxruntime::common::StatusCode::NOT_IMPLEMENTED) { + return ToOrtStatus(status); + } + } + + return OrtApis::CreateStatus(ORT_FAIL, "No execution provider is actively being used"); + API_IMPL_END +} + ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* stream) { return stream->GetHandle(); } @@ -3588,6 +3731,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* /*ep_device*/, + _In_ const struct GraphicsInteropParams* /*graphicsInteropParams*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SetupGraphicsInteropForEpDevice is not supported in a minimal build."); + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* /*ep_device*/, _In_opt_ const OrtKeyValuePairs* /*stream_options*/, _Outptr_ OrtSyncStream** /*ort_stream*/) { @@ -3596,6 +3746,24 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetOrtFenceForGraphicsInterop is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::InteropEpWait, _In_ OrtFence* /*ortFence*/, _In_ OrtSyncStream* /*stream*/, _In_ uint64_t /*fenceValue*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "InteropEpWait is not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::InteropEpSignal, _In_ OrtFence* /*ortFence*/, _In_ OrtSyncStream* /*stream*/, _In_ uint64_t /*fenceValue*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "InteropEpSignal is not supported in a minimal build."); + API_IMPL_END +} + ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* /*stream*/) { fprintf(stderr, "OrtSyncStream is not supported in a minimal build.\n"); return nullptr; @@ -4231,6 +4399,11 @@ static constexpr OrtApi ort_api_1_to_24 = { // End of Version 23 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::TensorTypeAndShape_HasShape, + + &OrtApis::SetupGraphicsInteropForEpDevice, + &OrtApis::GetOrtFenceForGraphicsInterop, + &OrtApis::InteropEpWait, + &OrtApis::InteropEpSignal, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -4267,6 +4440,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); +static_assert(offsetof(OrtApi, InteropEpSignal) / sizeof(void*) == 394, "Size of version 24 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.24.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f016bb3215330..2104c2530798b 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -751,4 +751,11 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, _In_reads_(num_tensors) OrtValue* const* dst_tensors, _In_opt_ OrtSyncStream* stream, _In_ size_t num_tensors); + +ORT_API_STATUS_IMPL(SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_ const struct GraphicsInteropParams* graphicsInteropParams); + +ORT_API_STATUS_IMPL(GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence); +ORT_API_STATUS_IMPL(InteropEpWait, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue); +ORT_API_STATUS_IMPL(InteropEpSignal, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue); } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 364bab471ddbe..5709f4c29cb03 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -31,6 +31,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator; OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; OrtEpFactory::IsStreamAware = Forward::IsStreamAware; + OrtEpFactory::SetupGraphicsInterop = Forward::SetupGraphicsInterop; OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions; } diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index 6eb83a117fb63..88f8aba15b813 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -74,6 +74,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->IsStreamAware(); } + OrtStatus* SetupGraphicsInterop(_In_ const OrtMemoryDevice* memory_device, + _In_ const struct GraphicsInteropParams* graphicsInteropParams) noexcept { + return impl_->SetupGraphicsInterop(memory_device, graphicsInteropParams); + } + OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index de9e2d44431bf..4275c56cdd42c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -62,6 +62,12 @@ class EpFactoryInternalImpl { return false; } + virtual OrtStatus* SetupGraphicsInterop(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_ const struct GraphicsInteropParams* /*graphicsInteropParams*/) noexcept { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "SetupGraphicsInterop is not implemented for this EP factory."); + } + virtual OrtStatus* ValidateCompiledModelCompatibilityInfo( _In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 8c5ef526baba1..7b1dfd329247d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -59,6 +59,14 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.IsStreamAware(&ep_factory_); } + OrtStatus* SetupGraphicsInterop(const OrtMemoryDevice* memory_device, + const struct GraphicsInteropParams* graphicsInteropParams) noexcept override { + if (ep_factory_.SetupGraphicsInterop == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Graphics interop is not supported by this EP"); + } + return ep_factory_.SetupGraphicsInterop(&ep_factory_, memory_device, graphicsInteropParams); + } + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, const OrtKeyValuePairs* stream_options, OrtSyncStreamImpl** stream) noexcept override { diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 65c396181f0a7..51d63e192a86a 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -74,6 +74,12 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->IsStreamAware(); } + static OrtStatus* ORT_API_CALL SetupGraphicsInterop(_In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _In_ const struct GraphicsInteropParams* graphicsInteropParams) noexcept { + return static_cast(this_ptr)->SetupGraphicsInterop(memory_device, graphicsInteropParams); + } + static OrtStatus* ORT_API_CALL CreateSyncStreamForDevice(_In_ OrtEpFactory* this_ptr, _In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_ort_interop_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_ort_interop_test.cc new file mode 100644 index 0000000000000..82983adcfe830 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_ort_interop_test.cc @@ -0,0 +1,438 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/unittest_util/framework_test_utils.h" + +#include "test/util/include/scoped_env_vars.h" +#include "test/common/trt_op_test_utils.h" +#include "test/common/random_generator.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include +#include + +#if DX_FOR_INTEROP && _WIN32 +#include +#include +using Microsoft::WRL::ComPtr; +#endif + +using namespace std; +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; +namespace onnxruntime { + + #if DX_FOR_INTEROP && _WIN32 +void CreateD3D12Buffer(ID3D12Device* pDevice, const size_t size, ID3D12Resource** ppResource, D3D12_RESOURCE_STATES initState) +{ + D3D12_RESOURCE_DESC bufferDesc = {}; + bufferDesc.MipLevels = 1; + bufferDesc.Format = DXGI_FORMAT_UNKNOWN; + bufferDesc.Width = size; + bufferDesc.Height = 1; + bufferDesc.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS; + bufferDesc.DepthOrArraySize = 1; + bufferDesc.SampleDesc.Count = 1; + bufferDesc.SampleDesc.Quality = 0; + bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + + D3D12_HEAP_PROPERTIES heapProps = {}; + heapProps.Type = D3D12_HEAP_TYPE_DEFAULT; + heapProps.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; + heapProps.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; + heapProps.CreationNodeMask = 1; + heapProps.VisibleNodeMask = 1; + + HRESULT hr = pDevice->CreateCommittedResource( + &heapProps, + D3D12_HEAP_FLAG_NONE, + &bufferDesc, + initState, + nullptr, + IID_PPV_ARGS(ppResource)); + + if (FAILED(hr)) + { + GTEST_FAIL() << "Failed creating a D3D12 resource, HRESULT: 0x" << std::hex << hr; + } +} + +void CreateUploadBuffer(ID3D12Device* pDevice, const size_t size, ID3D12Resource** ppResource) +{ + D3D12_HEAP_PROPERTIES heapProps = {}; + heapProps.Type = D3D12_HEAP_TYPE_UPLOAD; + heapProps.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; + heapProps.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; + heapProps.CreationNodeMask = 1; + heapProps.VisibleNodeMask = 1; + + D3D12_RESOURCE_DESC bufferDesc = {}; + bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + bufferDesc.Alignment = 0; + bufferDesc.Width = size; + bufferDesc.Height = 1; + bufferDesc.DepthOrArraySize = 1; + bufferDesc.MipLevels = 1; + bufferDesc.Format = DXGI_FORMAT_UNKNOWN; + bufferDesc.SampleDesc.Count = 1; + bufferDesc.SampleDesc.Quality = 0; + bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + bufferDesc.Flags = D3D12_RESOURCE_FLAG_NONE; + + HRESULT hr = pDevice->CreateCommittedResource( + &heapProps, + D3D12_HEAP_FLAG_NONE, + &bufferDesc, + D3D12_RESOURCE_STATE_GENERIC_READ, + nullptr, + IID_PPV_ARGS(ppResource)); + if (FAILED(hr)) + { + GTEST_FAIL() << "Failed creating a D3D12 upload resource, HRESULT: 0x" << std::hex << hr; + } +} + +void CreateReadBackBuffer(ID3D12Device* pDevice, const size_t size, ID3D12Resource** ppResource) +{ + D3D12_HEAP_PROPERTIES heapProps = {}; + heapProps.Type = D3D12_HEAP_TYPE_READBACK; + heapProps.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN; + heapProps.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN; + heapProps.CreationNodeMask = 1; + heapProps.VisibleNodeMask = 1; + + D3D12_RESOURCE_DESC bufferDesc = {}; + bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER; + bufferDesc.Alignment = 0; + bufferDesc.Width = size; + bufferDesc.Height = 1; + bufferDesc.DepthOrArraySize = 1; + bufferDesc.MipLevels = 1; + bufferDesc.Format = DXGI_FORMAT_UNKNOWN; + bufferDesc.SampleDesc.Count = 1; + bufferDesc.SampleDesc.Quality = 0; + bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR; + bufferDesc.Flags = D3D12_RESOURCE_FLAG_NONE; + + HRESULT hr = pDevice->CreateCommittedResource( + &heapProps, + D3D12_HEAP_FLAG_NONE, + &bufferDesc, + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_PPV_ARGS(ppResource)); + if (FAILED(hr)) + { + GTEST_FAIL() << "Failed creating a D3D12 read back resource, HRESULT: 0x" << std::hex << hr; + } +} + + +void FlushAndWait(ID3D12Device* pDevice, ID3D12CommandQueue* pQueue) +{ + // Event and D3D12 Fence to manage CPU<->GPU sync (we want to keep 2 iterations in "flight") + HANDLE hEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr); + ComPtr pFence; + pDevice->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&pFence)); + + pQueue->Signal(pFence.Get(), 1); + pFence->SetEventOnCompletion(1, hEvent); + DWORD retVal = WaitForSingleObject(hEvent, INFINITE); + + CloseHandle(hEvent); + // ComPtr automatically releases pFence +} +#endif +namespace test { + +TEST(NvExecutionProviderTest, GraphicsORTInteropTest) { +#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 3 + PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx"); + std::string graph_name = "test"; + constexpr int image_dim = 1080; + + // Create a simple 1-input, 1-output Relu model + onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger()); + auto& graph = model.MainGraph(); + + ONNX_NAMESPACE::TypeProto tensor_type; + tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(image_dim); + tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(image_dim); + + auto& input_arg = graph.GetOrCreateNodeArg("input", &tensor_type); + auto& output_arg = graph.GetOrCreateNodeArg("output", &tensor_type); + graph.AddNode("relu_node", "Relu", "Relu operation", {&input_arg}, {&output_arg}); + + ASSERT_STATUS_OK(graph.Resolve()); + ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_name)); + +#if DX_FOR_INTEROP && _WIN32 + { + std::vector cpuInputHalf(3 * image_dim * image_dim); + std::vector cpuOutputHalf(3 * image_dim * image_dim); + using StreamUniquePtr = std::unique_ptr>; + + // Generate random data for input + { + RandomValueGenerator random{}; + std::vector shape{3, image_dim, image_dim}; + std::vector input_data = random.Uniform(shape, static_cast(0), static_cast(65535)); + memcpy(cpuInputHalf.data(), input_data.data(), cpuInputHalf.size() * sizeof(uint16_t)); + } // input_data is freed here + + + // set up d3d12 + ComPtr pDevice; + ComPtr pCommandQueue; + ComPtr pInput; + ComPtr pOutput; + ComPtr pUploadRes; + ComPtr pUploadResCorrupt; + ComPtr pDownloadRes; + ComPtr pUploadCommandList; + ComPtr pDownloadCommandList; + ComPtr pAllocatorCopy; + + uint64_t fenceValue = 0; + GraphicsInteropParams graphicsInteropParams; + graphicsInteropParams.extSyncPrimitive = ExternalSyncPrimitive_D3D12Fence; + graphicsInteropParams.DevicePtr.DXDeviceParams.pDevice = nullptr; + graphicsInteropParams.DevicePtr.DXDeviceParams.pCommandQueue = nullptr; + HANDLE sharedFenceHandle = nullptr; + FenceInteropParams fenceInteropParams; + fenceInteropParams.extSyncPrimitive = ExternalSyncPrimitive_D3D12Fence; + fenceInteropParams.FencePtr.pFence = nullptr; + OrtFence* ortFence = nullptr; + + HRESULT hr = D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&pDevice)); + if (FAILED(hr)) + { + GTEST_SKIP() << "Failed to create D3D12 device, HRESULT: 0x" << std::hex << hr << " - D3D12 may not be available on this system"; + } + graphicsInteropParams.DevicePtr.DXDeviceParams.pDevice = pDevice.Get(); + + D3D12_COMMAND_QUEUE_DESC queueDesc = {}; + queueDesc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE; + queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE; + hr = pDevice->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&pCommandQueue)); + if (FAILED(hr)) + { + GTEST_SKIP() << "Failed to create D3D12 command queue, HRESULT: 0x" << std::hex << hr << " - Command queue may not be available on this system"; + } + graphicsInteropParams.DevicePtr.DXDeviceParams.pCommandQueue = pCommandQueue.Get(); + + // Use ORT APIs to load the model + OrtApi const& ortApi = Ort::GetApi(); + Ort::SessionOptions sessionOptions; + sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL); + sessionOptions.DisableMemPattern(); + sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); + sessionOptions.AddConfigEntry("session.disable_cpu_ep_fallback", "1"); + ortApi.AddFreeDimensionOverrideByName(sessionOptions, "batch_size", 1); + + std::string trtLibPath = "onnxruntime_providers_nv_tensorrt_rtx.dll"; + std::wstring wideTrtLibPath = std::wstring(trtLibPath.begin(), trtLibPath.end()); + + OrtStatus* status = ortApi.RegisterExecutionProviderLibrary(*ort_env, "NvTensorRtRtx", wideTrtLibPath.c_str()); + if (status != nullptr) { + std::string error_message = ortApi.GetErrorMessage(status); + ortApi.ReleaseStatus(status); + FAIL() << "Failed to register EP library: " << error_message; + } + + sessionOptions.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); + + ComPtr pFence; + pDevice->CreateFence(fenceValue, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&pFence)); + fenceInteropParams.FencePtr.pFence = pFence.Get(); + pDevice->CreateSharedHandle(pFence.Get(), nullptr, GENERIC_ALL, nullptr, &sharedFenceHandle); + + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_ep_devices; + ortApi.GetEpDevices(*ort_env, &ep_devices, &num_ep_devices); + const OrtEpDevice* trt_ep_device = nullptr; + for (UINT i = 0; i < num_ep_devices; i++) + { + if (strcmp(ortApi.EpDevice_EpName(ep_devices[i]), "NvTensorRTRTXExecutionProvider") == 0) + { + trt_ep_device = ep_devices[i]; + break; + } + } + + // Must be called before other interop functions to create the context + ortApi.SetupGraphicsInteropForEpDevice(trt_ep_device, &graphicsInteropParams); + + // Create ORT stream - this will be created on the context we just set up + OrtSyncStream* stream = nullptr; + StreamUniquePtr stream_ptr; + ortApi.CreateSyncStreamForEpDevice(trt_ep_device, nullptr, &stream); + stream_ptr = StreamUniquePtr(stream, [ortApi](OrtSyncStream* stream) { ortApi.ReleaseSyncStream(stream); }); + + // Create IHV-agnostic memory info using hardware device vendor ID + OrtMemoryInfo* memory_info_agnostic = nullptr; + const OrtHardwareDevice* hw_device = ortApi.EpDevice_Device(trt_ep_device); + UINT vID = ortApi.HardwareDevice_VendorId(hw_device); + ortApi.CreateMemoryInfo_V2("Device_Agnostic", OrtMemoryInfoDeviceType_GPU, + /*vendor_id*/vID, /*device_id*/0, + OrtDeviceMemoryType_DEFAULT, /*default alignment*/0, + OrtArenaAllocator, &memory_info_agnostic); + + auto memory_info_cleanup = std::unique_ptr>( + memory_info_agnostic, + [&ortApi](OrtMemoryInfo* ptr) { + if (ptr) ortApi.ReleaseMemoryInfo(ptr); + } + ); + + char streamAddress[32]; + size_t stream_addr_val = reinterpret_cast(ortApi.SyncStream_GetHandle(stream)); + sprintf_s(streamAddress, "%llu", static_cast(stream_addr_val)); + const char* option_keys[] = { "user_compute_stream", "has_user_compute_stream" }; + const char* option_values[] = { streamAddress, "1" }; + for (size_t i = 0; i < num_ep_devices; i++) + { + if (strcmp(ortApi.EpDevice_EpName(ep_devices[i]), "CPUExecutionProvider") != 0) + ortApi.SessionOptionsAppendExecutionProvider_V2(sessionOptions, *ort_env, &ep_devices[i], 1, option_keys, option_values, 2); + } + + // default resources + CreateD3D12Buffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pInput.GetAddressOf(), D3D12_RESOURCE_STATE_COPY_DEST); + CreateD3D12Buffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pOutput.GetAddressOf(), D3D12_RESOURCE_STATE_COPY_SOURCE); + + // upload and download resources + CreateUploadBuffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pUploadRes.GetAddressOf()); + CreateUploadBuffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pUploadResCorrupt.GetAddressOf()); + + CreateReadBackBuffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pDownloadRes.GetAddressOf()); + + hr = pDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&pAllocatorCopy)); + + hr = pDevice->CreateCommandList(1, D3D12_COMMAND_LIST_TYPE_COMPUTE, pAllocatorCopy.Get(), NULL, IID_PPV_ARGS(&pUploadCommandList)); + + + // heavy GPU load for reproducing race condition + for (int i = 0; i < 1000; i++) + { + pUploadCommandList->CopyResource(pInput.Get(), pUploadResCorrupt.Get()); + + D3D12_RESOURCE_BARRIER barrier = {}; + barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV; + barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE; + barrier.UAV.pResource = nullptr; // This makes it a NULL UAV barrier + + pUploadCommandList->ResourceBarrier(1, &barrier); + } + + pUploadCommandList->CopyResource(pInput.Get(), pUploadRes.Get()); + pUploadCommandList->Close(); + + std::cerr << "Test completed successfully1" <CreateCommandList(1, D3D12_COMMAND_LIST_TYPE_COMPUTE, pAllocatorCopy.Get(), NULL, IID_PPV_ARGS(&pDownloadCommandList)); + pDownloadCommandList->CopyResource(pDownloadRes.Get(), pOutput.Get()); + pDownloadCommandList->Close(); + + Ort::Session session = Ort::Session(*ort_env, L"nv_execution_provider_test.onnx", sessionOptions); + + ortApi.GetOrtFenceForGraphicsInterop(session, &graphicsInteropParams, &fenceInteropParams, &ortFence); + + Ort::IoBinding ioBinding = Ort::IoBinding::IoBinding(session); + + Ort::AllocatorWithDefaultOptions allocator; + Ort::AllocatedStringPtr InputTensorName = session.GetInputNameAllocated(0, allocator); + Ort::AllocatedStringPtr OuptutTensorName = session.GetOutputNameAllocated(0, allocator); + + int64_t inputDim[] = { 1, 3, image_dim, image_dim }; + int64_t outputDim[] = { 1, 3, image_dim, image_dim }; + + // upload the input + void* pData; + pUploadRes->Map(0, nullptr, (void**)&pData); + memcpy(pData, cpuInputHalf.data(), cpuInputHalf.size() * sizeof(uint16_t)); + pUploadRes->Unmap(0, nullptr); + + // Upload corrupted data to test synchronization (should not affect the output) + void* pDataCorrupt; + pUploadResCorrupt->Map(0, nullptr, (void**)&pDataCorrupt); + std::fill_n((uint8_t*)pDataCorrupt, 3 * image_dim * image_dim * sizeof(uint16_t), 0xFF); + pUploadResCorrupt->Unmap(0, nullptr); + + // bind the resources using IHV-agnostic memory info but keep zero-copy external memory sharing + Ort::Value inputTensor(Ort::Value::CreateTensor(memory_info_agnostic, (void*)pInput->GetGPUVirtualAddress(), cpuInputHalf.size() * sizeof(uint16_t), inputDim, 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + Ort::Value outputTensor(Ort::Value::CreateTensor(memory_info_agnostic, (void*)pOutput->GetGPUVirtualAddress(), cpuOutputHalf.size() * sizeof(uint16_t), outputDim, 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)); + ioBinding.BindInput(InputTensorName.get(), inputTensor); + ioBinding.BindOutput(OuptutTensorName.get(), outputTensor); + ioBinding.SynchronizeInputs(); + + std::cerr << "Test completed successfully2" <ExecuteCommandLists(1, &pUploadCmdList); + + // make ORT wait for upload + pCommandQueue->Signal(pFence.Get(), fenceValue); + ortApi.InteropEpWait(ortFence, stream, fenceValue); // make ORT wait on the fence (on CUDA side internally) + + // run the model + Ort::RunOptions runOptions; + runOptions.AddConfigEntry("disable_synchronize_execution_providers", "1"); + session.Run(runOptions, ioBinding); + + fenceValue++; + // make DX wait for ORT + ortApi.InteropEpSignal(ortFence, stream, fenceValue); // signal from CUDA side (internally) + pCommandQueue->Wait(pFence.Get(), fenceValue); + + // download the output to cpu memory (again using DX) + ID3D12CommandList* pDownloadCmdList = pDownloadCommandList.Get(); + pCommandQueue->ExecuteCommandLists(1, &pDownloadCmdList); + FlushAndWait(pDevice.Get(), pCommandQueue.Get()); + } + + + std::cerr << "Test completed successfully3" <Map(0, nullptr, (void**)&pOutputData); + memcpy(cpuOutputHalf.data(), pOutputData, cpuOutputHalf.size() * sizeof(uint16_t)); + pDownloadRes->Unmap(0, nullptr); + + std::cerr << "First 50 elements of cpuInputHalf:\n"; + for (int i = 0; i < 50; i++) { + std::cerr << cpuInputHalf[i] << " "; + } + std::cerr << std::endl; + + std::cerr << "First 50 elements of cpuOutputHalf:\n"; + for (int i = 0; i < 50; i++) { + std::cerr << cpuOutputHalf[i] << " "; + } + std::cerr << std::endl; + + // ComPtr automatically handles cleanup via RAII + CloseHandle(sharedFenceHandle); + + std::cerr << "\nInference done. Check output image." < None: azure_group = parser.add_argument_group("Azure Execution Provider") azure_group.add_argument("--use_azure", action="store_true", help="Enable Azure EP.") + # --- DX for Interop Feature --- + dx_group = parser.add_argument_group("DX for Interop Feature") + dx_group.add_argument("--use_dx_for_interop", action="store_true", help="Enable DX for Interop feature.") + + # --- Vulkan for Interop Feature --- + vulkan_group = parser.add_argument_group("Vulkan for Interop Feature") + vulkan_group.add_argument("--use_vulkan_for_interop", action="store_true", help="Enable VULKAN for Interop feature.") def add_other_feature_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for other miscellaneous features."""