Skip to content

Commit

Permalink
Distinguish between DML and the generic 'GPU' term. This is needed fo…
Browse files Browse the repository at this point in the history
…r packaging DML EP in the same ORT GPU pkg. (#22657)

### Description
Distinguish between DML and the generic 'GPU' term. This is needed for
packaging DML EP in the same ORT GPU pkg.

### Motivation and Context
Customer requirement.
  • Loading branch information
pranavsharma authored Oct 30, 2024
1 parent df236c7 commit 03ea5dc
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 70 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct OrtDevice {
static const DeviceType GPU = 1; // Nvidia or AMD
static const DeviceType FPGA = 2;
static const DeviceType NPU = 3; // Ascend
static const DeviceType DML = 4;

struct MemType {
// Pre-defined memory types.
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,16 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
*out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA) == 0 ||
strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 ||
strcmp(name1, onnxruntime::DML) == 0 ||
strcmp(name1, onnxruntime::HIP) == 0 ||
strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 ||
strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::DML) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
mem_type1);
} else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) {
*out = new OrtMemoryInfo(
name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast<OrtDevice::DeviceId>(id1)), id1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,19 @@ namespace Dml
D3D12_HEAP_FLAGS heapFlags,
D3D12_RESOURCE_FLAGS resourceFlags,
D3D12_RESOURCE_STATES initialState,
std::unique_ptr<DmlSubAllocator>&& subAllocator
)
std::unique_ptr<DmlSubAllocator>&& subAllocator)
: onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
)
),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator))
{
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))),
m_device(device),
m_heapProperties(heapProps),
m_heapFlags(heapFlags),
m_resourceFlags(resourceFlags),
m_initialState(initialState),
m_context(context),
m_subAllocator(std::move(subAllocator)) {
}

/*static*/ gsl::index BucketizedBufferAllocator::GetBucketIndexFromSize(uint64_t size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ namespace Dml
class DmlExternalBufferAllocator : public onnxruntime::IAllocator
{
public:
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)
))
{
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}
DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator(
OrtMemoryInfo(
"DML",
OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0))) {
m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false);
}

void* Alloc(size_t size) final
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,17 @@ namespace Dml
bool enableMetacommands,
bool enableGraphCapture,
bool enableSyncSpinning,
bool disableMemoryArena) :
IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0))
{
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE)
{
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}
bool disableMemoryArena) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, 0)) {
D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue();
if (queueType != D3D12_COMMAND_LIST_TYPE_DIRECT && queueType != D3D12_COMMAND_LIST_TYPE_COMPUTE) {
// DML requires either DIRECT or COMPUTE command queues.
ORT_THROW_HR(E_INVALIDARG);
}

ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));
ComPtr<ID3D12Device> device;
GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf())));

m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
m_impl = wil::MakeOrThrow<ExecutionProviderImpl>(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena);
}

std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ namespace Dml

bool CanCopy(const OrtDevice& srcDevice, const OrtDevice& dstDevice) const final
{
return (srcDevice.Type() == OrtDevice::GPU) ||
(dstDevice.Type() == OrtDevice::GPU);
return (srcDevice.Type() == OrtDevice::DML) ||
(dstDevice.Type() == OrtDevice::DML);
}

private:
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,23 @@ static void ResolveMemoryPatternFlags(SessionState& session_state) {
}
}
}

// This function is called when the session is being initialized.
// For now, this function only checks for invalid combination of DML EP with other EPs.
// TODO: extend this function to check for other invalid combinations of EPs.
common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() const {
// DML EP is only allowed with CPU EP
bool has_dml_ep = execution_providers_.Get(kDmlExecutionProvider) != nullptr;
if (has_dml_ep) {
const auto& ep_list = execution_providers_.GetIds();
for (const auto& ep : ep_list) {
if (ep == kDmlExecutionProvider || ep == kCpuExecutionProvider) continue;
return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "DML EP can be used with only CPU EP.");
}
}
return Status::OK();
}

#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
// VC++ reports: "Releasing unheld lock 'l' in function 'onnxruntime::InferenceSession::Initialize'". But I don't see anything wrong.
Expand Down Expand Up @@ -1719,6 +1736,11 @@ common::Status InferenceSession::Initialize() {
execution_providers_.SetCpuProviderWasImplicitlyAdded(true);
}

// Check for the presence of an invalid combination of execution providers in the session
// For e.g. we don't support DML EP and other GPU EPs to be present in the same session
// This check is placed here because it serves as a common place for all language bindings.
ORT_RETURN_IF_ERROR_SESSIONID_(HasInvalidCombinationOfExecutionProviders());

// re-acquire mutex
std::lock_guard<std::mutex> l(session_mutex_);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ class InferenceSession {
const Environment& session_env);
void ConstructorCommon(const SessionOptions& session_options,
const Environment& session_env);

[[nodiscard]] common::Status HasInvalidCombinationOfExecutionProviders() const;
[[nodiscard]] common::Status SaveModelMetadata(const onnxruntime::Model& model);

#if !defined(ORT_MINIMAL_BUILD)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/onnxruntime_pybind_mlvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ void DmlToCpuMemCpy(void* dst, const void* src, size_t num_bytes) {

const std::unordered_map<OrtDevice::DeviceType, MemCpyFunc>* GetDmlToHostMemCpyFunction() {
static std::unordered_map<OrtDevice::DeviceType, MemCpyFunc> map{
{OrtDevice::GPU, DmlToCpuMemCpy}};
{OrtDevice::DML, DmlToCpuMemCpy}};

return &map;
}
Expand Down
57 changes: 34 additions & 23 deletions onnxruntime/python/onnxruntime_pybind_ortvalue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,22 @@ void addOrtValueMethods(pybind11::module& m) {
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in CUDA
CreateGenericMLValue(nullptr, GetRocmAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToRocmMemCpy);
#elif USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
// InputDeflist is null because OrtValue creation is not tied to a specific model
// Likewise, there is no need to specify the name (as the name was previously used to lookup the def list)
// TODO: Add check to ensure that string arrays are not passed - we currently don't support string tensors in DML
CreateGenericMLValue(
nullptr, GetDmlAllocator(device.Id()), "", array_on_cpu, ml_value.get(), true, false, CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CUDA device using this package of OnnxRuntime. "
"Please use the CUDA package of OnnxRuntime to use this feature.");
#endif
} else if (device.Type() == OrtDevice::NPU) {
#ifdef USE_CANN
Expand All @@ -116,9 +122,9 @@ void addOrtValueMethods(pybind11::module& m) {
CreateGenericMLValue(nullptr, GetCannAllocator(device.Id()), "", array_on_cpu, ml_value.get(),
true, false, CpuToCannMemCpy);
#else
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
throw std::runtime_error(
"Can't allocate memory on the CANN device using this package of OnnxRuntime. "
"Please use the CANN package of OnnxRuntime to use this feature.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot place the OrtValue on this device");
Expand Down Expand Up @@ -160,19 +166,24 @@ void addOrtValueMethods(pybind11::module& m) {
}

onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#elif USE_DML
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToRocmMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else if (device.Type() == OrtDevice::DML) {
#if USE_DML
onnxruntime::python::CopyDataToTensor(
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
py_values,
values_type,
*(ml_value->GetMutable<Tensor>()),
CpuToDmlMemCpy);
#else
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
throw std::runtime_error(
"Unsupported GPU device: Cannot find the supported GPU device.");
#endif
} else {
throw std::runtime_error("Unsupported device: Cannot update the OrtValue on this device");
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) {
case OrtDevice::CPU:
return CPU;
case OrtDevice::GPU:
#ifdef USE_DML
return DML;
#else
return CUDA;
#endif
case OrtDevice::DML:
return DML;
case OrtDevice::FPGA:
return "FPGA";
case OrtDevice::NPU:
Expand Down Expand Up @@ -1579,7 +1577,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.def_static("cann", []() { return OrtDevice::NPU; })
.def_static("fpga", []() { return OrtDevice::FPGA; })
.def_static("npu", []() { return OrtDevice::NPU; })
.def_static("dml", []() { return OrtDevice::GPU; })
.def_static("dml", []() { return OrtDevice::DML; })
.def_static("webgpu", []() { return OrtDevice::GPU; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });

Expand Down

0 comments on commit 03ea5dc

Please sign in to comment.