Skip to content

Commit

Permalink
Added functionality to choose dml adapter by luid
Browse files Browse the repository at this point in the history
  • Loading branch information
DavitGrigoryan132 committed Nov 7, 2024
1 parent 8965fed commit 691d9e3
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 9 deletions.
38 changes: 31 additions & 7 deletions src/dml/dml_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ static bool IsSoftwareAdapter(IDXGIAdapter1* adapter) {
return desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE || (is_basic_render_driver_vendor_id && is_basic_render_driver_device_id);
};

static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters(PLUID device_luid = nullptr) {
ComPtr<IDXGIFactory4> dxgi_factory;
THROW_IF_FAILED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory)));

std::vector<ComPtr<IDXGIAdapter1>> adapter_infos;

ComPtr<IDXGIFactory6> dxgi_factory6;
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6))) {
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6)) && !device_luid) {
// Enumerate adapters by performance. This only works in Windows 10 Version 1803 and later.
ComPtr<IDXGIAdapter1> adapter;
for (uint32_t adapter_index = 0;
Expand All @@ -52,6 +52,31 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
adapter_infos.emplace_back(std::move(adapter));
}
}
} else if (device_luid) {
// Find the adapter with the required LUID.
ComPtr<IDXGIAdapter1> adapter;
for (uint32_t adapter_index = 0; dxgi_factory->EnumAdapters1(adapter_index, &adapter) != DXGI_ERROR_NOT_FOUND; adapter_index++) {
// We can't assume the ordering of hardware and software adapters, so keep looping. This path should only execute on Windows 10
// version 1709 or earlier; IDD (e.g. remote desktop) adapters do not exist when taking this code path.
if (IsSoftwareAdapter(adapter.Get())) {
continue;
}

// Make sure that we are able to create the device
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));

if (d3d12_device) {
DXGI_ADAPTER_DESC1 description = {};
THROW_IF_FAILED(adapter->GetDesc1(&description));

// Check if current adapter LUID is the same as the target one
if (device_luid->HighPart == description.AdapterLuid.HighPart && device_luid->LowPart == description.AdapterLuid.LowPart) {
adapter_infos.emplace_back(std::move(adapter));
break;
}
}
}
} else {
// Enumerate adapters without ordering.
ComPtr<IDXGIAdapter1> adapter;
Expand All @@ -75,15 +100,15 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
return adapter_infos;
}

static ComPtr<IDXGIAdapter1> CreatePerformantAdapter() {
auto filtered_adapters = EnumerateAdapters();
static ComPtr<IDXGIAdapter1> CreateAdapter(PLUID device_luid = nullptr) {
auto filtered_adapters = EnumerateAdapters(device_luid);
if (filtered_adapters.empty()) {
throw std::runtime_error("No adapter is available for DML.");
}
return filtered_adapters.front();
}

DmlObjects CreateDmlObjects(const std::string& current_module_path) {
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid) {
D3D12_COMMAND_QUEUE_DESC command_queue_description = {
D3D12_COMMAND_LIST_TYPE_COMPUTE,
0,
Expand All @@ -93,8 +118,7 @@ DmlObjects CreateDmlObjects(const std::string& current_module_path) {

DmlObjects dml_objects;

auto adapter = CreatePerformantAdapter();

auto adapter = CreateAdapter(device_luid);
ComPtr<ID3D12SDKConfiguration1> d3d12_sdk_config;
ComPtr<ID3D12DeviceFactory> d3d12_factory;

Expand Down
2 changes: 1 addition & 1 deletion src/dml/dml_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct DmlObjects {
};

namespace DmlHelpers {
DmlObjects CreateDmlObjects(const std::string& current_module_path);
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid = nullptr);

DmlReusedCommandListState BuildReusableCommandList(
IDMLDevice* dml_device,
Expand Down
19 changes: 18 additions & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,24 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
} else if (provider_options.name == "dml") {
if (!p_dml_api_) {
auto current_module_path = CurrentModulePath();
dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path);

bool contains_device_luid = false;
LUID device_luid{};
for (const auto& [name, value] : provider_options.options) {
if (name == "luid_high_part") {
device_luid.HighPart = std::stol(value);
contains_device_luid = true;
} else if (name == "luid_low_part") {
device_luid.LowPart = std::stol(value);
contains_device_luid = true;
}
}

if (contains_device_luid) {
dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path, &device_luid);
} else {
dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path);
}

constexpr auto directml_dll = "DirectML.dll";
wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll, nullptr, 0));
Expand Down

0 comments on commit 691d9e3

Please sign in to comment.