diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index a8e7d2cc9299a..9cf672da44337 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -244,6 +244,31 @@ option(onnxruntime_USE_AZURE "Build with azure inferencing support" OFF)
option(onnxruntime_USE_LOCK_FREE_QUEUE "Build with lock-free task queue for threadpool." OFF)
option(onnxruntime_FORCE_GENERIC_ALGORITHMS "Disable optimized arch-specific algorithms. Use only for testing and debugging generic algorithms." OFF)
+# DX for interop feature option
+option(onnxruntime_USE_DX_FOR_INTEROP "Build with the DX for Interop feature." OFF)
+
+if (onnxruntime_USE_DX_FOR_INTEROP)
+ add_compile_definitions(DX_FOR_INTEROP=1)
+else()
+ add_compile_definitions(DX_FOR_INTEROP=0)
+endif()
+
+# Vulkan for interop feature option
+find_package(Vulkan QUIET)
+option(onnxruntime_USE_VULKAN_FOR_INTEROP "Build with the Vulkan for Interop feature." OFF)
+
+if (onnxruntime_USE_VULKAN_FOR_INTEROP AND Vulkan_FOUND)
+ if (WIN32)
+ add_compile_definitions(VK_USE_PLATFORM_WIN32_KHR=1)
+ endif()
+ add_compile_definitions(VULKAN_FOR_INTEROP=1)
+else()
+ add_compile_definitions(VULKAN_FOR_INTEROP=0)
+ if (NOT Vulkan_FOUND)
+ message(STATUS "Vulkan not found. Vulkan interop disabled.")
+ endif()
+endif()
+
option(onnxruntime_USE_TENSORRT_INTERFACE "Build ONNXRuntime shared lib which is compatible with TensorRT EP interface" OFF)
option(onnxruntime_USE_NV_INTERFACE "Build ONNXRuntime shared lib which is compatible with NV EP interface" OFF)
option(onnxruntime_USE_CUDA_INTERFACE "Build ONNXRuntime shared lib which is compatible with Cuda EP interface" OFF)
@@ -1198,6 +1223,9 @@ function(onnxruntime_configure_target target_name)
set_target_properties(${target_name} PROPERTIES VS_USER_PROPS ${PROJECT_SOURCE_DIR}/EnableVisualStudioCodeAnalysis.props)
endif()
target_include_directories(${target_name} PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT})
+ if(Vulkan_FOUND)
+ target_include_directories(${target_name} PRIVATE ${Vulkan_INCLUDE_DIRS})
+ endif()
if (onnxruntime_ENABLE_TRAINING_OPS)
target_include_directories(${target_name} PRIVATE ${ORTTRAINING_ROOT})
endif()
diff --git a/cmake/onnxruntime_providers_nv.cmake b/cmake/onnxruntime_providers_nv.cmake
index e59463b6b91f1..5ec45a64e46bb 100644
--- a/cmake/onnxruntime_providers_nv.cmake
+++ b/cmake/onnxruntime_providers_nv.cmake
@@ -146,9 +146,9 @@ endif ()
target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE Eigen3::Eigen onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface Eigen3::Eigen)
add_dependencies(onnxruntime_providers_nv_tensorrt_rtx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
- target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
+ target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
else()
- target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart)
+ target_link_libraries(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart CUDA::cuda_driver)
endif()
target_include_directories(onnxruntime_providers_nv_tensorrt_rtx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${TENSORRT_RTX_INCLUDE_DIR} ${onnx_tensorrt_SOURCE_DIR}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 1ae7b5c9eb991..20351450906c6 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -451,6 +451,10 @@ public struct OrtApi
public IntPtr Graph_GetModelMetadata;
public IntPtr GetModelCompatibilityForEpDevices;
public IntPtr CreateExternalInitializerInfo;
+
+ public IntPtr GetOrtFenceForGraphicsInterop;
+ public IntPtr InteropEpWait;
+ public IntPtr InteropEpSignal;
}
internal static class NativeMethods
@@ -482,7 +486,7 @@ static NativeMethods()
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(OrtGetApiBase().GetApi, typeof(DOrtGetApi));
#endif
- const uint ORT_API_VERSION = 14;
+ const uint ORT_API_VERSION = 15;
#if NETSTANDARD2_0
IntPtr ortApiPtr = OrtGetApi(ORT_API_VERSION);
api_ = (OrtApi)Marshal.PtrToStructure(ortApiPtr, typeof(OrtApi));
@@ -847,7 +851,7 @@ static NativeMethods()
api_.CreateSyncStreamForEpDevice,
typeof(DOrtCreateSyncStreamForEpDevice));
- OrtSyncStream_GetHandle =
+ OrtSyncStream_GetHandle =
(DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer(
api_.SyncStream_GetHandle,
typeof(DOrtSyncStream_GetHandle));
@@ -861,6 +865,21 @@ static NativeMethods()
(DOrtCopyTensors)Marshal.GetDelegateForFunctionPointer(
api_.CopyTensors,
typeof(DOrtCopyTensors));
+
+ OrtGetOrtFenceForGraphicsInterop =
+ (DOrtGetOrtFenceForGraphicsInterop)Marshal.GetDelegateForFunctionPointer(
+ api_.GetOrtFenceForGraphicsInterop,
+ typeof(DOrtGetOrtFenceForGraphicsInterop));
+
+ OrtInteropEpWait =
+ (DOrtInteropEpWait)Marshal.GetDelegateForFunctionPointer(
+ api_.InteropEpWait,
+ typeof(DOrtInteropEpWait));
+
+ OrtInteropEpSignal =
+ (DOrtInteropEpSignal)Marshal.GetDelegateForFunctionPointer(
+ api_.InteropEpSignal,
+ typeof(DOrtInteropEpSignal));
}
internal class NativeLib
@@ -2644,7 +2663,7 @@ public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
byte[] /* const char* */ value);
///
- /// Get the value for the provided key.
+ /// Get the value for the provided key.
///
/// Value. Returns IntPtr.Zero if key was not found.
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
@@ -2743,6 +2762,30 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
out IntPtr /* OrtSyncStream** */ stream
);
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtGetOrtFenceForGraphicsInterop(
+ IntPtr /* OrtSession* */ session,
+ IntPtr /* struct GraphicsInteropParams* */ graphicsInteropParams,
+ IntPtr /* struct FenceInteropParams* */ fenceInteropParams,
+ out IntPtr /* void** */ ortFence
+ );
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtInteropEpWait(
+ IntPtr /* OrtSession* */ session,
+ IntPtr /* void* */ ortFence,
+ IntPtr /* OrtSyncStream* */ stream,
+ uint /* uint64_t */ fenceValue
+ );
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtInteropEpSignal(
+ IntPtr /* OrtSession* */ session,
+ IntPtr /* void* */ ortFence,
+ IntPtr /* OrtSyncStream* */ stream,
+ uint /* uint64_t */ fenceValue
+ );
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* void* */ DOrtSyncStream_GetHandle(
IntPtr /* OrtSyncStream* */ stream
@@ -2760,6 +2803,9 @@ out IntPtr /* OrtSyncStream** */ stream
public static DOrtEpDevice_Device OrtEpDevice_Device;
public static DOrtEpDevice_MemoryInfo OrtEpDevice_MemoryInfo;
public static DOrtCreateSyncStreamForEpDevice OrtCreateSyncStreamForEpDevice;
+ public static DOrtGetOrtFenceForGraphicsInterop OrtGetOrtFenceForGraphicsInterop;
+ public static DOrtInteropEpWait OrtInteropEpWait;
+ public static DOrtInteropEpSignal OrtInteropEpSignal;
public static DOrtSyncStream_GetHandle OrtSyncStream_GetHandle;
public static DOrtReleaseSyncStream OrtReleaseSyncStream;
@@ -2767,7 +2813,7 @@ out IntPtr /* OrtSyncStream** */ stream
// Auto Selection EP registration and selection customization
///
- /// Register an execution provider library.
+ /// Register an execution provider library.
/// The library must implement CreateEpFactories and ReleaseEpFactory.
///
/// Environment to add the EP library to.
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index f54f4a5a6f1ef..a2bf333444107 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -38,6 +38,14 @@ class GraphOptimizerRegistry;
#include "core/framework/tuning_context.h"
#include "core/session/onnxruntime_c_api.h"
+#if DX_FOR_INTEROP && _WIN32
+#include
+#endif
+
+#if VULKAN_FOR_INTEROP
+#include
+#endif
+
struct OrtEpDevice;
struct OrtRunOptions;
@@ -92,6 +100,86 @@ class IExecutionProvider {
public:
virtual ~IExecutionProvider() = default;
+ virtual Status GetExtSemaphore(const struct GraphicsInteropParams* graphicsInteropParams, const struct FenceInteropParams* fenceInteropParams, void** extSemFence) {
+ auto interop_params_sptr = std::make_shared(*fenceInteropParams);
+ *extSemFence = new std::shared_ptr(interop_params_sptr);
+ ORT_UNUSED_PARAMETER(graphicsInteropParams);
+ return Status::OK();
+ }
+
+ virtual Status SetupInteropEpWait(void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) {
+ ORT_UNUSED_PARAMETER(stream);
+ auto* sptr_ptr = static_cast*>(extSemFence);
+ std::shared_ptr interopWaitParamsSptr = *sptr_ptr;
+ delete sptr_ptr;
+
+ auto* interopWaitParams = interopWaitParamsSptr.get();
+
+ ExternalSyncPrimitive extSyncPrimitive = interopWaitParams->extSyncPrimitive;
+ // to-do: The fallback logic needs more refinement to deal with multi threaded scenarios.
+ if (extSyncPrimitive == ExternalSyncPrimitive_D3D12Fence) {
+#if DX_FOR_INTEROP && _WIN32
+ HANDLE hEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
+ reinterpret_cast(interopWaitParams->FencePtr.pFence)->SetEventOnCompletion(fenceValue, hEvent);
+ WaitForSingleObject(hEvent, INFINITE);
+ CloseHandle(hEvent);
+ return Status::OK();
+#endif
+ }
+ else if(extSyncPrimitive == ExternalSyncPrimitive_VulkanSemaphore)
+ {
+#if VULKAN_FOR_INTEROP
+ PFN_vkWaitForFences pfnVkWaitForFences = reinterpret_cast(
+ reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkGetDeviceProcAddr)(
+ reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), "vkWaitForFences"));
+
+ if (!pfnVkWaitForFences) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get function pointer for vkWaitForFences");
+ }
+ VkResult result = pfnVkWaitForFences(reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), 1, reinterpret_cast(&interopWaitParams->FencePtr.pVkFence), VK_TRUE, UINT64_MAX);
+
+ if (result != VK_SUCCESS) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "vkWaitForFences failed with Vulkan error code: " + std::to_string(result));
+ }
+
+ PFN_vkResetFences pfnVkResetFences = reinterpret_cast(
+ reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkGetDeviceProcAddr)(
+ reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), "vkResetFences"));
+ if (pfnVkResetFences) {
+ pfnVkResetFences(reinterpret_cast(interopWaitParams->VulkanDeviceParams.pVkDevice), 1, reinterpret_cast(&interopWaitParams->FencePtr.pVkFence));
+ }
+
+ return Status::OK();
+#endif
+ }
+ ORT_UNUSED_PARAMETER(fenceValue);
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported External Sync primitive");
+ }
+ virtual Status SetupInteropEpSignal(const OrtEpApi* ortEpApi, void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) {
+ ORT_UNUSED_PARAMETER(extSemFence);
+ ORT_UNUSED_PARAMETER(fenceValue);
+
+ const OrtSyncStreamImpl* streamImpl;
+ OrtSyncNotificationImpl* streamNotification;
+ streamImpl = ortEpApi->SyncStream_GetImpl(static_cast(stream));
+
+ OrtStatus* status = nullptr;
+ status = streamImpl->CreateNotification(const_cast(streamImpl), &streamNotification);
+ if(status != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create notification");
+ }
+
+ status = streamNotification->Activate(streamNotification);
+ if(status != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to activate notification");
+ }
+ status = streamNotification->WaitOnHost(streamNotification);
+ if(status != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait on host");
+ }
+ return Status::OK();
+ }
+
/**
* Returns a data transfer object that implements methods to copy to and
* from this device.
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 434aa075e62d6..2480a31d690f2 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -329,6 +329,7 @@ ORT_RUNTIME_CLASS(HardwareDevice);
ORT_RUNTIME_CLASS(EpDevice);
ORT_RUNTIME_CLASS(KeyValuePairs);
ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream.
+ORT_RUNTIME_CLASS(Fence);
ORT_RUNTIME_CLASS(ExternalInitializerInfo);
#ifdef _MSC_VER
@@ -507,6 +508,43 @@ typedef enum OrtExecutionProviderDevicePolicy {
OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER,
} OrtExecutionProviderDevicePolicy;
+typedef enum ExternalSyncPrimitive {
+ ExternalSyncPrimitive_D3D12Fence,
+ ExternalSyncPrimitive_VulkanSemaphore, // Vulkan timeline semaphore
+} ExternalSyncPrimitive;
+
+typedef struct VulkanDeviceParams {
+ void* pVkDevice;
+ void* pVkGetDeviceProcAddr;
+} VulkanDeviceParams;
+
+typedef struct GraphicsInteropParams {
+ ExternalSyncPrimitive extSyncPrimitive;
+ union DevicePtr {
+ struct DXDeviceParams {
+ void* pDevice;
+ void* pCommandQueue;
+ } DXDeviceParams;
+ VulkanDeviceParams VulkanDeviceParams;
+ } DevicePtr;
+
+} GraphicsInteropParams;
+
+typedef struct FenceInteropParams {
+ ExternalSyncPrimitive extSyncPrimitive;
+ union FencePtr {
+ void* pFence;
+ void* pVkFence;
+ void* VkSemaphore;
+ } FencePtr;
+ VulkanDeviceParams VulkanDeviceParams;
+} FenceInteropParams;
+
+typedef struct SemaphoreEpMap {
+ void* extSemFence;
+ void* selectedEp;
+} SemaphoreEpMap;
+
/** \brief Delegate to allow providing custom OrtEpDevice selection logic
*
* This delegate is called by the EP selection code to allow the user to provide custom device selection logic.
@@ -6590,6 +6628,99 @@ struct OrtApi {
* \since Version 1.24
*/
ORT_API_T(bool, TensorTypeAndShape_HasShape, _In_ const OrtTensorTypeAndShapeInfo* info);
+
+ /** \brief Setup Graphics Interopcontext for an execution provider device.
+ *
+ * This function enables D3D12/Vulkan interoperability with a graphics API command queue/device. Once setup, any OrtSyncStream
+ * created for this ep_device via CreateSyncStreamForEpDevice will be created, enabling efficient GPU-side synchronization.
+ *
+ * This must be called BEFORE CreateSyncStreamForEpDevice for the same ep_device.
+ *
+ * \param[in] ep_device The OrtEpDevice to setup Graphics interop for.
+ * \param[in] graphicsInteropParams Pointer to struct containing D3D12 command queue or Vulkan device info.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.24
+ */
+ ORT_API2_STATUS(SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device,
+ _In_ const struct GraphicsInteropParams* graphicsInteropParams);
+
+ /**
+ * \brief Create an OrtFence wrapper for an external GPU fence/semaphore to enable graphics interop.
+ *
+ * This function creates an OrtFence object that wraps an external synchronization primitive (D3D12 fence or Vulkan timeline semaphore)
+ * provided by the caller. The resulting OrtFence can be used with InteropEpWait and InteropEpSignal to synchronize ONNX Runtime
+ * computation with external graphics APIs like D3D12 or Vulkan.
+ *
+ * This enables integration between ONNX Runtime GPU computation and external graphics or compute pipelines, allowing efficient
+ * GPU-side synchronization between ONNX Runtime inference and external graphics workloads.
+ *
+ * Prerequisites:
+ * - SetupGraphicsInteropForEpDevice must be called first for the relevant ep_device.
+ * - The external fence/semaphore provided must be compatible with the graphics API specified in graphicsInteropParams.
+ *
+ * \param[in] session An OrtSession instance whose execution providers participate in the graphics interop.
+ * \param[in] graphicsInteropParams Pointer to a struct specifying the graphics API type (D3D12 or Vulkan) and device information.
+ * \param[in] fenceInteropParams Pointer to a struct containing the external fence/semaphore handle to be wrapped.
+ * \param[out] ortFence Pointer to receive the created OrtFence object. This OrtFence wraps the external synchronization primitive
+ * and can be used with InteropEpWait and InteropEpSignal APIs.
+ *
+ * \retval ORT_OK On success.
+ * \retval ORT_NOT_IMPLEMENTED If none of the active providers support graphics-fence interop.
+ * \retval ORT_FAIL or provider-specific error if the operation fails for another reason.
+ *
+ * \since Version 1.24
+ */
+ ORT_API2_STATUS(GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence);
+
+ /**
+ * \brief Wait on a graphics interop external fence/semaphore using an ONNX Runtime execution provider.
+ *
+ * This function synchronizes ONNX Runtime computation with an external graphics or compute API by waiting on an external
+ * semaphore or fence. It is typically used in scenarios where GPU computation in ONNX Runtime must be synchronized with
+ * non-ONNX Runtime workloads, such as graphics pipelines (e.g., Direct3D 12 or Vulkan).
+ *
+ * The external synchronization primitive (such as a fence or semaphore handle) referenced by `extSemFence` should be
+ * obtained from a previous call to GetOrtFenceForGraphicsInterop.
+ *
+ * The implementation and support for this is execution provider-specific.
+ *
+ * \param[in] extSemFence The handle to the external synchronization primitive, as returned from GetOrtFenceForGraphicsInterop.
+ * \param[in] stream The OrtSyncStream instance on which the synchronization will be performed.
+ * \param[in] fenceValue The fence value for synchronization (if required for the specific graphics API/interop scenario).
+ *
+ * \retval ORT_OK On successful synchronization.
+ * \retval ORT_NOT_IMPLEMENTED If the current execution provider does not support graphics interop wait.
+ * \retval ORT_FAIL or provider-specific error if the synchronization operation fails for another reason.
+ *
+ * \since Version 1.24
+ */
+ ORT_API2_STATUS(InteropEpWait, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue);
+
+ /**
+ * \brief Signal a graphics interop external fence/semaphore using an ONNX Runtime execution provider.
+ *
+ * This function synchronizes external graphics or compute APIs with ONNX Runtime computation by signaling an external
+ * semaphore or fence. This is typically used when GPU computation in ONNX Runtime has completed and an external
+ * workload (such as a graphics pipeline) should continue execution.
+ *
+ * The external synchronization primitive (such as a fence or semaphore handle) referenced by `extSemFence` should
+ * have been obtained from a previous call to GetOrtFenceForGraphicsInterop.
+ *
+ * The behavior and support for this operation is execution provider-specific.
+ *
+ * \param[in] extSemFence The handle to the external synchronization primitive, as returned from GetOrtFenceForGraphicsInterop.
+ * \param[in] stream The OrtSyncStream on which the signal operation should occur.
+ * \param[in] fenceValue The fence value to signal (if required for the specific graphics API or interop scenario).
+ *
+ * \retval ORT_OK On successful signal operation.
+ * \retval ORT_NOT_IMPLEMENTED If the current execution provider does not support graphics interop signaling.
+ * \retval ORT_FAIL or provider-specific error if signaling fails for another reason.
+ *
+ * \since Version 1.24
+ */
+ ORT_API2_STATUS(InteropEpSignal, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h
index de38085914516..e15fe2dd69857 100644
--- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h
@@ -1010,6 +1010,7 @@ struct OrtEpFactory {
* This is used to create a synchronization stream for the memory device that can be used for operations outside of
* a session.
*
+ *
* \param[in] this_ptr The OrtEpFactory instance.
* \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for.
* \param[in] stream_options Options for stream creation. May be nullptr.
@@ -1052,6 +1053,24 @@ struct OrtEpFactory {
* \since Version 1.24.
*/
ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options);
+
+ /** \brief Setup Graphics interop for a memory device.
+ *
+ * This function sets up Graphics interop associated with a graphics API (D3D12/Vulkan) command queue/device.
+ *
+ * Optional - EP factories that don't support Graphics interop setup should set this to nullptr.
+ *
+ * \param[in] this_ptr The OrtEpFactory instance.
+ * \param[in] memory_device The OrtMemoryDevice to setup Graphics interop for.
+ * \param[in] graphicsInteropParams Graphics API parameters (D3D12 command queue or Vulkan device info).
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.24.
+ */
+ ORT_API2_STATUS(SetupGraphicsInterop, _In_ OrtEpFactory* this_ptr,
+ _In_ const OrtMemoryDevice* memory_device,
+ _In_ const struct GraphicsInteropParams* graphicsInteropParams);
};
#ifdef __cplusplus
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
index e2a8005aba1da..110d14fcda3e0 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc
@@ -33,6 +33,13 @@
#include
// TODO: find a better way to share this
#include "core/providers/cuda/cuda_stream_handle.h"
+#include "core/providers/cuda/cuda_common.h"
+
+#if VULKAN_FOR_INTEROP
+#if _WIN32
+typedef VkResult (VKAPI_PTR *PFN_vkGetSemaphoreWin32HandleKHR)(VkDevice device, const VkSemaphoreGetWin32HandleInfoKHR* pGetWin32HandleInfo, HANDLE* pHandle);
+#endif
+#endif
#ifdef _WIN32
#include
@@ -104,6 +111,111 @@ static bool IsSupportedInputOutputDataType(ONNXTensorElementDataType data_type)
}
}
+Status NvExecutionProvider::GetExtSemaphore(const struct GraphicsInteropParams* graphicsInteropParams, const struct FenceInteropParams* fenceInteropParams, void** extSemFence)
+{
+ if (!info_.has_user_compute_stream)
+ {
+ (void)GetPerThreadContext();
+ }
+ cudaExternalSemaphore_t cSemFence;
+ cudaExternalSemaphoreHandleDesc semHandleDesc = {};
+
+ assert(graphicsInteropParams->extSyncPrimitive == fenceInteropParams->extSyncPrimitive &&
+ "ExternalSyncPrimitive mismatch between graphicsInteropParams and fenceInteropParams");
+ (void)graphicsInteropParams;
+
+ ExternalSyncPrimitive extSyncPrimitive = fenceInteropParams->extSyncPrimitive;
+
+ if(extSyncPrimitive == ExternalSyncPrimitive_D3D12Fence)
+ {
+#if DX_FOR_INTEROP && _WIN32
+ HANDLE sharedFenceHandle = nullptr;
+ if(reinterpret_cast(graphicsInteropParams->DevicePtr.DXDeviceParams.pDevice)->CreateSharedHandle(reinterpret_cast(fenceInteropParams->FencePtr.pFence), nullptr, GENERIC_ALL, nullptr, &sharedFenceHandle) != S_OK) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create shared handle for D3D12 fence");
+ }
+ semHandleDesc.type = cudaExternalSemaphoreHandleTypeD3D12Fence;
+ semHandleDesc.handle.win32.handle = sharedFenceHandle;
+#endif
+ }
+ else if(extSyncPrimitive == ExternalSyncPrimitive_VulkanSemaphore)
+ {
+#if VULKAN_FOR_INTEROP
+#if _WIN32
+ VkSemaphoreGetWin32HandleInfoKHR handleInfo = {};
+ handleInfo.sType = VK_STRUCTURE_TYPE_SEMAPHORE_GET_WIN32_HANDLE_INFO_KHR;
+ handleInfo.semaphore = reinterpret_cast(fenceInteropParams->FencePtr.VkSemaphore);
+ handleInfo.handleType = VK_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_BIT_KHR;
+
+ HANDLE sharedFenceHandle = nullptr;
+
+ // Get the function pointer for vkGetSemaphoreWin32HandleKHR
+ PFN_vkGetSemaphoreWin32HandleKHR pfnVkGetSemaphoreWin32HandleKHR = reinterpret_cast(
+ reinterpret_cast(graphicsInteropParams->DevicePtr.VulkanDeviceParams.pVkGetDeviceProcAddr)(
+ reinterpret_cast(graphicsInteropParams->DevicePtr.VulkanDeviceParams.pVkDevice), "vkGetSemaphoreWin32HandleKHR"));
+
+ if (pfnVkGetSemaphoreWin32HandleKHR &&
+ pfnVkGetSemaphoreWin32HandleKHR(reinterpret_cast(graphicsInteropParams->DevicePtr.VulkanDeviceParams.pVkDevice), &handleInfo, &sharedFenceHandle) == VK_SUCCESS)
+ {
+ semHandleDesc.type = cudaExternalSemaphoreHandleTypeOpaqueWin32;
+ semHandleDesc.handle.win32.handle = sharedFenceHandle;
+ }
+ else
+ {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to get shared semaphore handle");
+ }
+#endif
+#endif
+ }
+ if(cudaImportExternalSemaphore(&cSemFence, &semHandleDesc) != cudaSuccess)
+ {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to import external semaphore");
+ }
+#ifdef _WIN32
+ CloseHandle(semHandleDesc.handle.win32.handle);
+#endif
+ *extSemFence = cSemFence;
+
+ return Status::OK();
+}
+
+Status NvExecutionProvider::SetupInteropEpWait(void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue)
+{
+ LOGS_DEFAULT(VERBOSE) << "NvExecutionProvider::SetupInteropEpWait() called.";
+
+ // make CUDA wait for the upload by Graphics API to finish
+ cudaExternalSemaphoreWaitParams waitParams = {};
+ waitParams.params.fence.value = fenceValue;
+ cudaExternalSemaphore_t cSemFence = static_cast(extSemFence);
+ cudaStream_t cudaStream = static_cast(Ort::GetApi().SyncStream_GetHandle(stream));
+ cudaError_t result = cudaWaitExternalSemaphoresAsync(&cSemFence, &waitParams, 1, cudaStream);
+ if(result != cudaSuccess)
+ {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for external semaphore");
+ }
+
+ return Status::OK();
+}
+
+Status NvExecutionProvider::SetupInteropEpSignal(const OrtEpApi* ortEpApi, void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue)
+{
+ LOGS_DEFAULT(VERBOSE) << "NvExecutionProvider::SetupInteropEpSignal() called.";
+
+ cudaStream_t cudaStream = static_cast(Ort::GetApi().SyncStream_GetHandle(stream));
+
+ // make Graphics API wait for the CUDA kernel to finish
+ cudaExternalSemaphoreSignalParams signalParams = {};
+ signalParams.params.fence.value = fenceValue;
+ cudaExternalSemaphore_t cSemFence = static_cast(extSemFence);
+ cudaError_t result = cudaSignalExternalSemaphoresAsync(&cSemFence, &signalParams, 1, cudaStream);
+ if(result != cudaSuccess)
+ {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to signal external semaphore");
+ }
+
+ ORT_UNUSED_PARAMETER(ortEpApi);
+ return Status::OK();
+}
+
// Helper function to check if a data type is supported by NvTensorRTRTX EP
static bool IsSupportedDataType(ONNXTensorElementDataType data_type) {
switch (data_type) {
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
index bb8f687db094f..7be9feb48c8e6 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h
@@ -289,6 +289,10 @@ class NvExecutionProvider : public IExecutionProvider {
// explicit NvExecutionProvider(const ProviderOptions& provider_options_map, const ConfigOptions* config_options);
virtual ~NvExecutionProvider();
+ virtual Status GetExtSemaphore(const struct GraphicsInteropParams* graphicsInteropParams, const struct FenceInteropParams* fenceInteropParams, void** extSemFence) override;
+ virtual Status SetupInteropEpWait(void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) override;
+ virtual Status SetupInteropEpSignal(const OrtEpApi* ortEpApi, void* extSemFence, OrtSyncStream* stream, uint64_t fenceValue) override;
+
cublasHandle_t PerThreadDefaultCublasHandle() {
return GetPerThreadContext().CublasHandle();
}
diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
index c3fbccef84883..1635e987fb334 100644
--- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
+++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc
@@ -531,6 +531,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
CreateDataTransfer = CreateDataTransferImpl;
IsStreamAware = IsStreamAwareImpl;
+ SetupGraphicsInterop = SetupCigContextImpl;
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl;
ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with.
@@ -706,6 +707,62 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
return true;
}
+ static OrtStatus* ORT_API_CALL SetupCigContextImpl(OrtEpFactory* this_ptr,
+ const OrtMemoryDevice* memory_device,
+ const struct GraphicsInteropParams* graphicsInteropParams) noexcept {
+ auto& factory = *static_cast(this_ptr);
+ auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device);
+
+ // Initialize CUDA
+ CUresult result = cuInit(0);
+ if (result != CUDA_SUCCESS) {
+ return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to initialize CUDA for CIG context creation");
+ }
+
+ // Get LUID from CUDA device
+ cudaDeviceProp cuda_prop;
+ cudaError_t cuda_err = cudaGetDeviceProperties(&cuda_prop, device_id);
+ if (cuda_err != cudaSuccess) {
+ return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to get CUDA device properties");
+ }
+
+ // Create CIG context based on graphics API type
+ CUcontext cig_context = nullptr;
+
+ if (graphicsInteropParams->extSyncPrimitive == ExternalSyncPrimitive_D3D12Fence) {
+#if DX_FOR_INTEROP && _WIN32
+ // Get LUID of memory device and D3D12 device and compare it to that of memory device
+ uint64_t cuda_luid = *reinterpret_cast(cuda_prop.luid);
+ LUID d3d12_luid = reinterpret_cast(graphicsInteropParams->DevicePtr.DXDeviceParams.pDevice)->GetAdapterLuid();
+ uint64_t d3d12_luid_64 = (static_cast(d3d12_luid.HighPart) << 32) | d3d12_luid.LowPart;
+ if (d3d12_luid_64 != cuda_luid) {
+ return factory.ort_api.CreateStatus(ORT_FAIL, "D3D12 device LUID does not match CUDA device LUID");
+ }
+
+ // Create CIG context bound to D3D12 command queue
+ CUctxCigParam ctxCigParams = { CIG_DATA_TYPE_D3D12_COMMAND_QUEUE, reinterpret_cast(graphicsInteropParams->DevicePtr.DXDeviceParams.pCommandQueue) };
+ CUctxCreateParams ctxParams = { nullptr, 0, &ctxCigParams };
+
+ result = cuCtxCreate_v4(&cig_context, &ctxParams, 0, device_id);
+ if (result != CUDA_SUCCESS) {
+ return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to create CIG context for D3D12");
+ }
+#else
+ return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "CIG context creation not supported on this platform");
+#endif
+ } else if (graphicsInteropParams->extSyncPrimitive == ExternalSyncPrimitive_VulkanSemaphore) {
+ // TODO: Add Vulkan CIG context support if needed
+ return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Vulkan CIG context not yet implemented");
+ } else {
+ return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Unsupported graphics API for CIG context");
+ }
+
+ // Store the CIG context for this device
+ factory.cig_contexts_[device_id] = cig_context;
+
+ return nullptr;
+ }
+
static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr,
const OrtMemoryDevice* memory_device,
const OrtKeyValuePairs* /*stream_options*/,
@@ -714,8 +771,23 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device);
cudaStream_t stream = nullptr;
- CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id));
- CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
+
+ // Check if we have a CIG context for this device
+ auto cig_it = factory.cig_contexts_.find(device_id);
+ if (cig_it != factory.cig_contexts_.end()) {
+ // We have a CIG context - make it current and create stream on it
+ CUresult result = cuCtxSetCurrent(cig_it->second);
+ if (result != CUDA_SUCCESS) {
+ return factory.ort_api.CreateStatus(ORT_FAIL, "Failed to set CIG context current");
+ }
+
+ // Create stream on the CIG context
+ CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
+ } else {
+ // No CIG context - use default behavior
+ CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id));
+ CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
+ }
const OrtDevice* ort_device = static_cast(memory_device);
@@ -770,6 +842,9 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory {
// we use a shared instance for the OrtDataTransferImpl instead of creating a new one on every call to
NvTrtRtxDataTransferImpl data_transfer_impl;
+ // Map to store CIG context per device ID (for D3D12/Vulkan interop)
+ std::unordered_map cig_contexts_;
+
NvTensorRtRtxEpFactory(const NvTensorRtRtxEpFactory&) = delete;
NvTensorRtRtxEpFactory& operator=(const NvTensorRtRtxEpFactory&) = delete;
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index 14f0892687ad1..3d8267e5e2e1f 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -730,6 +730,12 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const
#endif // !defined(ORT_MINIMAL_BUILD)
InferenceSession::~InferenceSession() {
+
+ for (auto* sptr_ptr : ort_fences_for_cleanup_) {
+ delete sptr_ptr;
+ }
+ ort_fences_for_cleanup_.clear();
+
if (session_options_.enable_profiling) {
ORT_TRY {
EndProfiling();
diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h
index 8bea15c169ed4..61c7c611f87e6 100644
--- a/onnxruntime/core/session/inference_session.h
+++ b/onnxruntime/core/session/inference_session.h
@@ -653,6 +653,14 @@ class InferenceSession {
return session_id_;
}
+ void RegisterOrtFenceForCleanup(std::shared_ptr* ortFence) {
+ if(!ortFence) {
+ ORT_THROW("Cannot register null fence for cleanup");
+ }
+ std::lock_guard lock(ort_fences_mutex_);
+ ort_fences_for_cleanup_.push_back(ortFence);
+ }
+
protected:
#if !defined(ORT_MINIMAL_BUILD)
@@ -1045,6 +1053,8 @@ class InferenceSession {
// Enable nodestats collection
std::optional node_stats_recorder_;
#endif
+ mutable std::mutex ort_fences_mutex_;
+ std::vector*> ort_fences_for_cleanup_;
};
struct SessionIOBinding {
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 546b11ae580d5..1663730719a91 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -3366,6 +3366,43 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession*
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtApis::SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device,
+ _In_ const struct GraphicsInteropParams* graphicsInteropParams) {
+ API_IMPL_BEGIN
+ if (ep_device == nullptr || graphicsInteropParams == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and graphicsInteropParams must be provided.");
+ }
+
+ if (ep_device->device_memory_info == nullptr)
+ {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device must have valid device_memory_info.");
+ }
+ const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr;
+
+ if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not use DEFAULT memory of a non-CPU device.");
+ }
+
+ const auto* factory = ep_device->ep_factory;
+ if (!factory->IsStreamAware(factory)) {
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support streams.");
+ }
+
+ // Check if the factory supports Graphics interop setup
+ if (factory->SetupGraphicsInterop == nullptr) {
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support Graphics interop setup.");
+ }
+
+ // Call the EP factory to setup Graphics interop
+ ORT_API_RETURN_IF_ERROR(factory->SetupGraphicsInterop(ep_device->GetMutableFactory(),
+ static_cast(device), // alias
+ graphicsInteropParams));
+
+ return nullptr;
+
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device,
_In_opt_ const OrtKeyValuePairs* stream_options,
_Outptr_ OrtSyncStream** ort_stream) {
@@ -3374,6 +3411,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and stream must be provided.");
}
+ if (ep_device->device_memory_info == nullptr)
+ {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device must have valid device_memory_info.");
+ }
const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr;
if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) {
@@ -3407,6 +3448,108 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtApis::GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence) {
+ API_IMPL_BEGIN
+ auto* inference_session = reinterpret_cast(session);
+ if (!inference_session) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Session is null");
+ }
+
+ auto semaphore_ep_map_sptr = std::make_shared();
+ semaphore_ep_map_sptr->extSemFence = nullptr;
+ semaphore_ep_map_sptr->selectedEp = nullptr;
+
+ const auto& session_state = inference_session->GetSessionState();
+ const auto& execution_providers = session_state.GetExecutionProviders();
+ const auto& graph_viewer = session_state.GetGraphViewer();
+
+ // Collect the unique set of execution providers assigned to nodes in the graph.
+ std::unordered_set active_provider_types;
+ for (const auto& node : graph_viewer.Nodes()) {
+ if (!node.GetExecutionProviderType().empty()) {
+ active_provider_types.insert(node.GetExecutionProviderType());
+ }
+ }
+
+ // Call GetExtSemaphore only for the providers that are actively being used.
+ for (const auto& provider_type : active_provider_types) {
+ const onnxruntime::IExecutionProvider* const_provider = execution_providers.Get(provider_type);
+ if (const_provider) {
+ auto* provider = const_cast(const_provider);
+ auto status = provider->GetExtSemaphore(graphicsInteropParams, fenceInteropParams, &semaphore_ep_map_sptr->extSemFence);
+ if(status.IsOK()) {
+ semaphore_ep_map_sptr->selectedEp = provider;
+ auto* wrapper = new std::shared_ptr(semaphore_ep_map_sptr);
+ inference_session->RegisterOrtFenceForCleanup(wrapper);
+ *ortFence = reinterpret_cast(wrapper);
+ return nullptr;
+ }
+ if (status.Code() != onnxruntime::common::StatusCode::NOT_IMPLEMENTED) {
+ return ToOrtStatus(status);
+ }
+ }
+ }
+
+ return OrtApis::CreateStatus(ORT_FAIL, "No active execution provider returned a semaphore for the given session");
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::InteropEpWait, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue) {
+ API_IMPL_BEGIN
+ auto* sptr_ptr = reinterpret_cast*>(ortFence);
+ auto* semaphoreEpMap = sptr_ptr->get();
+
+ if(semaphoreEpMap->extSemFence == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External Fence Semaphore is null.");
+ }
+ if(stream == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Stream is null.");
+ }
+
+ const onnxruntime::IExecutionProvider* selectedEp = static_cast(semaphoreEpMap->selectedEp);
+ if(selectedEp){
+ auto* execution_provider = const_cast(selectedEp);
+ auto status = execution_provider->SetupInteropEpWait(semaphoreEpMap->extSemFence, stream, fenceValue);
+ if(status.IsOK()) {
+ return nullptr;
+ }
+ if (status.Code() != onnxruntime::common::StatusCode::NOT_IMPLEMENTED) {
+ return ToOrtStatus(status);
+ }
+ }
+
+ return OrtApis::CreateStatus(ORT_FAIL, "No execution provider is actively being used");
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::InteropEpSignal, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue) {
+ API_IMPL_BEGIN
+ auto* sptr_ptr = reinterpret_cast*>(ortFence);
+ auto* semaphoreEpMap = sptr_ptr->get(); // Get raw pointer for existing code
+
+ if(semaphoreEpMap->extSemFence == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "External Fence Semaphore is null.");
+ }
+ if(stream == nullptr) {
+ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Stream is null.");
+ }
+
+ const onnxruntime::IExecutionProvider* selectedEp = static_cast(semaphoreEpMap->selectedEp);
+ if(selectedEp){
+ auto* execution_provider = const_cast(selectedEp);
+ auto status = execution_provider->SetupInteropEpSignal(OrtApis::GetEpApi(), semaphoreEpMap->extSemFence, stream, fenceValue);
+ if(status.IsOK()) {
+ return nullptr;
+ }
+ if (status.Code() != onnxruntime::common::StatusCode::NOT_IMPLEMENTED) {
+ return ToOrtStatus(status);
+ }
+ }
+
+ return OrtApis::CreateStatus(ORT_FAIL, "No execution provider is actively being used");
+ API_IMPL_END
+}
+
ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* stream) {
return stream->GetHandle();
}
@@ -3588,6 +3731,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession*
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtApis::SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* /*ep_device*/,
+ _In_ const struct GraphicsInteropParams* /*graphicsInteropParams*/) {
+ API_IMPL_BEGIN
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SetupGraphicsInteropForEpDevice is not supported in a minimal build.");
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* /*ep_device*/,
_In_opt_ const OrtKeyValuePairs* /*stream_options*/,
_Outptr_ OrtSyncStream** /*ort_stream*/) {
@@ -3596,6 +3746,24 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtApis::GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence) {
+ API_IMPL_BEGIN
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetOrtFenceForGraphicsInterop is not supported in a minimal build.");
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::InteropEpWait, _In_ OrtFence* /*ortFence*/, _In_ OrtSyncStream* /*stream*/, _In_ uint64_t /*fenceValue*/) {
+ API_IMPL_BEGIN
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "InteropEpWait is not supported in a minimal build.");
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::InteropEpSignal, _In_ OrtFence* /*ortFence*/, _In_ OrtSyncStream* /*stream*/, _In_ uint64_t /*fenceValue*/) {
+ API_IMPL_BEGIN
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "InteropEpSignal is not supported in a minimal build.");
+ API_IMPL_END
+}
+
ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* /*stream*/) {
fprintf(stderr, "OrtSyncStream is not supported in a minimal build.\n");
return nullptr;
@@ -4231,6 +4399,11 @@ static constexpr OrtApi ort_api_1_to_24 = {
// End of Version 23 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::TensorTypeAndShape_HasShape,
+
+ &OrtApis::SetupGraphicsInteropForEpDevice,
+ &OrtApis::GetOrtFenceForGraphicsInterop,
+ &OrtApis::InteropEpWait,
+ &OrtApis::InteropEpSignal,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
@@ -4267,6 +4440,7 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz
static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change");
+static_assert(offsetof(OrtApi, InteropEpSignal) / sizeof(void*) == 394, "Size of version 24 API cannot change");
// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.24.0",
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index f016bb3215330..2104c2530798b 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -751,4 +751,11 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env,
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
_In_opt_ OrtSyncStream* stream,
_In_ size_t num_tensors);
+
+ORT_API_STATUS_IMPL(SetupGraphicsInteropForEpDevice, _In_ const OrtEpDevice* ep_device,
+ _In_ const struct GraphicsInteropParams* graphicsInteropParams);
+
+ORT_API_STATUS_IMPL(GetOrtFenceForGraphicsInterop, _In_ OrtSession* session, _In_ const struct GraphicsInteropParams* graphicsInteropParams, _In_ const struct FenceInteropParams* fenceInteropParams, _Outptr_ OrtFence** ortFence);
+ORT_API_STATUS_IMPL(InteropEpWait, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue);
+ORT_API_STATUS_IMPL(InteropEpSignal, _In_ OrtFence* ortFence, _In_ OrtSyncStream* stream, _In_ uint64_t fenceValue);
} // namespace OrtApis
diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc
index 364bab471ddbe..5709f4c29cb03 100644
--- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc
+++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc
@@ -31,6 +31,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl
OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator;
OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer;
OrtEpFactory::IsStreamAware = Forward::IsStreamAware;
+ OrtEpFactory::SetupGraphicsInterop = Forward::SetupGraphicsInterop;
OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice;
OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions;
}
diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h
index 6eb83a117fb63..88f8aba15b813 100644
--- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h
+++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h
@@ -74,6 +74,11 @@ class EpFactoryInternal : public OrtEpFactory {
return impl_->IsStreamAware();
}
+ OrtStatus* SetupGraphicsInterop(_In_ const OrtMemoryDevice* memory_device,
+ _In_ const struct GraphicsInteropParams* graphicsInteropParams) noexcept {
+ return impl_->SetupGraphicsInterop(memory_device, graphicsInteropParams);
+ }
+
OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device,
_In_opt_ const OrtKeyValuePairs* stream_options,
_Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept {
diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
index de9e2d44431bf..4275c56cdd42c 100644
--- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
+++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h
@@ -62,6 +62,12 @@ class EpFactoryInternalImpl {
return false;
}
+ virtual OrtStatus* SetupGraphicsInterop(_In_ const OrtMemoryDevice* /*memory_device*/,
+ _In_ const struct GraphicsInteropParams* /*graphicsInteropParams*/) noexcept {
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED,
+ "SetupGraphicsInterop is not implemented for this EP factory.");
+ }
+
virtual OrtStatus* ValidateCompiledModelCompatibilityInfo(
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_ size_t num_devices,
diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h
index 8c5ef526baba1..7b1dfd329247d 100644
--- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h
+++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h
@@ -59,6 +59,14 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl {
return ep_factory_.IsStreamAware(&ep_factory_);
}
+ OrtStatus* SetupGraphicsInterop(const OrtMemoryDevice* memory_device,
+ const struct GraphicsInteropParams* graphicsInteropParams) noexcept override {
+ if (ep_factory_.SetupGraphicsInterop == nullptr) {
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Graphics interop is not supported by this EP");
+ }
+ return ep_factory_.SetupGraphicsInterop(&ep_factory_, memory_device, graphicsInteropParams);
+ }
+
OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device,
const OrtKeyValuePairs* stream_options,
OrtSyncStreamImpl** stream) noexcept override {
diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h
index 65c396181f0a7..51d63e192a86a 100644
--- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h
+++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h
@@ -74,6 +74,12 @@ struct ForwardToFactoryImpl {
return static_cast(this_ptr)->IsStreamAware();
}
+ static OrtStatus* ORT_API_CALL SetupGraphicsInterop(_In_ OrtEpFactory* this_ptr,
+ _In_ const OrtMemoryDevice* memory_device,
+ _In_ const struct GraphicsInteropParams* graphicsInteropParams) noexcept {
+ return static_cast(this_ptr)->SetupGraphicsInterop(memory_device, graphicsInteropParams);
+ }
+
static OrtStatus* ORT_API_CALL CreateSyncStreamForDevice(_In_ OrtEpFactory* this_ptr,
_In_ const OrtMemoryDevice* memory_device,
_In_opt_ const OrtKeyValuePairs* stream_options,
diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_ort_interop_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_ort_interop_test.cc
new file mode 100644
index 0000000000000..82983adcfe830
--- /dev/null
+++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_ort_interop_test.cc
@@ -0,0 +1,438 @@
+// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+// Licensed under the MIT License.
+#include "core/graph/onnx_protobuf.h"
+#include "core/session/inference_session.h"
+#include "test/providers/provider_test_utils.h"
+#include "test/unittest_util/framework_test_utils.h"
+
+#include "test/util/include/scoped_env_vars.h"
+#include "test/common/trt_op_test_utils.h"
+#include "test/common/random_generator.h"
+#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h"
+
+#include
+#include
+
+#if DX_FOR_INTEROP && _WIN32
+#include
+#include
+using Microsoft::WRL::ComPtr;
+#endif
+
+using namespace std;
+using namespace ONNX_NAMESPACE;
+using namespace ::onnxruntime::logging;
+extern std::unique_ptr ort_env;
+namespace onnxruntime {
+
+ #if DX_FOR_INTEROP && _WIN32
+void CreateD3D12Buffer(ID3D12Device* pDevice, const size_t size, ID3D12Resource** ppResource, D3D12_RESOURCE_STATES initState)
+{
+ D3D12_RESOURCE_DESC bufferDesc = {};
+ bufferDesc.MipLevels = 1;
+ bufferDesc.Format = DXGI_FORMAT_UNKNOWN;
+ bufferDesc.Width = size;
+ bufferDesc.Height = 1;
+ bufferDesc.Flags = D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS;
+ bufferDesc.DepthOrArraySize = 1;
+ bufferDesc.SampleDesc.Count = 1;
+ bufferDesc.SampleDesc.Quality = 0;
+ bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
+ bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
+
+ D3D12_HEAP_PROPERTIES heapProps = {};
+ heapProps.Type = D3D12_HEAP_TYPE_DEFAULT;
+ heapProps.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
+ heapProps.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
+ heapProps.CreationNodeMask = 1;
+ heapProps.VisibleNodeMask = 1;
+
+ HRESULT hr = pDevice->CreateCommittedResource(
+ &heapProps,
+ D3D12_HEAP_FLAG_NONE,
+ &bufferDesc,
+ initState,
+ nullptr,
+ IID_PPV_ARGS(ppResource));
+
+ if (FAILED(hr))
+ {
+ GTEST_FAIL() << "Failed creating a D3D12 resource, HRESULT: 0x" << std::hex << hr;
+ }
+}
+
+void CreateUploadBuffer(ID3D12Device* pDevice, const size_t size, ID3D12Resource** ppResource)
+{
+ D3D12_HEAP_PROPERTIES heapProps = {};
+ heapProps.Type = D3D12_HEAP_TYPE_UPLOAD;
+ heapProps.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
+ heapProps.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
+ heapProps.CreationNodeMask = 1;
+ heapProps.VisibleNodeMask = 1;
+
+ D3D12_RESOURCE_DESC bufferDesc = {};
+ bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
+ bufferDesc.Alignment = 0;
+ bufferDesc.Width = size;
+ bufferDesc.Height = 1;
+ bufferDesc.DepthOrArraySize = 1;
+ bufferDesc.MipLevels = 1;
+ bufferDesc.Format = DXGI_FORMAT_UNKNOWN;
+ bufferDesc.SampleDesc.Count = 1;
+ bufferDesc.SampleDesc.Quality = 0;
+ bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
+ bufferDesc.Flags = D3D12_RESOURCE_FLAG_NONE;
+
+ HRESULT hr = pDevice->CreateCommittedResource(
+ &heapProps,
+ D3D12_HEAP_FLAG_NONE,
+ &bufferDesc,
+ D3D12_RESOURCE_STATE_GENERIC_READ,
+ nullptr,
+ IID_PPV_ARGS(ppResource));
+ if (FAILED(hr))
+ {
+ GTEST_FAIL() << "Failed creating a D3D12 upload resource, HRESULT: 0x" << std::hex << hr;
+ }
+}
+
+void CreateReadBackBuffer(ID3D12Device* pDevice, const size_t size, ID3D12Resource** ppResource)
+{
+ D3D12_HEAP_PROPERTIES heapProps = {};
+ heapProps.Type = D3D12_HEAP_TYPE_READBACK;
+ heapProps.CPUPageProperty = D3D12_CPU_PAGE_PROPERTY_UNKNOWN;
+ heapProps.MemoryPoolPreference = D3D12_MEMORY_POOL_UNKNOWN;
+ heapProps.CreationNodeMask = 1;
+ heapProps.VisibleNodeMask = 1;
+
+ D3D12_RESOURCE_DESC bufferDesc = {};
+ bufferDesc.Dimension = D3D12_RESOURCE_DIMENSION_BUFFER;
+ bufferDesc.Alignment = 0;
+ bufferDesc.Width = size;
+ bufferDesc.Height = 1;
+ bufferDesc.DepthOrArraySize = 1;
+ bufferDesc.MipLevels = 1;
+ bufferDesc.Format = DXGI_FORMAT_UNKNOWN;
+ bufferDesc.SampleDesc.Count = 1;
+ bufferDesc.SampleDesc.Quality = 0;
+ bufferDesc.Layout = D3D12_TEXTURE_LAYOUT_ROW_MAJOR;
+ bufferDesc.Flags = D3D12_RESOURCE_FLAG_NONE;
+
+ HRESULT hr = pDevice->CreateCommittedResource(
+ &heapProps,
+ D3D12_HEAP_FLAG_NONE,
+ &bufferDesc,
+ D3D12_RESOURCE_STATE_COPY_DEST,
+ nullptr,
+ IID_PPV_ARGS(ppResource));
+ if (FAILED(hr))
+ {
+ GTEST_FAIL() << "Failed creating a D3D12 read back resource, HRESULT: 0x" << std::hex << hr;
+ }
+}
+
+
+void FlushAndWait(ID3D12Device* pDevice, ID3D12CommandQueue* pQueue)
+{
+ // Event and D3D12 Fence to manage CPU<->GPU sync (we want to keep 2 iterations in "flight")
+ HANDLE hEvent = CreateEvent(nullptr, FALSE, FALSE, nullptr);
+ ComPtr pFence;
+ pDevice->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&pFence));
+
+ pQueue->Signal(pFence.Get(), 1);
+ pFence->SetEventOnCompletion(1, hEvent);
+ DWORD retVal = WaitForSingleObject(hEvent, INFINITE);
+
+ CloseHandle(hEvent);
+ // ComPtr automatically releases pFence
+}
+#endif
+namespace test {
+
+TEST(NvExecutionProviderTest, GraphicsORTInteropTest) {
+#if TRT_MAJOR_RTX > 1 || TRT_MINOR_RTX >= 3
+ PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx");
+ std::string graph_name = "test";
+ constexpr int image_dim = 1080;
+
+ // Create a simple 1-input, 1-output Relu model
+ onnxruntime::Model model(graph_name, false, DefaultLoggingManager().DefaultLogger());
+ auto& graph = model.MainGraph();
+
+ ONNX_NAMESPACE::TypeProto tensor_type;
+ tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
+ tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1);
+ tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(3);
+ tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(image_dim);
+ tensor_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(image_dim);
+
+ auto& input_arg = graph.GetOrCreateNodeArg("input", &tensor_type);
+ auto& output_arg = graph.GetOrCreateNodeArg("output", &tensor_type);
+ graph.AddNode("relu_node", "Relu", "Relu operation", {&input_arg}, {&output_arg});
+
+ ASSERT_STATUS_OK(graph.Resolve());
+ ASSERT_STATUS_OK(onnxruntime::Model::Save(model, model_name));
+
+#if DX_FOR_INTEROP && _WIN32
+ {
+ std::vector cpuInputHalf(3 * image_dim * image_dim);
+ std::vector cpuOutputHalf(3 * image_dim * image_dim);
+ using StreamUniquePtr = std::unique_ptr>;
+
+ // Generate random data for input
+ {
+ RandomValueGenerator random{};
+ std::vector shape{3, image_dim, image_dim};
+ std::vector input_data = random.Uniform(shape, static_cast(0), static_cast(65535));
+ memcpy(cpuInputHalf.data(), input_data.data(), cpuInputHalf.size() * sizeof(uint16_t));
+ } // input_data is freed here
+
+
+ // set up d3d12
+ ComPtr pDevice;
+ ComPtr pCommandQueue;
+ ComPtr pInput;
+ ComPtr pOutput;
+ ComPtr pUploadRes;
+ ComPtr pUploadResCorrupt;
+ ComPtr pDownloadRes;
+ ComPtr pUploadCommandList;
+ ComPtr pDownloadCommandList;
+ ComPtr pAllocatorCopy;
+
+ uint64_t fenceValue = 0;
+ GraphicsInteropParams graphicsInteropParams;
+ graphicsInteropParams.extSyncPrimitive = ExternalSyncPrimitive_D3D12Fence;
+ graphicsInteropParams.DevicePtr.DXDeviceParams.pDevice = nullptr;
+ graphicsInteropParams.DevicePtr.DXDeviceParams.pCommandQueue = nullptr;
+ HANDLE sharedFenceHandle = nullptr;
+ FenceInteropParams fenceInteropParams;
+ fenceInteropParams.extSyncPrimitive = ExternalSyncPrimitive_D3D12Fence;
+ fenceInteropParams.FencePtr.pFence = nullptr;
+ OrtFence* ortFence = nullptr;
+
+ HRESULT hr = D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&pDevice));
+ if (FAILED(hr))
+ {
+ GTEST_SKIP() << "Failed to create D3D12 device, HRESULT: 0x" << std::hex << hr << " - D3D12 may not be available on this system";
+ }
+ graphicsInteropParams.DevicePtr.DXDeviceParams.pDevice = pDevice.Get();
+
+ D3D12_COMMAND_QUEUE_DESC queueDesc = {};
+ queueDesc.Type = D3D12_COMMAND_LIST_TYPE_COMPUTE;
+ queueDesc.Flags = D3D12_COMMAND_QUEUE_FLAG_NONE;
+ hr = pDevice->CreateCommandQueue(&queueDesc, IID_PPV_ARGS(&pCommandQueue));
+ if (FAILED(hr))
+ {
+ GTEST_SKIP() << "Failed to create D3D12 command queue, HRESULT: 0x" << std::hex << hr << " - Command queue may not be available on this system";
+ }
+ graphicsInteropParams.DevicePtr.DXDeviceParams.pCommandQueue = pCommandQueue.Get();
+
+ // Use ORT APIs to load the model
+ OrtApi const& ortApi = Ort::GetApi();
+ Ort::SessionOptions sessionOptions;
+ sessionOptions.SetExecutionMode(ExecutionMode::ORT_SEQUENTIAL);
+ sessionOptions.DisableMemPattern();
+ sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
+ sessionOptions.AddConfigEntry("session.disable_cpu_ep_fallback", "1");
+ ortApi.AddFreeDimensionOverrideByName(sessionOptions, "batch_size", 1);
+
+ std::string trtLibPath = "onnxruntime_providers_nv_tensorrt_rtx.dll";
+ std::wstring wideTrtLibPath = std::wstring(trtLibPath.begin(), trtLibPath.end());
+
+ OrtStatus* status = ortApi.RegisterExecutionProviderLibrary(*ort_env, "NvTensorRtRtx", wideTrtLibPath.c_str());
+ if (status != nullptr) {
+ std::string error_message = ortApi.GetErrorMessage(status);
+ ortApi.ReleaseStatus(status);
+ FAIL() << "Failed to register EP library: " << error_message;
+ }
+
+ sessionOptions.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU);
+
+ ComPtr pFence;
+ pDevice->CreateFence(fenceValue, D3D12_FENCE_FLAG_SHARED, IID_PPV_ARGS(&pFence));
+ fenceInteropParams.FencePtr.pFence = pFence.Get();
+ pDevice->CreateSharedHandle(pFence.Get(), nullptr, GENERIC_ALL, nullptr, &sharedFenceHandle);
+
+ const OrtEpDevice* const* ep_devices = nullptr;
+ size_t num_ep_devices;
+ ortApi.GetEpDevices(*ort_env, &ep_devices, &num_ep_devices);
+ const OrtEpDevice* trt_ep_device = nullptr;
+ for (UINT i = 0; i < num_ep_devices; i++)
+ {
+ if (strcmp(ortApi.EpDevice_EpName(ep_devices[i]), "NvTensorRTRTXExecutionProvider") == 0)
+ {
+ trt_ep_device = ep_devices[i];
+ break;
+ }
+ }
+
+ // Must be called before other interop functions to create the context
+ ortApi.SetupGraphicsInteropForEpDevice(trt_ep_device, &graphicsInteropParams);
+
+ // Create ORT stream - this will be created on the context we just set up
+ OrtSyncStream* stream = nullptr;
+ StreamUniquePtr stream_ptr;
+ ortApi.CreateSyncStreamForEpDevice(trt_ep_device, nullptr, &stream);
+ stream_ptr = StreamUniquePtr(stream, [ortApi](OrtSyncStream* stream) { ortApi.ReleaseSyncStream(stream); });
+
+ // Create IHV-agnostic memory info using hardware device vendor ID
+ OrtMemoryInfo* memory_info_agnostic = nullptr;
+ const OrtHardwareDevice* hw_device = ortApi.EpDevice_Device(trt_ep_device);
+ UINT vID = ortApi.HardwareDevice_VendorId(hw_device);
+ ortApi.CreateMemoryInfo_V2("Device_Agnostic", OrtMemoryInfoDeviceType_GPU,
+ /*vendor_id*/vID, /*device_id*/0,
+ OrtDeviceMemoryType_DEFAULT, /*default alignment*/0,
+ OrtArenaAllocator, &memory_info_agnostic);
+
+ auto memory_info_cleanup = std::unique_ptr>(
+ memory_info_agnostic,
+ [&ortApi](OrtMemoryInfo* ptr) {
+ if (ptr) ortApi.ReleaseMemoryInfo(ptr);
+ }
+ );
+
+ char streamAddress[32];
+ size_t stream_addr_val = reinterpret_cast(ortApi.SyncStream_GetHandle(stream));
+ sprintf_s(streamAddress, "%llu", static_cast(stream_addr_val));
+ const char* option_keys[] = { "user_compute_stream", "has_user_compute_stream" };
+ const char* option_values[] = { streamAddress, "1" };
+ for (size_t i = 0; i < num_ep_devices; i++)
+ {
+ if (strcmp(ortApi.EpDevice_EpName(ep_devices[i]), "CPUExecutionProvider") != 0)
+ ortApi.SessionOptionsAppendExecutionProvider_V2(sessionOptions, *ort_env, &ep_devices[i], 1, option_keys, option_values, 2);
+ }
+
+ // default resources
+ CreateD3D12Buffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pInput.GetAddressOf(), D3D12_RESOURCE_STATE_COPY_DEST);
+ CreateD3D12Buffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pOutput.GetAddressOf(), D3D12_RESOURCE_STATE_COPY_SOURCE);
+
+ // upload and download resources
+ CreateUploadBuffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pUploadRes.GetAddressOf());
+ CreateUploadBuffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pUploadResCorrupt.GetAddressOf());
+
+ CreateReadBackBuffer(pDevice.Get(), 3 * image_dim * image_dim * sizeof(uint16_t), pDownloadRes.GetAddressOf());
+
+ hr = pDevice->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_COMPUTE, IID_PPV_ARGS(&pAllocatorCopy));
+
+ hr = pDevice->CreateCommandList(1, D3D12_COMMAND_LIST_TYPE_COMPUTE, pAllocatorCopy.Get(), NULL, IID_PPV_ARGS(&pUploadCommandList));
+
+
+ // heavy GPU load for reproducing race condition
+ for (int i = 0; i < 1000; i++)
+ {
+ pUploadCommandList->CopyResource(pInput.Get(), pUploadResCorrupt.Get());
+
+ D3D12_RESOURCE_BARRIER barrier = {};
+ barrier.Type = D3D12_RESOURCE_BARRIER_TYPE_UAV;
+ barrier.Flags = D3D12_RESOURCE_BARRIER_FLAG_NONE;
+ barrier.UAV.pResource = nullptr; // This makes it a NULL UAV barrier
+
+ pUploadCommandList->ResourceBarrier(1, &barrier);
+ }
+
+ pUploadCommandList->CopyResource(pInput.Get(), pUploadRes.Get());
+ pUploadCommandList->Close();
+
+ std::cerr << "Test completed successfully1" <CreateCommandList(1, D3D12_COMMAND_LIST_TYPE_COMPUTE, pAllocatorCopy.Get(), NULL, IID_PPV_ARGS(&pDownloadCommandList));
+ pDownloadCommandList->CopyResource(pDownloadRes.Get(), pOutput.Get());
+ pDownloadCommandList->Close();
+
+ Ort::Session session = Ort::Session(*ort_env, L"nv_execution_provider_test.onnx", sessionOptions);
+
+ ortApi.GetOrtFenceForGraphicsInterop(session, &graphicsInteropParams, &fenceInteropParams, &ortFence);
+
+ Ort::IoBinding ioBinding = Ort::IoBinding::IoBinding(session);
+
+ Ort::AllocatorWithDefaultOptions allocator;
+ Ort::AllocatedStringPtr InputTensorName = session.GetInputNameAllocated(0, allocator);
+ Ort::AllocatedStringPtr OuptutTensorName = session.GetOutputNameAllocated(0, allocator);
+
+ int64_t inputDim[] = { 1, 3, image_dim, image_dim };
+ int64_t outputDim[] = { 1, 3, image_dim, image_dim };
+
+ // upload the input
+ void* pData;
+ pUploadRes->Map(0, nullptr, (void**)&pData);
+ memcpy(pData, cpuInputHalf.data(), cpuInputHalf.size() * sizeof(uint16_t));
+ pUploadRes->Unmap(0, nullptr);
+
+ // Upload corrupted data to test synchronization (should not affect the output)
+ void* pDataCorrupt;
+ pUploadResCorrupt->Map(0, nullptr, (void**)&pDataCorrupt);
+ std::fill_n((uint8_t*)pDataCorrupt, 3 * image_dim * image_dim * sizeof(uint16_t), 0xFF);
+ pUploadResCorrupt->Unmap(0, nullptr);
+
+ // bind the resources using IHV-agnostic memory info but keep zero-copy external memory sharing
+ Ort::Value inputTensor(Ort::Value::CreateTensor(memory_info_agnostic, (void*)pInput->GetGPUVirtualAddress(), cpuInputHalf.size() * sizeof(uint16_t), inputDim, 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16));
+ Ort::Value outputTensor(Ort::Value::CreateTensor(memory_info_agnostic, (void*)pOutput->GetGPUVirtualAddress(), cpuOutputHalf.size() * sizeof(uint16_t), outputDim, 4, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16));
+ ioBinding.BindInput(InputTensorName.get(), inputTensor);
+ ioBinding.BindOutput(OuptutTensorName.get(), outputTensor);
+ ioBinding.SynchronizeInputs();
+
+ std::cerr << "Test completed successfully2" <ExecuteCommandLists(1, &pUploadCmdList);
+
+ // make ORT wait for upload
+ pCommandQueue->Signal(pFence.Get(), fenceValue);
+ ortApi.InteropEpWait(ortFence, stream, fenceValue); // make ORT wait on the fence (on CUDA side internally)
+
+ // run the model
+ Ort::RunOptions runOptions;
+ runOptions.AddConfigEntry("disable_synchronize_execution_providers", "1");
+ session.Run(runOptions, ioBinding);
+
+ fenceValue++;
+ // make DX wait for ORT
+ ortApi.InteropEpSignal(ortFence, stream, fenceValue); // signal from CUDA side (internally)
+ pCommandQueue->Wait(pFence.Get(), fenceValue);
+
+ // download the output to cpu memory (again using DX)
+ ID3D12CommandList* pDownloadCmdList = pDownloadCommandList.Get();
+ pCommandQueue->ExecuteCommandLists(1, &pDownloadCmdList);
+ FlushAndWait(pDevice.Get(), pCommandQueue.Get());
+ }
+
+
+ std::cerr << "Test completed successfully3" <Map(0, nullptr, (void**)&pOutputData);
+ memcpy(cpuOutputHalf.data(), pOutputData, cpuOutputHalf.size() * sizeof(uint16_t));
+ pDownloadRes->Unmap(0, nullptr);
+
+ std::cerr << "First 50 elements of cpuInputHalf:\n";
+ for (int i = 0; i < 50; i++) {
+ std::cerr << cpuInputHalf[i] << " ";
+ }
+ std::cerr << std::endl;
+
+ std::cerr << "First 50 elements of cpuOutputHalf:\n";
+ for (int i = 0; i < 50; i++) {
+ std::cerr << cpuOutputHalf[i] << " ";
+ }
+ std::cerr << std::endl;
+
+ // ComPtr automatically handles cleanup via RAII
+ CloseHandle(sharedFenceHandle);
+
+ std::cerr << "\nInference done. Check output image." < None:
azure_group = parser.add_argument_group("Azure Execution Provider")
azure_group.add_argument("--use_azure", action="store_true", help="Enable Azure EP.")
+ # --- DX for Interop Feature ---
+ dx_group = parser.add_argument_group("DX for Interop Feature")
+ dx_group.add_argument("--use_dx_for_interop", action="store_true", help="Enable DX for Interop feature.")
+
+ # --- Vulkan for Interop Feature ---
+ vulkan_group = parser.add_argument_group("Vulkan for Interop Feature")
+ vulkan_group.add_argument("--use_vulkan_for_interop", action="store_true", help="Enable VULKAN for Interop feature.")
def add_other_feature_args(parser: argparse.ArgumentParser) -> None:
"""Adds arguments for other miscellaneous features."""