Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added functionality to choose dml adapter by luid #1041

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/dml/dml_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ static bool IsSoftwareAdapter(IDXGIAdapter1* adapter) {
return desc.Flags & DXGI_ADAPTER_FLAG_SOFTWARE || (is_basic_render_driver_vendor_id && is_basic_render_driver_device_id);
};

static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters(PLUID device_luid = nullptr) {
ComPtr<IDXGIFactory4> dxgi_factory;
THROW_IF_FAILED(CreateDXGIFactory(IID_PPV_ARGS(&dxgi_factory)));

std::vector<ComPtr<IDXGIAdapter1>> adapter_infos;

ComPtr<IDXGIFactory6> dxgi_factory6;
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6))) {
if (SUCCEEDED(dxgi_factory.As(&dxgi_factory6)) && !device_luid) {
// Enumerate adapters by performance. This only works in Windows 10 Version 1803 and later.
ComPtr<IDXGIAdapter1> adapter;
for (uint32_t adapter_index = 0;
Expand Down Expand Up @@ -66,7 +66,16 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&d3d12_device)));

if (d3d12_device) {
if (d3d12_device && device_luid) {
DXGI_ADAPTER_DESC1 description = {};
THROW_IF_FAILED(adapter->GetDesc1(&description));

// Check if current adapter LUID is the same as the target one
if (device_luid->HighPart == description.AdapterLuid.HighPart && device_luid->LowPart == description.AdapterLuid.LowPart) {
adapter_infos.emplace_back(std::move(adapter));
break;
}
} else if (d3d12_device) {
adapter_infos.emplace_back(std::move(adapter));
}
}
Expand All @@ -75,15 +84,15 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
return adapter_infos;
}

static ComPtr<IDXGIAdapter1> CreatePerformantAdapter() {
auto filtered_adapters = EnumerateAdapters();
static ComPtr<IDXGIAdapter1> CreateAdapter(PLUID device_luid = nullptr) {
auto filtered_adapters = EnumerateAdapters(device_luid);
if (filtered_adapters.empty()) {
throw std::runtime_error("No adapter is available for DML.");
}
return filtered_adapters.front();
}

DmlObjects CreateDmlObjects(const std::string& current_module_path) {
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid) {
D3D12_COMMAND_QUEUE_DESC command_queue_description = {
D3D12_COMMAND_LIST_TYPE_COMPUTE,
0,
Expand All @@ -93,8 +102,7 @@ DmlObjects CreateDmlObjects(const std::string& current_module_path) {

DmlObjects dml_objects;

auto adapter = CreatePerformantAdapter();

auto adapter = CreateAdapter(device_luid);
ComPtr<ID3D12SDKConfiguration1> d3d12_sdk_config;
ComPtr<ID3D12DeviceFactory> d3d12_factory;

Expand Down
2 changes: 1 addition & 1 deletion src/dml/dml_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ struct DmlObjects {
};

namespace DmlHelpers {
DmlObjects CreateDmlObjects(const std::string& current_module_path);
DmlObjects CreateDmlObjects(const std::string& current_module_path, PLUID device_luid = nullptr);

DmlReusedCommandListState BuildReusableCommandList(
IDMLDevice* dml_device,
Expand Down
19 changes: 18 additions & 1 deletion src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,24 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
} else if (provider_options.name == "dml") {
if (!p_dml_api_) {
auto current_module_path = CurrentModulePath();
dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path);

bool contains_device_luid = false;
LUID device_luid{};
for (const auto& [name, value] : provider_options.options) {
if (name == "luid") {
if (auto separator_position = value.find(":"); separator_position != std::string::npos) {
device_luid.HighPart = std::stol(value.substr(0, separator_position));
device_luid.LowPart = std::stol(value.substr(separator_position + 1));
contains_device_luid = true;
}
}
}

if (contains_device_luid) {
dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path, &device_luid);
} else {
dml_objects_ = DmlHelpers::CreateDmlObjects(current_module_path);
}

constexpr auto directml_dll = "DirectML.dll";
wil::unique_hmodule smart_directml_dll(LoadLibraryEx(directml_dll, nullptr, 0));
Expand Down
29 changes: 29 additions & 0 deletions test/model_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
#define PHI2_PATH MODEL_PATH "phi-2/int4/cpu"
#endif
#endif
#if USE_DML
#include <DirectML.h>
#include <wrl.h>
#include <d3d12.h>
#include <dxgi1_6.h>
#endif

// To generate this file:
// python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20
Expand All @@ -26,6 +32,29 @@ static const std::pair<const char*, const char*> c_tiny_gpt2_model_paths[] = {
{MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp16-cuda", "fp16"},
};

#if USE_DML
TEST(ModelTests, DMLAdapterSelection) {
#if TEST_PHI2
auto model = Generators::CreateModel(Generators::GetOrtEnv(), PHI2_PATH);
auto d3d12Device = model->GetD3D12Device();

auto adapterLuid = d3d12Device->GetAdapterLuid();
for (const auto& provider_option: model->config_->model.decoder.session_options.provider_options) {
if (provider_option.name == "dml") {
for (const auto& [name, value] : provider_option.options) {
if (name == "luid") {
if (auto separator_position = value.find(":"); separator_position != std::string::npos) {
EXPECT_EQ(adapterLuid.HighPart, std::stol(value.substr(0, separator_position)));
EXPECT_EQ(adapterLuid.LowPart, std::stoul(value.substr(separator_position + 1)));
}
}
}
}
}
#endif
}
#endif

// DML doesn't support GPT attention
#if !USE_DML
TEST(ModelTests, GreedySearchGptFp32) {
Expand Down
Loading