diff --git a/src/dml/dml_helpers.cpp b/src/dml/dml_helpers.cpp index 876ed7036..3637b36e9 100644 --- a/src/dml/dml_helpers.cpp +++ b/src/dml/dml_helpers.cpp @@ -21,14 +21,14 @@ static bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE || (is_basic_render_driver_vendor_id && is_basic_render_driver_device_id); }; -static std::vector> EnumerateAdapters() { +static std::vector> EnumerateAdapters(PLUID device_luid = nullptr) { ComPtr dxgi_factory; THROW_IF_FAILED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory))); std::vector> adapter_infos; ComPtr dxgi_factory6; - if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6))) { + if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6)) && !device_luid) { // Enumerate adapters by performance. This only works in Windows 10 Version 1803 and later. ComPtr adapter; for (uint32_t adapter_index = 0; @@ -52,6 +52,31 @@ static std::vector> EnumerateAdapters() { adapter_infos.emplace_back(std::move(adapter)); } } + } else if (device_luid) { + // Find the adapter with the required LUID. + ComPtr adapter; + for (uint32_t adapter_index = 0; dxgi_factory->EnumAdapters1(adapter_index, &adapter) != DXGI_ERROR_NOT_FOUND; adapter_index++) { + // We can't assume the ordering of hardware and software adapters, so keep looping. This path should only execute on Windows 10 + // version 1709 or earlier; IDD (e.g. remote desktop) adapters do not exist when taking this code path. + if (IsSoftwareAdapter(adapter.Get())) { + continue; + } + + // Make sure that we are able to create the device + ComPtr d3d12_device; + THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device))); + + if (d3d12_device) { + DXGI_ADAPTER_DESC1 description = {}; + THROW_IF_FAILED(adapter->GetDesc1(&description)); + + // Check if current adapter LUID is the same as the target one + if (device_luid->HighPart == description.AdapterLuid.HighPart && device_luid->LowPart == description.AdapterLuid.LowPart) { + adapter_infos.emplace_back(std::move(adapter)); + break; + } + } + } } else { // Enumerate adapters without ordering. ComPtr adapter; @@ -75,15 +100,15 @@ static std::vector> EnumerateAdapters() { return adapter_infos; } -static ComPtr CreatePerformantAdapter() { - auto filtered_adapters = EnumerateAdapters(); +static ComPtr CreateAdapter(PLUID device_luid = nullptr) { + auto filtered_adapters = EnumerateAdapters(device_luid); if (filtered_adapters.empty()) { throw std::runtime_error("No adapter is available for DML."); } return filtered_adapters.front(); } -DmlObjects CreateDmlObjects(const std::string& current_module_path) { +DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid) { D3D12_COMMAND_QUEUE_DESC command_queue_description = { D3D12_COMMAND_LIST_TYPE_COMPUTE, 0, @@ -93,8 +118,7 @@ DmlObjects CreateDmlObjects(const std::string& current_module_path) { DmlObjects dml_objects; - auto adapter = CreatePerformantAdapter(); - + auto adapter = CreateAdapter(device_luid); ComPtr d3d12_sdk_config; ComPtr d3d12_factory; diff --git a/src/dml/dml_helpers.h b/src/dml/dml_helpers.h index 2f530a0e4..c65c23672 100644 --- a/src/dml/dml_helpers.h +++ b/src/dml/dml_helpers.h @@ -31,7 +31,7 @@ struct DmlObjects { }; namespace DmlHelpers { -DmlObjects CreateDmlObjects(const std::string& current_module_path); +DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid = nullptr); DmlReusedCommandListState BuildReusableCommandList( IDMLDevice* dml_device, diff --git a/src/models/model.cpp b/src/models/model.cpp index 858fc9d91..a9dd1271e 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -437,7 +437,26 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ } else if (provider_options.name == "dml") { if (!p_dml_api_) { auto current_module_path = CurrentModulePath(); - dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path); + + bool contains_device_luid = false; + LUID device_luid{}; + for (const auto& [name, value]: provider_options.options) { + if (name == "luid_high_part") { + device_luid.HighPart = std::stol(value); + contains_device_luid = true; + } + else if (name == "luid_low_part") { + device_luid.LowPart = std::stol(value); + contains_device_luid = true; + } + } + + if (contains_device_luid) { + dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path, &device_luid); + } + else { + dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path); + } constexpr auto directml_dll = "DirectML.dll"; wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll, nullptr, 0));