diff --git a/src/dml/dml_helpers.cpp b/src/dml/dml_helpers.cpp index 876ed7036..c423be033 100644 --- a/src/dml/dml_helpers.cpp +++ b/src/dml/dml_helpers.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -75,10 +76,52 @@ static std::vector> EnumerateAdapters() { return adapter_infos; } -static ComPtr CreatePerformantAdapter() { +static std::vector> EnumerateCoreAdapters() { + // Try to find Direct3D 12 Core Compute adapters. + ComPtr adapterFactory; + THROW_IF_FAILED(::DXCoreCreateAdapterFactory(IID_PPV_ARGS(&adapterFactory))); + + std::vector> adapter_infos; + + ComPtr d3D12CoreComputeAdapters; + GUID attributes[]{DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE}; + THROW_IF_FAILED(adapterFactory->CreateAdapterList(_countof(attributes), attributes, IID_PPV_ARGS(&d3D12CoreComputeAdapters))); + + // Ask the OS to sort for the highest performance hardware adapter. + DXCoreAdapterPreference sortPreferences[]{ + DXCoreAdapterPreference::Hardware, DXCoreAdapterPreference::HighPerformance }; + THROW_IF_FAILED(d3D12CoreComputeAdapters->Sort(_countof(sortPreferences), sortPreferences)); + + const uint32_t count{ d3D12CoreComputeAdapters->GetAdapterCount() }; + for (uint32_t i = 0; i < count; ++i) { + ComPtr adapter; + THROW_IF_FAILED(d3D12CoreComputeAdapters->GetAdapter(i, IID_PPV_ARGS(&adapter))); + if (adapter) { + bool isHardware{ false }; + THROW_IF_FAILED(adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &isHardware)); + bool hasCoreCompute = adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE); + if (isHardware && hasCoreCompute) { + // Make sure that we are able to create the device + ComPtr d3d12_device; + THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_1_0_CORE, IID_PPV_ARGS(&d3d12_device))); + if (d3d12_device) { + adapter_infos.emplace_back(std::move(adapter)); + } + } + } + } + + return adapter_infos; +} + +static std::variant, ComPtr> CreatePerformantAdapter() { auto filtered_adapters = EnumerateAdapters(); if (filtered_adapters.empty()) { - throw std::runtime_error("No adapter is available for DML."); + auto core_adapters = EnumerateCoreAdapters(); + if (core_adapters.empty()) { + throw std::runtime_error("No adapter is available for DML."); + } + return core_adapters.front(); } return filtered_adapters.front(); } @@ -101,13 +144,20 @@ DmlObjects CreateDmlObjects(const std::string& current_module_path) { // Get the version from https://devblogs.microsoft.com/directx/directx12agility/. We are currently using 1.614.0. constexpr uint32_t agility_sdk_version = 614; - if (SUCCEEDED(D3D12GetInterface(CLSID_D3D12SDKConfiguration, IID_PPV_ARGS(&d3d12_sdk_config))) && - SUCCEEDED(d3d12_sdk_config->CreateDeviceFactory(agility_sdk_version, current_module_path.c_str(), IID_PPV_ARGS(&d3d12_factory)))) { - THROW_IF_FAILED(d3d12_factory->CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device))); + if (std::holds_alternative>(adapter)) { + ComPtr dxgiAdapter = std::get>(adapter); + if (SUCCEEDED(D3D12GetInterface(CLSID_D3D12SDKConfiguration, IID_PPV_ARGS(&d3d12_sdk_config))) && + SUCCEEDED(d3d12_sdk_config->CreateDeviceFactory(agility_sdk_version, current_module_path.c_str(), IID_PPV_ARGS(&d3d12_factory)))) { + THROW_IF_FAILED(d3d12_factory->CreateDevice(dxgiAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device))); + } else { + printf("Warning: Unable to create a device from version 1.614.0 of the DirectX 12 Agility SDK. You can still use this library, but some scenarios may not work.\n"); + printf("The given module path: %s", current_module_path.c_str()); + THROW_IF_FAILED(D3D12CreateDevice(dxgiAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device))); + } } else { - printf("Warning: Unable to create a device from version 1.614.0 of the DirectX 12 Agility SDK. You can still use this library, but some scenarios may not work.\n"); - printf("The given module path: %s", current_module_path.c_str()); - THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device))); + // Use DXCore adapter + ComPtr dxcoreAdapter = std::get>(adapter); + THROW_IF_FAILED(D3D12CreateDevice(dxcoreAdapter.Get(), D3D_FEATURE_LEVEL_1_0_CORE, IID_PPV_ARGS(&dml_objects.d3d12_device))); } THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandQueue(&command_queue_description, IID_PPV_ARGS(&dml_objects.command_queue)));