Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GenAI DML on MCDM compute only devices #1086

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions src/dml/dml_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <assert.h>
#include <stdexcept>
#include <variant>
#include <dxcore.h>
#include <dxcore_interface.h>
#include <dxgi1_6.h>
Expand Down Expand Up @@ -75,10 +76,52 @@ static std::vector<ComPtr<IDXGIAdapter1>> EnumerateAdapters() {
return adapter_infos;
}

static ComPtr<IDXGIAdapter1> CreatePerformantAdapter() {
static std::vector<ComPtr<IDXCoreAdapter>> EnumerateCoreAdapters() {
// Try to find Direct3D 12 Core Compute adapters.
ComPtr<IDXCoreAdapterFactory> adapterFactory;
THROW_IF_FAILED(::DXCoreCreateAdapterFactory(IID_PPV_ARGS(&adapterFactory)));

std::vector<ComPtr<IDXCoreAdapter>> adapter_infos;

ComPtr<IDXCoreAdapterList> d3D12CoreComputeAdapters;
GUID attributes[]{DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE};
THROW_IF_FAILED(adapterFactory->CreateAdapterList(_countof(attributes), attributes, IID_PPV_ARGS(&d3D12CoreComputeAdapters)));

// Ask the OS to sort for the highest performance hardware adapter.
DXCoreAdapterPreference sortPreferences[]{
DXCoreAdapterPreference::Hardware, DXCoreAdapterPreference::HighPerformance };
THROW_IF_FAILED(d3D12CoreComputeAdapters->Sort(_countof(sortPreferences), sortPreferences));

const uint32_t count{ d3D12CoreComputeAdapters->GetAdapterCount() };
for (uint32_t i = 0; i < count; ++i) {
ComPtr<IDXCoreAdapter> adapter;
THROW_IF_FAILED(d3D12CoreComputeAdapters->GetAdapter(i, IID_PPV_ARGS(&adapter)));
if (adapter) {
bool isHardware{ false };
THROW_IF_FAILED(adapter->GetProperty(DXCoreAdapterProperty::IsHardware, &isHardware));
bool hasCoreCompute = adapter->IsAttributeSupported(DXCORE_ADAPTER_ATTRIBUTE_D3D12_CORE_COMPUTE);
if (isHardware && hasCoreCompute) {
// Make sure that we are able to create the device
ComPtr<ID3D12Device> d3d12_device;
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_1_0_CORE, IID_PPV_ARGS(&d3d12_device)));
if (d3d12_device) {
adapter_infos.emplace_back(std::move(adapter));
}
}
}
}

return adapter_infos;
}

static std::variant<ComPtr<IDXCoreAdapter>, ComPtr<IDXGIAdapter1>> CreatePerformantAdapter() {
auto filtered_adapters = EnumerateAdapters();
if (filtered_adapters.empty()) {
throw std::runtime_error("No adapter is available for DML.");
auto core_adapters = EnumerateCoreAdapters();
if (core_adapters.empty()) {
throw std::runtime_error("No adapter is available for DML.");
}
return core_adapters.front();
}
return filtered_adapters.front();
}
Expand All @@ -101,13 +144,20 @@ DmlObjects CreateDmlObjects(const std::string& current_module_path) {
// Get the version from https://devblogs.microsoft.com/directx/directx12agility/. We are currently using 1.614.0.
constexpr uint32_t agility_sdk_version = 614;

if (SUCCEEDED(D3D12GetInterface(CLSID_D3D12SDKConfiguration, IID_PPV_ARGS(&d3d12_sdk_config))) &&
SUCCEEDED(d3d12_sdk_config->CreateDeviceFactory(agility_sdk_version, current_module_path.c_str(), IID_PPV_ARGS(&d3d12_factory)))) {
THROW_IF_FAILED(d3d12_factory->CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
if (std::holds_alternative<ComPtr<IDXGIAdapter1>>(adapter)) {
ComPtr<IDXGIAdapter1> dxgiAdapter = std::get<ComPtr<IDXGIAdapter1>>(adapter);
if (SUCCEEDED(D3D12GetInterface(CLSID_D3D12SDKConfiguration, IID_PPV_ARGS(&d3d12_sdk_config))) &&
SUCCEEDED(d3d12_sdk_config->CreateDeviceFactory(agility_sdk_version, current_module_path.c_str(), IID_PPV_ARGS(&d3d12_factory)))) {
THROW_IF_FAILED(d3d12_factory->CreateDevice(dxgiAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
} else {
printf("Warning: Unable to create a device from version 1.614.0 of the DirectX 12 Agility SDK. You can still use this library, but some scenarios may not work.\n");
printf("The given module path: %s", current_module_path.c_str());
THROW_IF_FAILED(D3D12CreateDevice(dxgiAdapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
}
} else {
printf("Warning: Unable to create a device from version 1.614.0 of the DirectX 12 Agility SDK. You can still use this library, but some scenarios may not work.\n");
printf("The given module path: %s", current_module_path.c_str());
THROW_IF_FAILED(D3D12CreateDevice(adapter.Get(), D3D_FEATURE_LEVEL_11_0, IID_PPV_ARGS(&dml_objects.d3d12_device)));
// Use DXCore adapter
ComPtr<IDXCoreAdapter> dxcoreAdapter = std::get<ComPtr<IDXCoreAdapter>>(adapter);
THROW_IF_FAILED(D3D12CreateDevice(dxcoreAdapter.Get(), D3D_FEATURE_LEVEL_1_0_CORE, IID_PPV_ARGS(&dml_objects.d3d12_device)));
}

THROW_IF_FAILED(dml_objects.d3d12_device->CreateCommandQueue(&command_queue_description, IID_PPV_ARGS(&dml_objects.command_queue)));
Expand Down