diff --git a/src/dml/dml_helpers.cpp b/src/dml/dml_helpers.cpp index 876ed7036..9e4c8b93c 100644 --- a/src/dml/dml_helpers.cpp +++ b/src/dml/dml_helpers.cpp @@ -21,14 +21,14 @@ static bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { return desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE || (is_basic_render_driver_vendor_id && is_basic_render_driver_device_id); }; -static std::vector> EnumerateAdapters() { +static std::vector> EnumerateAdapters(PLUID device_luid = nullptr) { ComPtr dxgi_factory; THROW_IF_FAILED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory))); std::vector> adapter_infos; ComPtr dxgi_factory6; - if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6))) { + if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6)) && !device_luid) { // Enumerate adapters by performance. This only works in Windows 10 Version 1803 and later. ComPtr adapter; for (uint32_t adapter_index = 0; @@ -66,7 +66,16 @@ static std::vector> EnumerateAdapters() { ComPtr d3d12_device; THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device))); - if (d3d12_device) { + if (d3d12_device && device_luid) { + DXGI_ADAPTER_DESC1 description = {}; + THROW_IF_FAILED(adapter->GetDesc1(&description)); + + // Check if current adapter LUID is the same as the target one + if (device_luid->HighPart == description.AdapterLuid.HighPart && device_luid->LowPart == description.AdapterLuid.LowPart) { + adapter_infos.emplace_back(std::move(adapter)); + break; + } + } else if (d3d12_device) { adapter_infos.emplace_back(std::move(adapter)); } } @@ -75,15 +84,15 @@ static std::vector> EnumerateAdapters() { return adapter_infos; } -static ComPtr CreatePerformantAdapter() { - auto filtered_adapters = EnumerateAdapters(); +static ComPtr CreateAdapter(PLUID device_luid = nullptr) { + auto filtered_adapters = EnumerateAdapters(device_luid); if (filtered_adapters.empty()) { throw std::runtime_error("No adapter is available for DML."); } return filtered_adapters.front(); } -DmlObjects CreateDmlObjects(const std::string& current_module_path) { +DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid) { D3D12_COMMAND_QUEUE_DESC command_queue_description = { D3D12_COMMAND_LIST_TYPE_COMPUTE, 0, @@ -93,8 +102,7 @@ DmlObjects CreateDmlObjects(const std::string& current_module_path) { DmlObjects dml_objects; - auto adapter = CreatePerformantAdapter(); - + auto adapter = CreateAdapter(device_luid); ComPtr d3d12_sdk_config; ComPtr d3d12_factory; diff --git a/src/dml/dml_helpers.h b/src/dml/dml_helpers.h index 2f530a0e4..c65c23672 100644 --- a/src/dml/dml_helpers.h +++ b/src/dml/dml_helpers.h @@ -31,7 +31,7 @@ struct DmlObjects { }; namespace DmlHelpers { -DmlObjects CreateDmlObjects(const std::string& current_module_path); +DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid = nullptr); DmlReusedCommandListState BuildReusableCommandList( IDMLDevice* dml_device, diff --git a/src/models/model.cpp b/src/models/model.cpp index f6b2b3111..4430fdb41 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -437,7 +437,24 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_ } else if (provider_options.name == "dml") { if (!p_dml_api_) { auto current_module_path = CurrentModulePath(); - dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path); + + bool contains_device_luid = false; + LUID device_luid{}; + for (const auto& [name, value] : provider_options.options) { + if (name == "luid") { + if (auto separator_position = value.find(":"); separator_position != std::string::npos) { + device_luid.HighPart = std::stol(value.substr(0, separator_position)); + device_luid.LowPart = std::stol(value.substr(separator_position + 1)); + contains_device_luid = true; + } + } + } + + if (contains_device_luid) { + dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path, &device_luid); + } else { + dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path); + } constexpr auto directml_dll = "DirectML.dll"; wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll, nullptr, 0)); diff --git a/test/model_tests.cpp b/test/model_tests.cpp index d60b6161e..7fb60b4a8 100644 --- a/test/model_tests.cpp +++ b/test/model_tests.cpp @@ -17,6 +17,12 @@ #define PHI2_PATH MODEL_PATH "phi-2/int4/cpu" #endif #endif +#if USE_DML +#include +#include +#include +#include +#endif // To generate this file: // python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20 @@ -26,6 +32,29 @@ static const std::pair c_tiny_gpt2_model_paths[] = { {MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp16-cuda", "fp16"}, }; +#if USE_DML +TEST(ModelTests, DMLAdapterSelection) { +#if TEST_PHI2 + auto model = Generators::CreateModel(Generators::GetOrtEnv(), PHI2_PATH); + auto d3d12Device = model->GetD3D12Device(); + + auto adapterLuid = d3d12Device->GetAdapterLuid(); + for (const auto& provider_option: model->config_->model.decoder.session_options.provider_options) { + if (provider_option.name == "dml") { + for (const auto& [name, value] : provider_option.options) { + if (name == "luid") { + if (auto separator_position = value.find(":"); separator_position != std::string::npos) { + EXPECT_EQ(adapterLuid.HighPart, std::stol(value.substr(0, separator_position))); + EXPECT_EQ(adapterLuid.LowPart, std::stoul(value.substr(separator_position + 1))); + } + } + } + } + } +#endif +} +#endif + // DML doesn't support GPT attention #if !USE_DML TEST(ModelTests, GreedySearchGptFp32) {