Skip to content

Commit

Permalink
Updated LUID config format and added test
Browse files Browse the repository at this point in the history
  • Loading branch information
DavitGrigoryan132 committed Nov 18, 2024
1 parent 79374c4 commit c15cc48
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,12 @@ void Model::CreateSessionOptionsFromConfig(const Config::SessionOptions& config_
bool contains_device_luid = false;
LUID device_luid{};
for (const auto& [name, value] : provider_options.options) {
if (name == "luid_high_part") {
device_luid.HighPart = std::stol(value);
contains_device_luid = true;
} else if (name == "luid_low_part") {
device_luid.LowPart = std::stol(value);
contains_device_luid = true;
if (name == "luid") {
if (auto separator_position = value.find(":"); separator_position != std::string::npos) {
device_luid.HighPart = std::stol(value.substr(0, separator_position));
device_luid.LowPart = std::stol(value.substr(separator_position + 1));
contains_device_luid = true;
}
}
}

Expand Down
29 changes: 29 additions & 0 deletions test/model_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
#define PHI2_PATH MODEL_PATH "phi-2/int4/cpu"
#endif
#endif
#if USE_DML
#include <DirectML.h>
#include <wrl.h>
#include <d3d12.h>
#include <dxgi1_6.h>
#endif

// To generate this file:
// python convert_generation.py --model_type gpt2 -m hf-internal-testing/tiny-random-gpt2 --output tiny_gpt2_greedysearch_fp16.onnx --use_gpu --max_length 20
Expand All @@ -26,6 +32,29 @@ static const std::pair<const char*, const char*> c_tiny_gpt2_model_paths[] = {
{MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp16-cuda", "fp16"},
};

#if USE_DML
TEST(ModelTests, DMLAdapterSelection) {
#if TEST_PHI2
auto model = Generators::CreateModel(Generators::GetOrtEnv(), PHI2_PATH);
auto d3d12Device = model->GetD3D12Device();

auto adapterLuid = d3d12Device->GetAdapterLuid();
for (const auto& provider_option: model->config_->model.decoder.session_options.provider_options) {
if (provider_option.name == "dml") {
for (const auto& [name, value] : provider_option.options) {
if (name == "luid") {
if (auto separator_position = value.find(":"); separator_position != std::string::npos) {
EXPECT_EQ(adapterLuid.HighPart, std::stol(value.substr(0, separator_position)));
EXPECT_EQ(adapterLuid.LowPart, std::stoul(value.substr(separator_position + 1)));
}
}
}
}
}
#endif
}
#endif

// DML doesn't support GPT attention
#if !USE_DML
TEST(ModelTests, GreedySearchGptFp32) {
Expand Down

0 comments on commit c15cc48

Please sign in to comment.