diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 59ca1a1df762e..3a423a64b9047 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -146,6 +146,17 @@ class Environment { return execution_devices_; } + /// Get hardware device incompatibility details for a specific EP. + /// @param ep_name The name of the execution provider to check. + /// @param hw The hardware device to check for incompatibility. + /// @param details Output: Incompatibility details including reasons for incompatibility if any. + /// @returns Status indicating success or failure. + Status GetHardwareDeviceEpIncompatibilityDetails(const std::string& ep_name, + const OrtHardwareDevice* hw, + std::unique_ptr& details) const; + + const std::vector& GetSortedOrtHardwareDevices() const; + Status CreateSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 303bb5411ffd9..5acac571f3f3b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // See docs\c_cxx\README.md on generating the Doxygen documentation from this file @@ -333,6 +333,7 @@ ORT_RUNTIME_CLASS(ExternalInitializerInfo); ORT_RUNTIME_CLASS(ExternalResourceImporter); // Capability object for external resource import ORT_RUNTIME_CLASS(ExternalMemoryHandle); // EP-imported view of shared external allocation ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared external semaphore +ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -510,6 +511,16 @@ typedef enum OrtExecutionProviderDevicePolicy { OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER, } OrtExecutionProviderDevicePolicy; +/** \brief Reasons why an execution provider might not be compatible with a device + */ +typedef enum OrtDeviceEpIncompatibilityReason { + OrtDeviceEpIncompatibility_NONE = 0, + OrtDeviceEpIncompatibility_DRIVER_INCOMPATIBLE = 1 << 0, + OrtDeviceEpIncompatibility_DEVICE_INCOMPATIBLE = 1 << 1, + OrtDeviceEpIncompatibility_MISSING_DEPENDENCY = 1 << 2, + OrtDeviceEpIncompatibility_UNKNOWN = 1 << 31 +} OrtDeviceEpIncompatibilityReason; + /** \brief Delegate to allow providing custom OrtEpDevice selection logic * * This delegate is called by the EP selection code to allow the user to provide custom device selection logic. @@ -6784,6 +6795,121 @@ struct OrtApi { ORT_API2_STATUS(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session, _Out_writes_(num_outputs) const OrtEpDevice** outputs_ep_devices, _In_ size_t num_outputs); + /** \brief Get the number of available hardware devices. + * + * Returns the count of hardware devices discovered on the system. + * Use this to allocate an array before calling GetHardwareDevices(). + * + * \param[in] env The OrtEnv instance where device discovery results are stored. + * \param[out] num_devices The number of OrtHardwareDevice instances available. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetNumHardwareDevices, _In_ const OrtEnv* env, _Out_ size_t* num_devices); + + /** \brief Get the list of available hardware devices. + * + * Enumerates hardware devices available on the system. + * Populates a user-provided array with pointers to OrtHardwareDevice instances. The caller is responsible + * for allocating the array with sufficient space (use GetNumHardwareDevices() to get the count). + * + * The returned pointers reference internal ORT data structures that are discovered once at process + * startup and remain valid for the lifetime of the OrtEnv. The caller does not need to release these + * pointers, but should not use them after calling ReleaseEnv(). + * + * \param[in] env The OrtEnv instance where device discovery results are stored. + * \param[out] devices User-allocated array to receive pointers to OrtHardwareDevice instances. + * The array must have space for at least num_devices elements. + * \param[in] num_devices The size of the user-allocated devices array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetHardwareDevices, _In_ const OrtEnv* env, + _Out_writes_(num_devices) const OrtHardwareDevice** devices, + _In_ size_t num_devices); + + /** \brief Check for known incompatibility issues between hardware device and a specific execution provider. + * + * This function checks for known incompatibility issues between the specified hardware device + * and a specific execution provider. + * If returned incompatibility details have non-zero reasons, it indicates the device is not compatible. + * However, if returned detail have reason == 0, it doesn't guarantee 100% compatibility for all models, + * as models may have specific requirements. + * + * Note: This method should only be called when the OrtEnv has been initialized with execution + * providers (after RegisterExecutionProviderLibrary is called). + * + * \param[in] env The OrtEnv instance with registered execution providers. + * \param[in] ep_name The name of the execution provider to check. Required and cannot be null or empty. + * \param[in] hw The hardware device to check for incompatibility. + * \param[out] details Compatibility details including reasons for incompatibility if any. + * Must be freed with OrtApi::ReleaseDeviceEpIncompatibilityDetails. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetHardwareDeviceEpIncompatibilityDetails, _In_ const OrtEnv* env, + _In_ const char* ep_name, + _In_ const OrtHardwareDevice* hw, + _Outptr_ OrtDeviceEpIncompatibilityDetails** details); + + /// \name OrtDeviceEpIncompatibilityDetails + /// Accessor functions for device incompatibility details + /// @{ + + /** \brief Get the incompatibility reasons bitmask from OrtDeviceEpIncompatibilityDetails. + * + * \param[in] details The OrtDeviceEpIncompatibilityDetails instance to query. + * \param[out] reasons_bitmask Pointer to store the bitmask of incompatibility reasons. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(DeviceEpIncompatibilityDetails_GetReasonsBitmask, + _In_ const OrtDeviceEpIncompatibilityDetails* details, + _Out_ uint32_t* reasons_bitmask); + + /** \brief Get the notes from OrtDeviceEpIncompatibilityDetails. + * + * \param[in] details The OrtDeviceEpIncompatibilityDetails instance to query. + * \param[out] notes Pointer to the notes string. May be nullptr if no notes are available. + * The returned string is owned by the details object and should not be freed. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(DeviceEpIncompatibilityDetails_GetNotes, + _In_ const OrtDeviceEpIncompatibilityDetails* details, + _Outptr_result_maybenull_ const char** notes); + + /** \brief Get the execution provider error code from OrtDeviceEpIncompatibilityDetails. + * + * This allows Independent Hardware Vendors (IHVs) to define their own error codes + * to provide additional details about device incompatibility. + * + * \param[in] details The OrtDeviceEpIncompatibilityDetails instance to query. + * \param[out] error_code Pointer to store the EP-specific error code. A value of 0 indicates no error code was set. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(DeviceEpIncompatibilityDetails_GetErrorCode, + _In_ const OrtDeviceEpIncompatibilityDetails* details, + _Out_ int32_t* error_code); + + /** \brief Release an OrtDeviceEpIncompatibilityDetails instance. + * + * \since Version 1.24. + */ + ORT_CLASS_RELEASE(DeviceEpIncompatibilityDetails); /// @} }; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 0db06f23dcd4a..6bb454cd47623 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1308,6 +1308,27 @@ struct OrtEpApi { */ ORT_API2_STATUS(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ const OrtEp** ep); + /** \brief Set the details of an OrtDeviceEpIncompatibilityDetails instance. + * + * Used by execution provider factories to set incompatibility details in their + * GetHardwareDeviceIncompatibilityDetails implementation. ORT creates and initializes the object + * before passing it to the EP, so calling this function is optional. The EP uses this function + * to set incompatibility information when the device is not compatible. + * + * \param[in,out] details The OrtDeviceEpIncompatibilityDetails instance to update. + * \param[in] reasons_bitmask Bitmask of OrtDeviceEpIncompatibilityReason values. (0 = no incompatibility). + * \param[in] error_code Optional EP-specific error code (0 = no error). + * \param[in] notes Optional human-readable notes. Can be null. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(DeviceEpIncompatibilityDetails_SetDetails, _Inout_ OrtDeviceEpIncompatibilityDetails* details, + _In_ uint32_t reasons_bitmask, + _In_ int32_t error_code, + _In_opt_z_ const char* notes); + /** \brief Creates an OrtKernelImpl instance for an If operator. * * Control flow operators require access to ORT session internals to orchestrate subgraph operations. @@ -1990,6 +2011,30 @@ struct OrtEpFactory { */ ORT_API2_STATUS(SetEnvironmentOptions, _In_ OrtEpFactory* this_ptr, _In_ const OrtKeyValuePairs* options); + /** \brief Check for known incompatibility reasons between a hardware device and this execution provider. + * + * This function allows an execution provider to check if a specific hardware device is compatible + * with the execution provider. The EP can set specific incompatibility reasons via the + * OrtDeviceEpIncompatibilityDetails parameter using OrtEpApi::DeviceEpIncompatibilityDetails_SetDetails. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] hw The hardware device to check for incompatibility. + * \param[in,out] details Pre-allocated incompatibility details object created and initialized by ORT. + * The EP can use OrtEpApi::DeviceEpIncompatibilityDetails_SetDetails to set + * incompatibility information. If the device is compatible, the EP can + * leave the object unchanged (it defaults to no incompatibility). + * + * \note Implementation of this function is optional. + * If not implemented, ORT will assume the device is compatible with this EP. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetHardwareDeviceIncompatibilityDetails, _In_ OrtEpFactory* this_ptr, + _In_ const OrtHardwareDevice* hw, + _Inout_ OrtDeviceEpIncompatibilityDetails* details); + /** \brief Create an OrtExternalResourceImporterImpl for external resource import. * * This is used to create an external resource importer that enables zero-copy import of diff --git a/onnxruntime/core/session/abi_devices.h b/onnxruntime/core/session/abi_devices.h index 571a9eb2a54e2..aafe2ab114ab9 100644 --- a/onnxruntime/core/session/abi_devices.h +++ b/onnxruntime/core/session/abi_devices.h @@ -75,3 +75,9 @@ struct OrtEpDevice { // get/create methods to be as flexible as possible. this helper converts to a non-const factory instance. OrtEpFactory* GetMutableFactory() const { return ep_factory; } }; + +struct OrtDeviceEpIncompatibilityDetails { + uint32_t reasons_bitmask{0}; // Bitmask of OrtDeviceEpIncompatibilityReason values + int32_t error_code{0}; // EP-specific error code (0 = no error) + std::string notes; // Additional human-readable notes +}; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 9008a906155fd..cd8a799115ce6 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -637,6 +637,106 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam return status; } +Status Environment::GetHardwareDeviceEpIncompatibilityDetails( + const std::string& ep_name, + const OrtHardwareDevice* hw, + std::unique_ptr& details) const { + std::lock_guard lock{mutex_}; + + OrtEpFactory* matched_factory = nullptr; + + // Search for a factory whose GetName() matches ep_name exactly. + for (const auto& [registration_name, ep_info] : ep_libraries_) { + for (OrtEpFactory* factory : ep_info->factories) { + if (factory != nullptr && factory->GetName != nullptr) { + const char* factory_name = factory->GetName(factory); + if (factory_name != nullptr && ep_name == factory_name) { + matched_factory = factory; + break; + } + } + } + if (matched_factory != nullptr) { + break; + } + } + + if (matched_factory == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "No valid factory found for execution provider '", ep_name, "'."); + } + + // ORT creates the details object with default values (compatible) + details = std::make_unique(); + // If the factory implements GetHardwareDeviceIncompatibilityDetails, let it initialize the details + if (matched_factory->GetHardwareDeviceIncompatibilityDetails != nullptr) { + OrtStatusPtr status = matched_factory->GetHardwareDeviceIncompatibilityDetails(matched_factory, hw, details.get()); + + if (status != nullptr) { + return ToStatusAndRelease(status); + } + } + + // Factory doesn't implement the hook - details remain with default values (compatible) + return Status::OK(); +} + +namespace { +std::vector SortDevicesByType() { + auto& devices = DeviceDiscovery::GetDevices(); + std::vector sorted_devices; + sorted_devices.reserve(devices.size()); + + const auto select_by_type = [&](OrtHardwareDeviceType type) { + for (const auto& device : devices) { + if (device.type == type) { + sorted_devices.push_back(&device); + } + } + }; + + select_by_type(OrtHardwareDeviceType_NPU); + select_by_type(OrtHardwareDeviceType_GPU); + select_by_type(OrtHardwareDeviceType_CPU); + + return sorted_devices; +} + +// Returns a static reference to sorted hardware devices. +// Hardware devices are discovered once at startup and don't change. +const std::vector& GetSortedHardwareDevices() { + static const auto sorted_devices = SortDevicesByType(); + return sorted_devices; +} + +bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) { + constexpr std::string_view suffix{".virtual"}; + + return lib_registration_name.size() >= suffix.size() && + lib_registration_name.compare(lib_registration_name.size() - suffix.size(), + suffix.size(), suffix) == 0; +} + +Status SetEpFactoryEnvironmentOptions(OrtEpFactory& factory, std::string_view lib_registration_name) { + // OrtEpFactory::SetEnvironmentOptions was added in ORT 1.24 + if (factory.ort_version_supported < 24 || factory.SetEnvironmentOptions == nullptr) { + return Status::OK(); + } + + // We only set one option now but this can be generalized if necessary. + OrtKeyValuePairs options; + options.Add("allow_virtual_devices", AreVirtualDevicesAllowed(lib_registration_name) ? "1" : "0"); + + ORT_RETURN_IF_ERROR(ToStatusAndRelease(factory.SetEnvironmentOptions(&factory, &options))); + + return Status::OK(); +} +} // namespace + +const std::vector& Environment::GetSortedOrtHardwareDevices() const { + return GetSortedHardwareDevices(); +} + Status Environment::CreateSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, const OrtKeyValuePairs* allocator_options, @@ -728,51 +828,6 @@ Status Environment::ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDevi return status; } -namespace { -std::vector SortDevicesByType() { - auto& devices = DeviceDiscovery::GetDevices(); - std::vector sorted_devices; - sorted_devices.reserve(devices.size()); - - const auto select_by_type = [&](OrtHardwareDeviceType type) { - for (const auto& device : devices) { - if (device.type == type) { - sorted_devices.push_back(&device); - } - } - }; - - select_by_type(OrtHardwareDeviceType_NPU); - select_by_type(OrtHardwareDeviceType_GPU); - select_by_type(OrtHardwareDeviceType_CPU); - - return sorted_devices; -} - -bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) { - constexpr std::string_view suffix{".virtual"}; - - return lib_registration_name.size() >= suffix.size() && - lib_registration_name.compare(lib_registration_name.size() - suffix.size(), - suffix.size(), suffix) == 0; -} - -Status SetEpFactoryEnvironmentOptions(OrtEpFactory& factory, std::string_view lib_registration_name) { - // OrtEpFactory::SetEnvironmentOptions was added in ORT 1.24 - if (factory.ort_version_supported < 24 || factory.SetEnvironmentOptions == nullptr) { - return Status::OK(); - } - - // We only set one option now but this can be generalized if necessary. - OrtKeyValuePairs options; - options.Add("allow_virtual_devices", AreVirtualDevicesAllowed(lib_registration_name) ? "1" : "0"); - - ORT_RETURN_IF_ERROR(ToStatusAndRelease(factory.SetEnvironmentOptions(&factory, &options))); - - return Status::OK(); -} -} // namespace - Status Environment::EpInfo::Create(std::unique_ptr library_in, std::unique_ptr& out, const std::vector& internal_factories) { if (!library_in) { @@ -787,9 +842,8 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u ORT_RETURN_IF_ERROR(instance.library->Load()); instance.factories = instance.library->GetFactories(); - // OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured. - // the set of hardware devices is static so this can also be static. - const static std::vector sorted_devices = SortDevicesByType(); + // OrtHardwareDevice instances to pass to GetSupportedDevices. + const auto& sorted_devices = GetSortedHardwareDevices(); for (auto* factory_ptr : instance.factories) { ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:", diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index c3bf74a4607e8..8cca7f2872c44 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4291,6 +4291,13 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::GetInteropApi, &OrtApis::SessionGetEpDeviceForOutputs, + &OrtApis::GetNumHardwareDevices, + &OrtApis::GetHardwareDevices, + &OrtApis::GetHardwareDeviceEpIncompatibilityDetails, + &OrtApis::DeviceEpIncompatibilityDetails_GetReasonsBitmask, + &OrtApis::DeviceEpIncompatibilityDetails_GetNotes, + &OrtApis::DeviceEpIncompatibilityDetails_GetErrorCode, + &OrtApis::ReleaseDeviceEpIncompatibilityDetails, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -4368,3 +4375,137 @@ DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Value, OrtValue) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(RunOptions, OrtRunOptions) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(Session, ::onnxruntime::InferenceSession) DEFINE_RELEASE_ORT_OBJECT_FUNCTION(ModelMetadata, ::onnxruntime::ModelMetadata) + +ORT_API_STATUS_IMPL(OrtApis::GetNumHardwareDevices, _In_ const OrtEnv* env, _Out_ size_t* num_devices) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (env == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "env must not be null"); + } + if (num_devices == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "num_devices must not be null"); + } + + const auto& device_vector = env->GetEnvironment().GetSortedOrtHardwareDevices(); + *num_devices = device_vector.size(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(env); + ORT_UNUSED_PARAMETER(num_devices); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetNumHardwareDevices is not available in minimal build"); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetHardwareDevices, _In_ const OrtEnv* env, + _Out_writes_(num_devices) const OrtHardwareDevice** devices, + _In_ size_t num_devices) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (env == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "env must not be null"); + } + if (devices == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "devices must not be null"); + } + + const auto& device_vector = env->GetEnvironment().GetSortedOrtHardwareDevices(); + size_t available_devices = device_vector.size(); + + if (num_devices < available_devices) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "num_devices is less than the number of available hardware devices. " + "Use GetNumHardwareDevices() to get the required array size."); + } + + for (size_t i = 0; i < available_devices; ++i) { + devices[i] = device_vector[i]; + } + + return nullptr; +#else + ORT_UNUSED_PARAMETER(env); + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetHardwareDevices is not available in minimal build"); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::GetHardwareDeviceEpIncompatibilityDetails, _In_ const OrtEnv* env, _In_ const char* ep_name, _In_ const OrtHardwareDevice* hw, _Outptr_ OrtDeviceEpIncompatibilityDetails** details) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + // Validate all input parameters + if (env == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "env is required and cannot be null"); + } + if (ep_name == nullptr || ep_name[0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_name is required and cannot be null or empty"); + } + if (hw == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "hw is required and cannot be null"); + } + if (details == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "details output parameter cannot be null"); + } + + std::unique_ptr compat_details; + auto status = env->GetEnvironment().GetHardwareDeviceEpIncompatibilityDetails(ep_name, hw, compat_details); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + + *details = compat_details.release(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(env); + ORT_UNUSED_PARAMETER(ep_name); + ORT_UNUSED_PARAMETER(hw); + ORT_UNUSED_PARAMETER(details); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetHardwareDeviceEpIncompatibilityDetails is not available in minimal build"); +#endif + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::DeviceEpIncompatibilityDetails_GetReasonsBitmask, _In_ const OrtDeviceEpIncompatibilityDetails* details, _Out_ uint32_t* reasons_bitmask) { + API_IMPL_BEGIN + if (details == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "details cannot be null"); + } + if (reasons_bitmask == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "reasons_bitmask output parameter cannot be null"); + } + *reasons_bitmask = details->reasons_bitmask; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::DeviceEpIncompatibilityDetails_GetNotes, _In_ const OrtDeviceEpIncompatibilityDetails* details, _Outptr_result_maybenull_ const char** notes) { + API_IMPL_BEGIN + if (details == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "details cannot be null"); + } + if (notes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "notes output parameter cannot be null"); + } + *notes = details->notes.empty() ? nullptr : details->notes.c_str(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::DeviceEpIncompatibilityDetails_GetErrorCode, _In_ const OrtDeviceEpIncompatibilityDetails* details, _Out_ int32_t* error_code) { + API_IMPL_BEGIN + if (details == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "details cannot be null"); + } + if (error_code == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "error_code output parameter cannot be null"); + } + *error_code = details->error_code; + return nullptr; + API_IMPL_END +} + +void ORT_API_CALL OrtApis::ReleaseDeviceEpIncompatibilityDetails(OrtDeviceEpIncompatibilityDetails* details) noexcept { + delete details; +} diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 7aa09adfd32d1..a93d853592dea 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -480,6 +480,20 @@ ORT_API_STATUS_IMPL(KernelContext_GetAllocator, _In_ const OrtKernelContext* con ORT_API(const char*, GetBuildInfoString); +ORT_API_STATUS_IMPL(GetNumHardwareDevices, _In_ const OrtEnv* env, _Out_ size_t* num_devices); + +ORT_API_STATUS_IMPL(GetHardwareDevices, _In_ const OrtEnv* env, _Out_writes_(num_devices) const OrtHardwareDevice** devices, _In_ size_t num_devices); + +ORT_API_STATUS_IMPL(GetHardwareDeviceEpIncompatibilityDetails, _In_ const OrtEnv* env, _In_ const char* ep_name, _In_ const OrtHardwareDevice* hw, _Outptr_ OrtDeviceEpIncompatibilityDetails** details); + +ORT_API_STATUS_IMPL(DeviceEpIncompatibilityDetails_GetReasonsBitmask, _In_ const OrtDeviceEpIncompatibilityDetails* details, _Out_ uint32_t* reasons_bitmask); + +ORT_API_STATUS_IMPL(DeviceEpIncompatibilityDetails_GetNotes, _In_ const OrtDeviceEpIncompatibilityDetails* details, _Outptr_result_maybenull_ const char** notes); + +ORT_API_STATUS_IMPL(DeviceEpIncompatibilityDetails_GetErrorCode, _In_ const OrtDeviceEpIncompatibilityDetails* details, _Out_ int32_t* error_code); + +ORT_API(void, ReleaseDeviceEpIncompatibilityDetails, _Frees_ptr_opt_ OrtDeviceEpIncompatibilityDetails*); + ORT_API_STATUS_IMPL(CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out); ORT_API_STATUS_IMPL(UpdateROCMProviderOptions, _Inout_ OrtROCMProviderOptions* rocm_options, _In_reads_(num_keys) const char* const* provider_options_keys, diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index a598de3053133..21e6ae1525838 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -656,6 +656,27 @@ ORT_API_STATUS_IMPL(KernelInfo_GetEp, _In_ const OrtKernelInfo* info, _Outptr_ c API_IMPL_END } +ORT_API_STATUS_IMPL(DeviceEpIncompatibilityDetails_SetDetails, _Inout_ OrtDeviceEpIncompatibilityDetails* details, + _In_ uint32_t reasons_bitmask, + _In_ int32_t error_code, + _In_opt_z_ const char* notes) { + API_IMPL_BEGIN + if (details == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "details parameter must not be null"); + } + + details->reasons_bitmask = reasons_bitmask; + details->error_code = error_code; + if (notes != nullptr) { + details->notes = notes; + } else { + details->notes.clear(); + } + + return nullptr; + API_IMPL_END +} + // Control flow kernel APIs ORT_API_STATUS_IMPL(CreateIfKernel, _In_ const OrtKernelInfo* kernel_info, _Outptr_ OrtKernelImpl** kernel_out) { API_IMPL_BEGIN @@ -819,6 +840,7 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::EpGraphSupportInfo_LookUpKernel, &OrtExecutionProviderApi::SharedPrePackedWeightCache_StoreWeightData, &OrtExecutionProviderApi::KernelInfo_GetEp, + &OrtExecutionProviderApi::DeviceEpIncompatibilityDetails_SetDetails, &OrtExecutionProviderApi::CreateIfKernel, &OrtExecutionProviderApi::CreateLoopKernel, &OrtExecutionProviderApi::CreateScanKernel, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 8dc92802aa84b..a13645e293844 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -34,6 +34,7 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; OrtEpFactory::SetEnvironmentOptions = Forward::SetEnvironmentOptions; OrtEpFactory::CreateExternalResourceImporterForDevice = Forward::CreateExternalResourceImporterForDevice; + OrtEpFactory::GetHardwareDeviceIncompatibilityDetails = Forward::GetHardwareDeviceIncompatibilityDetails; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h index ae98f2c0ac589..6f4a37f44fb44 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -96,6 +96,11 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->CreateExternalResourceImporterForDevice(ep_device, importer); } + OrtStatus* GetHardwareDeviceIncompatibilityDetails(_In_ const OrtHardwareDevice* hw, + _Inout_ OrtDeviceEpIncompatibilityDetails* details) noexcept { + return impl_->GetHardwareDeviceIncompatibilityDetails(hw, details); + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index 20a47715df2b8..7f42cdda33a96 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -96,6 +96,12 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* GetHardwareDeviceIncompatibilityDetails(_In_ const OrtHardwareDevice* /*hw*/, + _Inout_ OrtDeviceEpIncompatibilityDetails* /*details*/) noexcept { + // Default implementation: leave details unchanged (device assumed compatible) + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 26173f0055ed7..3a7a1b6504d12 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -77,6 +77,15 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateExternalResourceImporterForDevice(&ep_factory_, ep_device, importer); } + OrtStatus* GetHardwareDeviceIncompatibilityDetails(_In_ const OrtHardwareDevice* hw, + _Inout_ OrtDeviceEpIncompatibilityDetails* details) noexcept override { + if (ep_factory_.GetHardwareDeviceIncompatibilityDetails == nullptr) { + // Factory doesn't implement this hook, leave details unchanged (device assumed compatible) + return nullptr; + } + return ep_factory_.GetHardwareDeviceIncompatibilityDetails(&ep_factory_, hw, details); + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_; diff --git a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 2530ae8eb3c2b..27c453b500017 100644 --- a/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -94,6 +94,12 @@ struct ForwardToFactoryImpl { return static_cast(this_ptr)->CreateExternalResourceImporterForDevice(ep_device, importer); } + static OrtStatus* ORT_API_CALL GetHardwareDeviceIncompatibilityDetails(_In_ OrtEpFactory* this_ptr, + _In_ const OrtHardwareDevice* hw, + _Inout_ OrtDeviceEpIncompatibilityDetails* details) noexcept { + return static_cast(this_ptr)->GetHardwareDeviceIncompatibilityDetails(hw, details); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 7c2b8e59ade89..79ec3fe3a3780 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -36,6 +36,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + GetHardwareDeviceIncompatibilityDetails = GetHardwareDeviceIncompatibilityDetailsImpl; CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; @@ -312,6 +313,29 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac } /*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetHardwareDeviceIncompatibilityDetailsImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* hw, + OrtDeviceEpIncompatibilityDetails* details) noexcept { + auto& factory = *static_cast(this_ptr); + + // Example: This EP only supports CPU devices. Report incompatibility for non-CPU devices. + OrtHardwareDeviceType device_type = factory.ort_api.HardwareDevice_Type(hw); + + if (device_type != OrtHardwareDeviceType_CPU) { + // Report that the device type is not supported + uint32_t reasons = OrtDeviceEpIncompatibility_DEVICE_INCOMPATIBLE; + return factory.ep_api.DeviceEpIncompatibilityDetails_SetDetails( + details, + reasons, + static_cast(device_type), // Use device type as the error code for testing + "ExampleEP only supports CPU devices"); + } + + // Device is compatible - details are already initialized with default values by ORT + return nullptr; +} + OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDeviceImpl( OrtEpFactory* this_ptr, const OrtEpDevice* /*ep_device*/, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 230fdef772e2f..9306b0fc88ec9 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -73,6 +73,11 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtEpDevice* ep_device, OrtExternalResourceImporterImpl** out_importer) noexcept; + static OrtStatus* ORT_API_CALL GetHardwareDeviceIncompatibilityDetailsImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* hw, + OrtDeviceEpIncompatibilityDetails* details) noexcept; + const OrtLogger& default_logger_; // default logger for the EP factory const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index c20a5455e5eae..0970654b48ca1 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include // #include #include #include @@ -549,5 +550,102 @@ TEST(OrtEpLibrary, KernelPluginEp_ControlFlow_Scan) { ASSERT_NO_FATAL_FAILURE(RunScanMulModel(session_options)); } } + +// Tests the GetHardwareDeviceEpIncompatibilityDetails C API with the example plugin EP. +// The example plugin EP supports CPU devices, so this test verifies that a CPU device +// is reported as compatible (reasons_bitmask == 0). +TEST(OrtEpLibrary, PluginEp_CpuDevice_ReturnsCompatible) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = static_cast(*ort_env); + + // Register the example plugin EP + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + + // Get all hardware devices + size_t num_hw_devices = 0; + ASSERT_ORTSTATUS_OK(api->GetNumHardwareDevices(env, &num_hw_devices)); + ASSERT_GT(num_hw_devices, 0u); + std::vector hw_devices(num_hw_devices); + ASSERT_ORTSTATUS_OK(api->GetHardwareDevices(env, hw_devices.data(), num_hw_devices)); + + // Find a CPU device using the public accessor + const OrtHardwareDevice* cpu_device = nullptr; + for (size_t i = 0; i < num_hw_devices; ++i) { + if (api->HardwareDevice_Type(hw_devices[i]) == OrtHardwareDeviceType_CPU) { + cpu_device = hw_devices[i]; + break; + } + } + ASSERT_NE(cpu_device, nullptr) << "No CPU device found"; + + // Check compatibility - ExampleEP supports CPU, so should return no incompatibility reasons + OrtDeviceEpIncompatibilityDetails* details = nullptr; + ASSERT_ORTSTATUS_OK(api->GetHardwareDeviceEpIncompatibilityDetails(env, Utils::example_ep_info.registration_name.c_str(), + cpu_device, &details)); + ASSERT_NE(details, nullptr); + + // Verify compatible (no incompatibility reasons) + uint32_t reasons_bitmask = 0xFFFFFFFF; + ASSERT_ORTSTATUS_OK(api->DeviceEpIncompatibilityDetails_GetReasonsBitmask(details, &reasons_bitmask)); + EXPECT_EQ(reasons_bitmask, 0u) << "CPU device should be compatible with example_plugin_ep"; + + int32_t error_code = -1; + ASSERT_ORTSTATUS_OK(api->DeviceEpIncompatibilityDetails_GetErrorCode(details, &error_code)); + EXPECT_EQ(error_code, 0); + + api->ReleaseDeviceEpIncompatibilityDetails(details); +} + +// Tests the GetHardwareDeviceEpIncompatibilityDetails C API with the example plugin EP. +// The example plugin EP only supports CPU devices, so this test verifies that a GPU device +// is reported as incompatible (reasons_bitmask != 0). +TEST(OrtEpLibrary, PluginEp_GpuDevice_ReturnsInCompatible) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = static_cast(*ort_env); + + // Register the regular example plugin EP (CPU-only) + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + + // Get all hardware devices + size_t num_hw_devices = 0; + ASSERT_ORTSTATUS_OK(api->GetNumHardwareDevices(env, &num_hw_devices)); + ASSERT_GT(num_hw_devices, 0u); + std::vector hw_devices(num_hw_devices); + ASSERT_ORTSTATUS_OK(api->GetHardwareDevices(env, hw_devices.data(), num_hw_devices)); + + // Find a GPU device using the public accessor + const OrtHardwareDevice* gpu_device = nullptr; + for (size_t i = 0; i < num_hw_devices; ++i) { + if (api->HardwareDevice_Type(hw_devices[i]) == OrtHardwareDeviceType_GPU) { + gpu_device = hw_devices[i]; + break; + } + } + + if (gpu_device == nullptr) { + // GPU device not found, early exit + GTEST_SKIP() << "No GPU device found"; + } + + // Check compatibility - ExampleEP only supports CPU, so GPU should return incompatibility reasons + OrtDeviceEpIncompatibilityDetails* details = nullptr; + ASSERT_ORTSTATUS_OK(api->GetHardwareDeviceEpIncompatibilityDetails(env, Utils::example_ep_info.registration_name.c_str(), + gpu_device, &details)); + ASSERT_NE(details, nullptr); + + // Verify incompatible (should have incompatibility reasons) + uint32_t reasons_bitmask = 0; + ASSERT_ORTSTATUS_OK(api->DeviceEpIncompatibilityDetails_GetReasonsBitmask(details, &reasons_bitmask)); + EXPECT_NE(reasons_bitmask, 0u) << "GPU device should be incompatible with example_plugin_ep (CPU-only)"; + + api->ReleaseDeviceEpIncompatibilityDetails(details); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/hardware_device_compatibility_test.cc b/onnxruntime/test/framework/hardware_device_compatibility_test.cc new file mode 100644 index 0000000000000..d5f0188cb683f --- /dev/null +++ b/onnxruntime/test/framework/hardware_device_compatibility_test.cc @@ -0,0 +1,374 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/test_environment.h" +#include "test/util/include/api_asserts.h" + +#include + +using namespace onnxruntime::test; + +namespace { +// Helper to get hardware devices using the two-step API pattern +void GetHardwareDevicesHelper(const OrtApi* api, OrtEnv* env, + std::vector& devices) { + size_t num_devices = 0; + ASSERT_ORTSTATUS_OK(api->GetNumHardwareDevices(env, &num_devices)); + devices.resize(num_devices); + if (num_devices > 0) { + ASSERT_ORTSTATUS_OK(api->GetHardwareDevices(env, devices.data(), num_devices)); + } +} +} // namespace + +// ----------------------------- +// GetHardwareDeviceEpIncompatibilityDetails C API unit tests +// ----------------------------- + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, InvalidArguments_NullEnv) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + // Create env for GetHardwareDevices + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + // Get a valid hardware device first + std::vector hw_devices; + ASSERT_NO_FATAL_FAILURE(GetHardwareDevicesHelper(api, env, hw_devices)); + ASSERT_GT(hw_devices.size(), 0u); + + // env == nullptr for GetHardwareDeviceEpIncompatibilityDetails + OrtDeviceEpIncompatibilityDetails* details = nullptr; + OrtStatus* st = api->GetHardwareDeviceEpIncompatibilityDetails(nullptr, "CPUExecutionProvider", hw_devices[0], &details); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, InvalidArguments_NullEpName) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + std::vector hw_devices; + ASSERT_NO_FATAL_FAILURE(GetHardwareDevicesHelper(api, env, hw_devices)); + ASSERT_GT(hw_devices.size(), 0u); + + // ep_name == nullptr + OrtDeviceEpIncompatibilityDetails* details = nullptr; + OrtStatus* st = api->GetHardwareDeviceEpIncompatibilityDetails(env, nullptr, hw_devices[0], &details); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, InvalidArguments_EmptyEpName) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + std::vector hw_devices; + ASSERT_NO_FATAL_FAILURE(GetHardwareDevicesHelper(api, env, hw_devices)); + ASSERT_GT(hw_devices.size(), 0u); + + // ep_name == "" + OrtDeviceEpIncompatibilityDetails* details = nullptr; + OrtStatus* st = api->GetHardwareDeviceEpIncompatibilityDetails(env, "", hw_devices[0], &details); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, InvalidArguments_NullHardwareDevice) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + // hw == nullptr + OrtDeviceEpIncompatibilityDetails* details = nullptr; + OrtStatus* st = api->GetHardwareDeviceEpIncompatibilityDetails(env, "CPUExecutionProvider", nullptr, &details); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, InvalidArguments_NullDetailsOutput) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + std::vector hw_devices; + ASSERT_NO_FATAL_FAILURE(GetHardwareDevicesHelper(api, env, hw_devices)); + ASSERT_GT(hw_devices.size(), 0u); + + // details == nullptr + OrtStatus* st = api->GetHardwareDeviceEpIncompatibilityDetails(env, "CPUExecutionProvider", hw_devices[0], nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, UnregisteredEp_ReturnsInvalidArgument) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + std::vector hw_devices; + ASSERT_NO_FATAL_FAILURE(GetHardwareDevicesHelper(api, env, hw_devices)); + ASSERT_GT(hw_devices.size(), 0u); + + // Non-existent EP name should return INVALID_ARGUMENT + OrtDeviceEpIncompatibilityDetails* details = nullptr; + OrtStatus* st = api->GetHardwareDeviceEpIncompatibilityDetails(env, "NonExistentExecutionProvider", hw_devices[0], &details); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + EXPECT_THAT(api->GetErrorMessage(st), testing::HasSubstr("No valid factory found")); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, CpuEp_ReturnsEmptyDetails) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + std::vector hw_devices; + ASSERT_NO_FATAL_FAILURE(GetHardwareDevicesHelper(api, env, hw_devices)); + ASSERT_GT(hw_devices.size(), 0u); + + // CPU EP doesn't implement GetHardwareDeviceIncompatibilityDetails, so should return empty details + OrtDeviceEpIncompatibilityDetails* details = nullptr; + ASSERT_ORTSTATUS_OK(api->GetHardwareDeviceEpIncompatibilityDetails(env, "CPUExecutionProvider", hw_devices[0], &details)); + ASSERT_NE(details, nullptr); + + // Verify empty details + uint32_t reasons_bitmask = 0xFFFFFFFF; // Initialize to non-zero to verify it gets set + ASSERT_ORTSTATUS_OK(api->DeviceEpIncompatibilityDetails_GetReasonsBitmask(details, &reasons_bitmask)); + EXPECT_EQ(reasons_bitmask, 0u); + + int32_t error_code = -1; // Initialize to non-zero to verify it gets set + ASSERT_ORTSTATUS_OK(api->DeviceEpIncompatibilityDetails_GetErrorCode(details, &error_code)); + EXPECT_EQ(error_code, 0); + + const char* notes = reinterpret_cast(0xDEADBEEF); // Initialize to non-null + ASSERT_ORTSTATUS_OK(api->DeviceEpIncompatibilityDetails_GetNotes(details, ¬es)); + EXPECT_TRUE(notes == nullptr || strlen(notes) == 0); + + api->ReleaseDeviceEpIncompatibilityDetails(details); + api->ReleaseEnv(env); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, AccessorFunctions_NullDetails) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + // Test accessor functions with null details + uint32_t reasons_bitmask = 0; + OrtStatus* st = api->DeviceEpIncompatibilityDetails_GetReasonsBitmask(nullptr, &reasons_bitmask); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + int32_t error_code = 0; + st = api->DeviceEpIncompatibilityDetails_GetErrorCode(nullptr, &error_code); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + const char* notes = nullptr; + st = api->DeviceEpIncompatibilityDetails_GetNotes(nullptr, ¬es); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); +} + +TEST(GetHardwareDeviceEpIncompatibilityDetailsCapiTest, AccessorFunctions_NullOutputPtr) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "EpIncompatTest", &env)); + EXPECT_NE(env, nullptr); + + std::vector hw_devices; + ASSERT_NO_FATAL_FAILURE(GetHardwareDevicesHelper(api, env, hw_devices)); + ASSERT_GT(hw_devices.size(), 0u); + + // Get a valid details object first + OrtDeviceEpIncompatibilityDetails* details = nullptr; + ASSERT_ORTSTATUS_OK(api->GetHardwareDeviceEpIncompatibilityDetails(env, "CPUExecutionProvider", hw_devices[0], &details)); + ASSERT_NE(details, nullptr); + + // Test accessor functions with null output pointers + OrtStatus* st = api->DeviceEpIncompatibilityDetails_GetReasonsBitmask(details, nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + st = api->DeviceEpIncompatibilityDetails_GetErrorCode(details, nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + st = api->DeviceEpIncompatibilityDetails_GetNotes(details, nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseDeviceEpIncompatibilityDetails(details); + api->ReleaseEnv(env); +} + +// ----------------------------- +// GetNumHardwareDevices / GetHardwareDevices C API unit tests +// ----------------------------- + +TEST(GetHardwareDevicesCapiTest, GetNumHardwareDevices_InvalidArguments_NullEnv) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + size_t num_devices = 0; + OrtStatus* st = api->GetNumHardwareDevices(nullptr, &num_devices); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); +} + +TEST(GetHardwareDevicesCapiTest, GetNumHardwareDevices_InvalidArguments_NullNumDevices) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "HwDevicesTest", &env)); + EXPECT_NE(env, nullptr); + + OrtStatus* st = api->GetNumHardwareDevices(env, nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDevicesCapiTest, GetHardwareDevices_InvalidArguments_NullEnv) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + const OrtHardwareDevice* devices[1] = {nullptr}; + OrtStatus* st = api->GetHardwareDevices(nullptr, devices, 1); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); +} + +TEST(GetHardwareDevicesCapiTest, GetHardwareDevices_InvalidArguments_NullDevices) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "HwDevicesTest", &env)); + EXPECT_NE(env, nullptr); + + OrtStatus* st = api->GetHardwareDevices(env, nullptr, 1); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDevicesCapiTest, GetHardwareDevices_InvalidArguments_ArrayTooSmall) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "HwDevicesTest", &env)); + EXPECT_NE(env, nullptr); + + // Get number of devices first + size_t num_devices = 0; + ASSERT_ORTSTATUS_OK(api->GetNumHardwareDevices(env, &num_devices)); + ASSERT_GT(num_devices, 0u); + + // Try to get devices with an undersized array (pass a valid pointer but claim size is 0) + std::vector devices(1); // Allocate at least 1 element to avoid nullptr + OrtStatus* st = api->GetHardwareDevices(env, devices.data(), 0); // But claim size is 0 + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + EXPECT_THAT(api->GetErrorMessage(st), testing::HasSubstr("num_devices is less than")); + api->ReleaseStatus(st); + + api->ReleaseEnv(env); +} + +TEST(GetHardwareDevicesCapiTest, ReturnsDevices) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtEnv* env = nullptr; + EXPECT_EQ(nullptr, api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "HwDevicesTest", &env)); + EXPECT_NE(env, nullptr); + + // Get number of devices first + size_t num_devices = 0; + ASSERT_ORTSTATUS_OK(api->GetNumHardwareDevices(env, &num_devices)); + + // Should return at least one device (CPU) + EXPECT_GT(num_devices, 0u); + + // Allocate array and get devices + std::vector devices(num_devices); + ASSERT_ORTSTATUS_OK(api->GetHardwareDevices(env, devices.data(), num_devices)); + + // Verify we can access device properties via C API accessor functions + for (size_t i = 0; i < num_devices; ++i) { + const OrtHardwareDevice* device = devices[i]; + ASSERT_NE(device, nullptr); + // Device type should be valid (CPU, GPU, or NPU) + OrtHardwareDeviceType device_type = api->HardwareDevice_Type(device); + EXPECT_TRUE(device_type == OrtHardwareDeviceType_CPU || + device_type == OrtHardwareDeviceType_GPU || + device_type == OrtHardwareDeviceType_NPU); + // Vendor should not be null + const char* vendor = api->HardwareDevice_Vendor(device); + EXPECT_NE(vendor, nullptr); + } + + api->ReleaseEnv(env); +}