diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index a5b5d2edde46c..585db158f8f10 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -324,6 +324,13 @@ class IExecutionProvider { return default_device_; }; + /** + * Return the appropriate OrtDevice object given OrtMemType that can be used directly by external callers. + */ + virtual OrtDevice GetExternalOrtDeviceByMemType(OrtMemType mem_type) const { + return GetOrtDeviceByMemType(mem_type); + }; + /** * Create Preferred allocators for the current Execution Provider * This function is a stateless function which creates new instances of Allocator, without storing them in EP. diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index f15543f22f21d..3abe5e0f76217 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -24,6 +24,7 @@ struct OrtDevice { static const MemoryType CUDA_PINNED = 1; static const MemoryType HIP_PINNED = 2; static const MemoryType CANN_PINNED = 3; + static const MemoryType DML_EXTERNAL = 4; }; constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 5e66f2b99fded..87be13f5e539f 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -139,12 +139,17 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA *out = new OrtMemoryInfo(onnxruntime::CPU, type, OrtDevice(), id1, mem_type1); } else if (strcmp(name1, onnxruntime::CUDA) == 0 || strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 || - strcmp(name1, onnxruntime::DML) == 0 || strcmp(name1, onnxruntime::HIP) == 0 || strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); + } else if (strcmp(name1, onnxruntime::DML) == 0) { + // Since EPs cannot have 2 allocators with the same OrtMemType and Memory ID, + // we use -1 as the memory ID to represent external allocations that don't have any allocator. + *out = new OrtMemoryInfo( + name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DML_EXTERNAL, static_cast(id1)), + -1, mem_type1); } else if (strcmp(name1, onnxruntime::OpenVINO_RT_NPU) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, diff --git a/onnxruntime/core/framework/bfc_arena.cc b/onnxruntime/core/framework/bfc_arena.cc index 13f9656ae0595..9d58bf52de3e6 100644 --- a/onnxruntime/core/framework/bfc_arena.cc +++ b/onnxruntime/core/framework/bfc_arena.cc @@ -41,23 +41,8 @@ BFCArena::BFCArena(std::unique_ptr resource_allocator, memory_limit_ = total_memory; stats_.bytes_limit = static_cast(total_memory); - arena_extend_strategy_ = arena_extend_strategy; + SetArenaExtendStrategy(arena_extend_strategy); - // We never want to shrink the initial allocation if the arena extend strategy is kNextPowerOfTwo. - // This could seem confusingly arbitrary but the rationale is as follows: - // The user selected initial allocation chunk is only valid for the arena extend strategy kNextPowerOfTwo - // and the user has likely chosen this initial value so that any ad-hoc arena extensions/shrinkages could potentially - // be avoided. So we do not consider the initial allocation for shrinkage whatever its usage status. - // On the other hand, if the arena extension strategy is kSameAsRequested, any initial chunk set by the user or otherwise, - // is moot and the arena will only extend based on the request size. In these cases, we consider any allocation for shrinkage - // if it is left unused (even if it is the first allocation). - if (arena_extend_strategy_ == ArenaExtendStrategy::kSameAsRequested) { - // Consider all allocation regions (including first allocation region) for shrinkage - consider_first_allocation_region_for_shrinkage_ = true; - } else { // arena_extend_strategy_ == kNextPowerOfTwo - // Do not consider the first allocation region for shrinkage - consider_first_allocation_region_for_shrinkage_ = false; - } // Create a bunch of bins of various good sizes. // We create bins to fit all possible ranges that cover the @@ -91,6 +76,29 @@ BFCArena::~BFCArena() { } } +void BFCArena::UpdateFirstAllocationShrinkageLogic() { + // We never want to shrink the initial allocation if the arena extend strategy is kNextPowerOfTwo. + // This could seem confusingly arbitrary but the rationale is as follows: + // The user selected initial allocation chunk is only valid for the arena extend strategy kNextPowerOfTwo + // and the user has likely chosen this initial value so that any ad-hoc arena extensions/shrinkages could potentially + // be avoided. So we do not consider the initial allocation for shrinkage whatever its usage status. + // On the other hand, if the arena extension strategy is kSameAsRequested, any initial chunk set by the user or otherwise, + // is moot and the arena will only extend based on the request size. In these cases, we consider any allocation for shrinkage + // if it is left unused (even if it is the first allocation). + if (arena_extend_strategy_ == ArenaExtendStrategy::kSameAsRequested) { + // Consider all allocation regions (including first allocation region) for shrinkage + consider_first_allocation_region_for_shrinkage_ = true; + } else { // arena_extend_strategy_ == kNextPowerOfTwo + // Do not consider the first allocation region for shrinkage + consider_first_allocation_region_for_shrinkage_ = false; + } +} + +void BFCArena::SetArenaExtendStrategy(ArenaExtendStrategy arena_extend_strategy) { + arena_extend_strategy_ = arena_extend_strategy; + UpdateFirstAllocationShrinkageLogic(); +} + BFCArena::Chunk* BFCArena::ChunkFromHandle(ChunkHandle h) { ORT_ENFORCE(h < chunks_.size()); return &(chunks_[h]); diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index 5e4cd9f62f11b..429ccd1633a30 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -77,6 +77,11 @@ class BFCArena : public IAllocator { ~BFCArena() override; + // Allows the caller to change the arena extend strategy after the allocator is done initializing. + // For example, kSameAsRequested may be desirable in certain situations and kNextPowerOfTwo may be + // desirable in others. + void SetArenaExtendStrategy(ArenaExtendStrategy arena_extend_strategy); + // If size is 0, then this function returns either NULL, // or a unique pointer value that can later be successfully // passed to free(). Whatever, do not dereference that pointer @@ -123,6 +128,9 @@ class BFCArena : public IAllocator { private: void DeallocateRawInternal(void* ptr); + // Updates whether the first allocation should be considered for shrinkage depending on the strategy type. + void UpdateFirstAllocationShrinkageLogic(); + // A ChunkHandle is an index into the chunks_ vector in BFCAllocator // kInvalidChunkHandle means an invalid chunk using ChunkHandle = size_t; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 9eed0249711f9..b4691a81e7bdb 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -163,6 +163,19 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, return Status::OK(); } +#ifdef USE_DML + const bool bothValuesOnGPU = copy_info.source_device.Type() == OrtDevice::GPU && copy_info.target_device.Type() == OrtDevice::GPU; + const bool sourceIsDmlAlloc = copy_info.source_device.MemType() == OrtDevice::MemType::DEFAULT || copy_info.source_device.MemType() == OrtDevice::MemType::DML_EXTERNAL; + const bool targetIsInternalAlloc = copy_info.target_device.MemType() == OrtDevice::MemType::DEFAULT; + const bool bothValuesOnSameDevice = copy_info.source_device.Id() == copy_info.target_device.Id(); + + // The DML EP supports binding external allocations directly, even if the memory types don't match, as long as they are on the same D3D12 device + if (bothValuesOnGPU && sourceIsDmlAlloc && targetIsInternalAlloc && bothValuesOnSameDevice) { + target_mlvalue = source_mlvalue; + return Status::OK(); + } +#endif + auto allocator = session_state.GetAllocator(copy_info.target_device); if (!target_mlvalue.IsAllocated()) { ORT_ENFORCE(allocator != nullptr, "Failed to find allocator for device ", copy_info.target_device.ToString()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index ecef48dc6d480..585a1c4b9859f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -10,6 +10,7 @@ interface ID3D12Resource; #include "core/common/status.h" #include "core/framework/data_transfer.h" #include "IWinmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.h" #include "core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h" namespace onnxruntime @@ -21,12 +22,6 @@ namespace onnxruntime class KernelRegistry; } -enum class AllocatorRoundingMode -{ - Disabled = 0, - Enabled = 1, -}; - namespace Dml { std::unique_ptr CreateExecutionProvider( @@ -35,9 +30,9 @@ namespace Dml bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, - bool disableMemoryArena); + bool disableMemoryArena, + bool enableBfcAllocator); - ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); void ReleaseCompletedReferences(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 88e3dd487d427..b7f9b2879be98 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -10,6 +10,7 @@ #include #include "core/framework/op_kernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.h" #include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; @@ -25,6 +26,11 @@ namespace onnxruntime class Node; } +namespace Dml +{ + struct TaggedPointer; +} + namespace Windows::AI::MachineLearning::Adapter { interface __declspec(uuid("5b19a18a-5ed5-4df2-a363-21b89380a698")) @@ -37,19 +43,9 @@ namespace Windows::AI::MachineLearning::Adapter // the provider's underlying queues. virtual void QueueReference(IUnknown *object) = 0; - virtual void GetShadowCopyIfRequired( - bool isInternalOperator, - IUnknown* data, - IUnknown** dataCopy) const = 0; - - virtual void GetABIDataInterface( - bool isInternalOperator, - IUnknown* data, - IUnknown** abiData) const = 0; + virtual Dml::D3D12BufferRegion GetBufferRegion(void* opaquePointer, uint64_t size) const = 0; - virtual uint64_t TryGetPooledAllocationId( - IUnknown* data, - bool isInternalOperator) = 0; + virtual uint64_t GetUniqueId(void* opaquePointer) = 0; virtual void GetABIExecutionInterfaceAndInvalidateState( bool isInternalOperator, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index 353f698bb6f2c..fc900b9444b6e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -564,6 +564,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( // // For backward compatibility, this does not propagate errors for external operators static_cast(m_kernelRegistry->RegisterCustomKernel(create_info)); // ignore result + m_hasExternalOperators = true; } return S_OK; @@ -571,4 +572,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( ORT_CATCH_RETURN } +bool STDMETHODCALLTYPE AbiCustomRegistry::HasExternalOperators() const noexcept +{ + return m_hasExternalOperators; +} + } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index eb84b4f822e92..79591abe8019d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -50,6 +50,8 @@ class AbiCustomRegistry : public WRL::Base> GetRegistries() { std::list> registries; @@ -108,6 +110,8 @@ class AbiCustomRegistry : public WRL::Base m_internalRegInfoMap; + mutable bool m_hasExternalOperators = false; + }; } // namespace Windows::AI::MachineLearning::Adapter diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp index b1714a8220cd1..385897b4c1761 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.cpp @@ -6,19 +6,13 @@ #include "core/session/onnxruntime_c_api.h" #include "BucketizedBufferAllocator.h" -#include "DmlSubAllocator.h" +#include "DmlAllocationInfo.h" +#include "DmlCommittedResourceWrapper.h" +#include "DmlAllocatorRoundingMode.h" // #define PRINT_OUTSTANDING_ALLOCATIONS namespace Dml { - AllocationInfo::~AllocationInfo() - { - if (m_owner) - { - m_owner->FreeResource(this, m_pooledResourceId); - } - } - BucketizedBufferAllocator::~BucketizedBufferAllocator() { #ifdef PRINT_OUTSTANDING_ALLOCATIONS @@ -40,23 +34,14 @@ namespace Dml const D3D12_HEAP_PROPERTIES& heapProps, D3D12_HEAP_FLAGS heapFlags, D3D12_RESOURCE_FLAGS resourceFlags, - D3D12_RESOURCE_STATES initialState, - std::unique_ptr&& subAllocator - ) - : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) - ) - ), + D3D12_RESOURCE_STATES initialState) + : m_device(device), m_heapProperties(heapProps), m_heapFlags(heapFlags), m_resourceFlags(resourceFlags), m_initialState(initialState), - m_context(context), - m_subAllocator(std::move(subAllocator)) + m_context(context) { } @@ -68,21 +53,50 @@ namespace Dml gsl::index index = static_cast(ceil(log2(size))); assert((1ull << index) >= size); // This must be true unless there were some strange rounding issues - // The smallest bucket is 2^n bytes large, where n = c_minResourceSizeExponent - index = std::max(index, c_minResourceSizeExponent); - index -= c_minResourceSizeExponent; + // The smallest bucket is 2^n bytes large, where n = MinResourceSizeExponent + index = std::max(index, MinResourceSizeExponent); + index -= MinResourceSizeExponent; return index; } /*static*/ uint64_t BucketizedBufferAllocator::GetBucketSizeFromIndex(gsl::index index) { - return (1ull << (index + c_minResourceSizeExponent)); + return (1ull << (index + MinResourceSizeExponent)); + } + + ComPtr BucketizedBufferAllocator::AllocCommittedResource(size_t size) + { + ComPtr resource; + auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( + &CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_NONE, + &buffer, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) + )); + + ComPtr resourceWrapper; + wil::MakeOrThrow(std::move(resource)).As(&resourceWrapper); + return resourceWrapper; } - void* BucketizedBufferAllocator::Alloc(size_t size) + AllocationInfo* BucketizedBufferAllocator::GetAllocationInfo(void* opaquePointer) { - return Alloc(size, m_defaultRoundingMode); + return static_cast(opaquePointer); + } + + D3D12BufferRegion BucketizedBufferAllocator::CreateBufferRegion(void* opaquePointer, uint64_t sizeInBytes) const + { + auto allocationInfo = static_cast(opaquePointer); + + // Make sure that we are aligned to 4 bytes to satisfy DML's requirements + constexpr uint64_t DML_ALIGNMENT = 4; + sizeInBytes = (1 + (sizeInBytes - 1) / DML_ALIGNMENT) * DML_ALIGNMENT; + + return D3D12BufferRegion(0, sizeInBytes, allocationInfo->GetD3D12Resource()); } void* BucketizedBufferAllocator::Alloc(size_t size, AllocatorRoundingMode roundingMode) @@ -114,7 +128,7 @@ namespace Dml if (bucket->resources.empty()) { // No more resources in this bucket - allocate a new one - resourceWrapper = m_subAllocator->Alloc(onnxruntime::narrow(bucketSize)); + resourceWrapper = AllocCommittedResource(onnxruntime::narrow(bucketSize)); resourceId = ++m_currentResourceId; } else @@ -129,7 +143,7 @@ namespace Dml { // The allocation will not be pooled. Construct a new one bucketSize = (size + 3) & ~3; - resourceWrapper = m_subAllocator->Alloc(onnxruntime::narrow(bucketSize)); + resourceWrapper = AllocCommittedResource(onnxruntime::narrow(bucketSize)); resourceId = ++m_currentResourceId; } @@ -160,10 +174,14 @@ namespace Dml allocInfo.Attach(static_cast(p)); } - void BucketizedBufferAllocator::FreeResource(void* p, uint64_t pooledResourceId) + uint64_t BucketizedBufferAllocator::GetUniqueId(void* opaquePointer) { - AllocationInfo *allocInfo = static_cast(p); + const auto* allocInfo = static_cast(opaquePointer); + return allocInfo->GetPooledResourceId(); + } + void BucketizedBufferAllocator::FreeResource(AllocationInfo* allocInfo, uint64_t pooledResourceId) + { assert(allocInfo != nullptr); // Can't free nullptr if (allocInfo->GetOwner() != this) @@ -174,7 +192,7 @@ namespace Dml // Free the resource to the pool if its size matches a bucket size gsl::index bucketIndex = GetBucketIndexFromSize(allocInfo->GetRequestedSize()); - if (GetBucketSizeFromIndex(bucketIndex) == allocInfo->GetResource()->GetDesc().Width) + if (GetBucketSizeFromIndex(bucketIndex) == allocInfo->GetD3D12Resource()->GetDesc().Width) { assert(gsl::narrow_cast(m_pool.size()) > bucketIndex); @@ -189,11 +207,11 @@ namespace Dml if (!m_context->IsClosed()) { // Free the underlying allocation once queued work has completed. - #ifdef _GAMING_XBOX - m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetResource()).Get()); - #else - m_context->QueueReference(allocInfo->GetResource()); - #endif +#ifdef _GAMING_XBOX + m_context->QueueReference(WRAP_GRAPHICS_UNKNOWN(allocInfo->GetD3D12Resource()).Get()); +#else + m_context->QueueReference(allocInfo->GetD3D12Resource()); +#endif } allocInfo->DetachResourceWrapper(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h index 16283d5b19c9c..8e255f2decd2d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h @@ -6,26 +6,16 @@ #include "core/framework/allocator.h" #include "ExecutionContext.h" #include "DmlResourceWrapper.h" -#include "AllocationInfo.h" +#include "DmlSubAllocator.h" +#include "DmlAllocatorRoundingMode.h" namespace Dml { - class DmlSubAllocator; - - class CPUAllocator : public onnxruntime::IAllocator - { - public: - explicit CPUAllocator(OrtMemType memType); - - void* Alloc(size_t size) override; - void Free(void* p) override; - }; - // Implements a Lotus allocator for D3D12 heap buffers, using a bucket allocation strategy. The allocator // maintains a set of fixed-size buckets, with each bucket containing one or more D3D12 buffers of that fixed size. // All requested allocation sizes are rounded up to the nearest bucket size, which ensures minimal fragmentation // while providing an upper bound on the amount of memory "wasted" with each allocation. - class BucketizedBufferAllocator : public onnxruntime::IAllocator + class BucketizedBufferAllocator : public DmlSubAllocator { public: ~BucketizedBufferAllocator(); @@ -38,21 +28,25 @@ namespace Dml const D3D12_HEAP_PROPERTIES& heapProps, D3D12_HEAP_FLAGS heapFlags, D3D12_RESOURCE_FLAGS resourceFlags, - D3D12_RESOURCE_STATES initialState, - std::unique_ptr&& subAllocator); + D3D12_RESOURCE_STATES initialState); + + ComPtr AllocCommittedResource(size_t size); // Returns the information associated with an opaque allocation handle returned by IAllocator::Alloc. const AllocationInfo* DecodeDataHandle(const void* opaqueHandle); void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode); + AllocationInfo* GetAllocationInfo(void* opaquePointer); + D3D12BufferRegion CreateBufferRegion(void* opaquePointer, uint64_t sizeInBytes) const; + uint64_t GetUniqueId(void* opaquePointer); + public: // onnxruntime::IAllocator void* Alloc(size_t size, AllocatorRoundingMode roundingMode); - void* Alloc(size_t size) final; - void Free(void* p) final; + void Free(void* p); private: - static const uint32_t c_minResourceSizeExponent = 16; // 2^16 = 64KB + static const uint32_t MinResourceSizeExponent = 16; // 2^16 = 64KB // The pool consists of a number of buckets, and each bucket contains a number of resources of the same size. // The resources in each bucket are always sized as a power of two, and each bucket contains resources twice @@ -72,7 +66,7 @@ namespace Dml static uint64_t GetBucketSizeFromIndex(gsl::index index); friend class AllocationInfo; - void FreeResource(void* p, uint64_t resourceId); + void FreeResource(AllocationInfo* allocInfo, uint64_t resourceId) final; ComPtr m_device; D3D12_HEAP_PROPERTIES m_heapProperties; @@ -84,12 +78,13 @@ namespace Dml size_t m_currentAllocationId = 0; uint64_t m_currentResourceId = 0; + ComPtr m_context; + // Unless specifically requested, allocation sizes are not rounded to enable pooling // until SetDefaultRoundingMode is called. This should be done at completion of session // initialization. AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Disabled; - ComPtr m_context; std::unique_ptr m_subAllocator; #ifndef NDEBUG diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocationInfo.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocationInfo.cpp new file mode 100644 index 0000000000000..5db1289778819 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocationInfo.cpp @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" +#include "DmlAllocationInfo.h" +#include "DmlReservedResourceSubAllocator.h" +#include "DmlSubAllocator.h" + +namespace Dml +{ + + AllocationInfo::~AllocationInfo() + { + if (m_owner) + { + m_owner->FreeResource(this, m_pooledResourceId); + } + } + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocationInfo.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocationInfo.h new file mode 100644 index 0000000000000..f61e59edd5159 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocationInfo.h @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "DmlReservedResourceWrapper.h" + +namespace Dml +{ + class DmlSubAllocator; + + class AllocationInfo : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, IUnknown> + { + public: + AllocationInfo( + DmlSubAllocator* owner, + size_t id, + uint64_t pooledResourceId, + DmlResourceWrapper* resourceWrapper, + size_t requestedSize) + : m_owner(owner) + , m_allocationId(id) + , m_pooledResourceId(pooledResourceId) + , m_resourceWrapper(resourceWrapper) + , m_requestedSize(requestedSize) + {} + + ~AllocationInfo(); + + DmlSubAllocator* GetOwner() const + { + return m_owner; + } + + ID3D12Resource* GetD3D12Resource() const + { + return m_resourceWrapper->GetD3D12Resource(); + } + + ComPtr DetachResourceWrapper() + { + return std::move(m_resourceWrapper); + } + + size_t GetRequestedSize() const + { + return m_requestedSize; + } + + size_t GetId() const + { + return m_allocationId; + } + + uint64_t GetPooledResourceId() const + { + return m_pooledResourceId; + } + + private: + DmlSubAllocator* m_owner; + size_t m_allocationId; // For debugging purposes + uint64_t m_pooledResourceId; + Microsoft::WRL::ComPtr m_resourceWrapper; + + // The size requested during Alloc(), which may be smaller than the physical resource size + size_t m_requestedSize; + }; +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocatorRoundingMode.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocatorRoundingMode.h new file mode 100644 index 0000000000000..9dfd63ad4c2a2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlAllocatorRoundingMode.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Dml +{ + enum class AllocatorRoundingMode + { + Disabled = 0, + Enabled = 1, + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBuffer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBuffer.cpp new file mode 100644 index 0000000000000..4801ede60038e --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBuffer.cpp @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" +#include "DmlBuffer.h" +#include "DmlGpuAllocator.h" +#include "DmlAllocatorRoundingMode.h" + +namespace Dml +{ + +/*explicit*/ DmlBuffer::DmlBuffer(DmlGpuAllocator* allocator, uint64_t sizeInBytes, AllocatorRoundingMode roundingMode) + : m_allocator(allocator) +{ + m_opaqueData = m_allocator->Alloc(sizeInBytes, roundingMode); + ORT_THROW_HR_IF(E_OUTOFMEMORY, m_opaqueData == nullptr); + + m_bufferRegion = m_allocator->CreateBufferRegion(m_opaqueData, sizeInBytes); +} + +DmlBuffer::~DmlBuffer() +{ + if (m_opaqueData != nullptr) + { + m_allocator->Free(m_opaqueData); + } +} + +DmlBuffer::DmlBuffer(DmlBuffer&& other) noexcept +{ + m_opaqueData = other.m_opaqueData; + m_allocator = other.m_allocator; + m_bufferRegion = std::move(other.m_bufferRegion); + other.m_opaqueData = nullptr; +} + +DmlBuffer& DmlBuffer::operator=(DmlBuffer&& other) noexcept +{ + m_opaqueData = other.m_opaqueData; + m_allocator = other.m_allocator; + m_bufferRegion = std::move(other.m_bufferRegion); + other.m_opaqueData = nullptr; + return *this; +} + +ID3D12Resource* DmlBuffer::GetD3D12Resource() const +{ + return m_bufferRegion.GetD3D12Resource(); +} + +uint64_t DmlBuffer::Offset() const +{ + return m_bufferRegion ? m_bufferRegion.Offset() : 0; +} + +uint64_t DmlBuffer::SizeInBytes() const +{ + return m_bufferRegion ? m_bufferRegion.SizeInBytes() : 0; +} + +DML_BUFFER_BINDING DmlBuffer::GetBufferBinding() const +{ + return m_bufferRegion ? m_bufferRegion.GetBufferBinding() + : DML_BUFFER_BINDING{}; +} + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBuffer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBuffer.h new file mode 100644 index 0000000000000..ba46cbd08d5fd --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBuffer.h @@ -0,0 +1,45 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "DmlBufferRegion.h" +#include "DmlAllocatorRoundingMode.h" + +namespace Dml +{ + +class DmlGpuAllocator; +class OpKernelContext; + +// Owns a D3D12 default heap buffer allocated using the DML device's +// allocator. This is essentially a convenience wrapper over a device memory +// allocation as well as the buffer region that spans it. When this object is +// destructed, the device memory is freed to the allocator. +class DmlBuffer +{ + public: + explicit DmlBuffer(DmlGpuAllocator* allocator, uint64_t sizeInBytes, AllocatorRoundingMode roundingMode); + ~DmlBuffer(); + + // Move-only + DmlBuffer(const DmlBuffer&) = delete; + DmlBuffer& operator=(const DmlBuffer&) = delete; + DmlBuffer(DmlBuffer&&) noexcept; + DmlBuffer& operator=(DmlBuffer&&) noexcept; + + ID3D12Resource* GetD3D12Resource() const; + uint64_t Offset() const; + uint64_t SizeInBytes() const; + const D3D12BufferRegion& Region() const { return m_bufferRegion; } + DML_BUFFER_BINDING GetBufferBinding() const; + + explicit operator bool() const { return !!m_bufferRegion; } + + private: + DmlGpuAllocator* m_allocator; + D3D12BufferRegion m_bufferRegion; + void* m_opaqueData; +}; + +} // namespace tfdml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.cpp new file mode 100644 index 0000000000000..627e383a17195 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" +#include "DmlBufferRegion.h" + +namespace Dml +{ + + D3D12BufferRegion::D3D12BufferRegion(uint64_t offset, uint64_t sizeInBytes, ID3D12Resource* resource) + : m_resource(resource), + m_offset(offset), + m_sizeInBytes(sizeInBytes) + { + ORT_THROW_HR_IF(E_INVALIDARG, m_resource == nullptr); + + // Regions cannot be empty. + ORT_THROW_HR_IF(E_INVALIDARG, m_sizeInBytes == 0); + + // Regions cannot extend beyond the size of the resource. + uint64_t bufferSize = m_resource->GetDesc().Width; + ORT_THROW_HR_IF(E_INVALIDARG, m_offset >= bufferSize); + ORT_THROW_HR_IF(E_INVALIDARG, m_sizeInBytes > bufferSize - offset); + + // All three resources, if provided, must be identical aside from state. + assert(m_resource->GetDesc().Dimension == D3D12_RESOURCE_DIMENSION_BUFFER); + assert(m_resource->GetDesc().Width == bufferSize); + } + + D3D12BufferRegion::D3D12BufferRegion(D3D12BufferRegion&& that) noexcept + { + std::swap(this->m_resource, that.m_resource); + std::swap(this->m_offset, that.m_offset); + std::swap(this->m_sizeInBytes, that.m_sizeInBytes); + } + + D3D12BufferRegion& D3D12BufferRegion::operator=(D3D12BufferRegion&& that) noexcept + { + std::swap(this->m_resource, that.m_resource); + std::swap(this->m_offset, that.m_offset); + std::swap(this->m_sizeInBytes, that.m_sizeInBytes); + return *this; + } + + ID3D12Resource* D3D12BufferRegion::GetD3D12Resource() const + { + return m_resource; + } + + uint64_t D3D12BufferRegion::Offset() const + { + return m_resource ? m_offset : 0; + } + + uint64_t D3D12BufferRegion::SizeInBytes() const + { + return m_resource ? m_sizeInBytes : 0; + } + + DML_BUFFER_BINDING D3D12BufferRegion::GetBufferBinding() const + { + if (!m_resource) + { + return DML_BUFFER_BINDING{}; + } + + return DML_BUFFER_BINDING{m_resource, m_offset, m_sizeInBytes}; + } + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.h new file mode 100644 index 0000000000000..d312cad247981 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "directx/d3d12.h" +#include +#include +#include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h" + +namespace Dml +{ + // Represents a region of a D3D12 buffer resource. A buffer region has an + // underlying ID3D12Resource* (of D3D12_RESOURCE_DIMENSION_BUFFER), an offset in + // bytes from the beginning of that buffer, and a size in bytes of the region. + class D3D12BufferRegion + { + public: + D3D12BufferRegion() = default; + + // References a region of a buffer. The respective ID3D12Resource objects + // must be in the appropriate states. Each resource is optional, but if more + // than one are provided they must map to the same region of memory. + D3D12BufferRegion(uint64_t offset, uint64_t sizeInBytes, ID3D12Resource* resource); + + // Move-only + D3D12BufferRegion(const D3D12BufferRegion&) = default; + D3D12BufferRegion& operator=(const D3D12BufferRegion&) = default; + D3D12BufferRegion(D3D12BufferRegion&&) noexcept; + D3D12BufferRegion& operator=(D3D12BufferRegion&&) noexcept; + ID3D12Resource* GetD3D12Resource() const; + + uint64_t Offset() const; + uint64_t SizeInBytes() const; + + DML_BUFFER_BINDING GetBufferBinding() const; + + explicit operator bool() const { return m_resource != nullptr; } + + // Creates a subregion at an offset from the start of this region. If no + // size is provided the region runs to the end of the current region. + inline D3D12BufferRegion Subregion(uint64_t offset, uint64_t sizeInBytes = 0) const + { + // start of subregion must be within current region + ORT_THROW_HR_IF(E_INVALIDARG, offset >= m_sizeInBytes); + sizeInBytes = sizeInBytes == 0 ? m_sizeInBytes - offset : sizeInBytes; + // end of subregion must be within current region + ORT_THROW_HR_IF(E_INVALIDARG, sizeInBytes > m_sizeInBytes - offset); + + return D3D12BufferRegion(m_offset + offset, sizeInBytes, m_resource); + } + + private: + ID3D12Resource* m_resource = nullptr; + uint64_t m_offset = 0; + uint64_t m_sizeInBytes = 0; + }; + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp index 5254b23f56376..040cd9c895fa4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.cpp @@ -4,7 +4,6 @@ #include "precomp.h" #include "DmlCommandRecorder.h" #include "CommandQueue.h" -#include "BucketizedBufferAllocator.h" using namespace Dml; @@ -22,9 +21,9 @@ DmlCommandRecorder::DmlCommandRecorder( ORT_THROW_IF_FAILED(dmlDevice->CreateCommandRecorder(IID_PPV_ARGS(&m_recorder))); } -void DmlCommandRecorder::SetAllocator(std::weak_ptr allocator) +void DmlCommandRecorder::SetAllocator(std::weak_ptr allocator) { - m_bufferAllocator = allocator; + m_allocator = allocator; } void DmlCommandRecorder::InitializeOperator( @@ -57,21 +56,14 @@ void DmlCommandRecorder::InitializeOperator( UINT64 temporaryResourceSize = initBindingProps.TemporaryResourceSize; if (temporaryResourceSize > 0) { - auto allocator = m_bufferAllocator.lock(); + auto allocator = m_allocator.lock(); // Allocate and immediately free a temporary buffer. The buffer resource will still be // alive (managed by the pool); freeing allows the resource to be shared with other operators. - void* tempResourceHandle = allocator->Alloc(static_cast(temporaryResourceSize)); - if (!tempResourceHandle) - { - ORT_THROW_HR(E_OUTOFMEMORY); - } - - ID3D12Resource* buffer = allocator->DecodeDataHandle(tempResourceHandle)->GetResource(); - allocator->Free(tempResourceHandle); + auto buffer = allocator->AllocateDefaultBuffer(temporaryResourceSize); // Bind the temporary resource. - DML_BUFFER_BINDING bufferBinding = { buffer, 0, temporaryResourceSize }; + DML_BUFFER_BINDING bufferBinding = buffer.GetBufferBinding(); DML_BINDING_DESC bindingDesc = { DML_BINDING_TYPE_BUFFER, &bufferBinding }; bindingTable->BindTemporaryResource(&bindingDesc); } @@ -133,21 +125,14 @@ void DmlCommandRecorder::ExecuteOperator( UINT64 temporaryResourceSize = execBindingProps.TemporaryResourceSize; if (temporaryResourceSize > 0) { - auto allocator = m_bufferAllocator.lock(); + auto allocator = m_allocator.lock(); // Allocate and immediately free a temporary buffer. The buffer resource will still be // alive (managed by the pool); freeing allows the resource to be shared with other operators. - void* tempResourceHandle = allocator->Alloc(static_cast(temporaryResourceSize)); - if (!tempResourceHandle) - { - ORT_THROW_HR(E_OUTOFMEMORY); - } - - ID3D12Resource* buffer = allocator->DecodeDataHandle(tempResourceHandle)->GetResource(); - allocator->Free(tempResourceHandle); + auto buffer = allocator->AllocateDefaultBuffer(temporaryResourceSize); // Bind the temporary resource. - DML_BUFFER_BINDING bufferBinding = { buffer, 0, temporaryResourceSize }; + DML_BUFFER_BINDING bufferBinding = buffer.GetBufferBinding(); DML_BINDING_DESC bindingDesc = { DML_BINDING_TYPE_BUFFER, &bufferBinding }; bindingTable->BindTemporaryResource(&bindingDesc); } @@ -186,6 +171,7 @@ void DmlCommandRecorder::CopyBufferRegion( void DmlCommandRecorder::FillBufferWithPattern( ID3D12Resource* dstBuffer, + uint64_t offset, gsl::span value /* Data type agnostic value, treated as raw bits */) { // The fill pattern for ClearUnorderedAccessViewUint is 16 bytes. @@ -216,6 +202,7 @@ void DmlCommandRecorder::FillBufferWithPattern( D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {}; uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER; uavDesc.Format = DXGI_FORMAT_R32_TYPELESS; + uavDesc.Buffer.FirstElement = gsl::narrow(offset / sizeof(uint32_t)); uavDesc.Buffer.NumElements = gsl::narrow(dstBuffer->GetDesc().Width / sizeof(uint32_t)); uavDesc.Buffer.Flags = D3D12_BUFFER_UAV_FLAG_RAW; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h index 76a7a7277851a..619939f2f3139 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommandRecorder.h @@ -6,12 +6,14 @@ #include #include "ICommandRecorder.h" #include "CommandAllocatorRing.h" +#include "core/framework/allocator.h" +#include "DmlGpuAllocator.h" #include "DescriptorPool.h" namespace Dml { class CommandQueue; - class BucketizedBufferAllocator; + class DmlReservedResourceSubAllocator; class DmlCommandRecorder : public ICommandRecorder { @@ -41,6 +43,7 @@ namespace Dml void FillBufferWithPattern( ID3D12Resource* dstBuffer, + uint64_t offset, gsl::span value /* Data type agnostic value, treated as raw bits */); void ExecuteCommandList( @@ -56,7 +59,7 @@ namespace Dml void Open() final; void CloseAndExecute() final; - void SetAllocator(std::weak_ptr allocator); + void SetAllocator(std::weak_ptr allocator); bool HasUnsubmittedWork() override { @@ -84,7 +87,7 @@ namespace Dml ID3D12DescriptorHeap* m_currentDescriptorHeap = nullptr; // The weak pointer avoids a circular reference from context->recorder->allocator->context - std::weak_ptr m_bufferAllocator; + std::weak_ptr m_allocator; CommandAllocatorRing<2> m_commandAllocatorRing; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp deleted file mode 100644 index 54393e9bf1539..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.cpp +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "precomp.h" -#include "DmlCommittedResourceAllocator.h" -#include "DmlResourceWrapper.h" -#include "DmlCommittedResourceWrapper.h" - -namespace Dml -{ - ComPtr DmlCommittedResourceAllocator::Alloc(size_t size) - { - ComPtr resource; - auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); - ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( - unmove_ptr(CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT)), - D3D12_HEAP_FLAG_NONE, - &buffer, - D3D12_RESOURCE_STATE_COMMON, - nullptr, - IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) - )); - - ComPtr resourceWrapper; - wil::MakeOrThrow(std::move(resource)).As(&resourceWrapper); - return resourceWrapper; - } -} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h deleted file mode 100644 index 7ad48be32a6c9..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "DmlSubAllocator.h" - -namespace Dml -{ - struct DmlResourceWrapper; - - class DmlCommittedResourceAllocator : public DmlSubAllocator - { - public: - DmlCommittedResourceAllocator(ID3D12Device* device) : m_device(device) {} - Microsoft::WRL::ComPtr Alloc(size_t size) final; - - private: - ID3D12Device* m_device = nullptr; - }; -} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCpuAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCpuAllocator.cpp new file mode 100644 index 0000000000000..a85c35e2fa72c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCpuAllocator.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" +#include "DmlCpuAllocator.h" + +namespace Dml +{ + +DmlCpuAllocator::DmlCpuAllocator(OrtMemType memType) + : onnxruntime::IAllocator( + OrtMemoryInfo( + "DML CPU", + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0), + 0, + memType + ) + ) +{ +} + +void* DmlCpuAllocator::Alloc(size_t size) +{ + return onnxruntime::AllocatorDefaultAlloc(size); +} + +void DmlCpuAllocator::Free(void* p) +{ + return onnxruntime::AllocatorDefaultFree(p); +} + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCpuAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCpuAllocator.h new file mode 100644 index 0000000000000..2f81975d2c4cd --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlCpuAllocator.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace Dml +{ + +class DmlCpuAllocator : public onnxruntime::IAllocator +{ +public: + explicit DmlCpuAllocator(OrtMemType memType); + + void* Alloc(size_t size) override; + void Free(void* p) override; +}; + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h deleted file mode 100644 index b22f0b2853e5d..0000000000000 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalBufferAllocator.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include "directx/d3d12.h" -#include -#include -#include "External/D3DX12/d3dx12.h" -#include "core/framework/allocator.h" -#include "core/providers/dml/dml_provider_factory_creator.h" -#include "AllocationInfo.h" -#include "GraphicsUnknownHelper.h" -#include "ErrorHandling.h" -#include "DmlCommittedResourceWrapper.h" - -namespace Dml -{ - class DmlExternalBufferAllocator : public onnxruntime::IAllocator - { - public: - DmlExternalBufferAllocator(int device_id) : onnxruntime::IAllocator( - OrtMemoryInfo( - "DML", - OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) - )) - { - m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); - } - - void* Alloc(size_t size) final - { - Microsoft::WRL::ComPtr resource; - auto buffer = CD3DX12_RESOURCE_DESC::Buffer(size, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); - auto props = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); - ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( - &props, - D3D12_HEAP_FLAG_NONE, - &buffer, - D3D12_RESOURCE_STATE_COMMON, - nullptr, - IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) - )); - - const uint64_t resourceWidth = resource->GetDesc().Width; - constexpr uint64_t pooledResourceId = 0; // Not a pooled resource - - Microsoft::WRL::ComPtr resourceWrapper; - wil::MakeOrThrow(std::move(resource)).As(&resourceWrapper); - - Microsoft::WRL::ComPtr allocInfo = wil::MakeOrThrow( - nullptr, - 0, - pooledResourceId, - resourceWrapper.Get(), - static_cast(resourceWidth)); - - return allocInfo.Detach(); - } - - void Free(void* ptr) final - { - Microsoft::WRL::ComPtr resource; - resource.Attach(static_cast(ptr)); - } - - private: - Microsoft::WRL::ComPtr m_device; - }; -} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalGpuAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalGpuAllocator.cpp new file mode 100644 index 0000000000000..3882823629854 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalGpuAllocator.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "precomp.h" +#include "DmlExternalGpuAllocator.h" +#include "DmlResourceWrapper.h" +#include "DmlCommittedResourceWrapper.h" +#include "DmlAllocationInfo.h" +#include "core/providers/dml/dml_provider_factory_creator.h" + +namespace Dml +{ + DmlExternalGpuAllocator::DmlExternalGpuAllocator(ID3D12Device* device) + : onnxruntime::IAllocator( + OrtMemoryInfo( + onnxruntime::DML, + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DML_EXTERNAL, 0), + -1 + )), + m_device(device) + { + } + + DmlExternalGpuAllocator::DmlExternalGpuAllocator(int device_id) + : onnxruntime::IAllocator( + OrtMemoryInfo( + onnxruntime::DML, + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DML_EXTERNAL, gsl::narrow_cast(device_id)), + device_id + )) + { + m_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); + } + + void* DmlExternalGpuAllocator::Alloc(size_t sizeInBytes) + { + Microsoft::WRL::ComPtr resource; + auto buffer = CD3DX12_RESOURCE_DESC::Buffer(sizeInBytes, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS); + auto props = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT); + ORT_THROW_IF_FAILED(m_device->CreateCommittedResource( + &props, + D3D12_HEAP_FLAG_NONE, + &buffer, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + nullptr, + IID_GRAPHICS_PPV_ARGS(resource.GetAddressOf()) + )); + + const uint64_t resourceWidth = resource->GetDesc().Width; + constexpr uint64_t pooledResourceId = 0; // Not a pooled resource + + Microsoft::WRL::ComPtr resourceWrapper; + wil::MakeOrThrow(std::move(resource)).As(&resourceWrapper); + + Microsoft::WRL::ComPtr allocInfo = wil::MakeOrThrow( + nullptr, + 0, + pooledResourceId, + resourceWrapper.Get(), + static_cast(resourceWidth)); + + return allocInfo.Detach(); + } + + void DmlExternalGpuAllocator::Free(void* ptr) + { + Microsoft::WRL::ComPtr resource; + resource.Attach(static_cast(ptr)); + } + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalGpuAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalGpuAllocator.h new file mode 100644 index 0000000000000..3d61bee211949 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlExternalGpuAllocator.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace Dml +{ + class DmlReservedResourceSubAllocator; + class AllocationInfo; + struct TaggedPointer; + + class DmlExternalGpuAllocator : public onnxruntime::IAllocator + { + public: + DmlExternalGpuAllocator(ID3D12Device* device); + DmlExternalGpuAllocator(int device_id); + + void* Alloc(size_t sizeInBytes) final; + void Free(void* ptr) final; + + private: + Microsoft::WRL::ComPtr m_device; + }; +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGpuAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGpuAllocator.cpp new file mode 100644 index 0000000000000..4e1a401adba17 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGpuAllocator.cpp @@ -0,0 +1,153 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "precomp.h" +#include "DmlGpuAllocator.h" +#include "core/framework/allocator.h" +#include "DmlReservedResourceSubAllocator.h" +#include "DmlTaggedPointer.h" +#include "DmlAllocationInfo.h" +#include "BucketizedBufferAllocator.h" +#include "DmlAllocatorRoundingMode.h" +#include "core/framework/arena_extend_strategy.h" +#include "core/framework/bfc_arena.h" + +namespace Dml +{ + static onnxruntime::ArenaExtendStrategy RoundingModeToArenaStrategy(AllocatorRoundingMode roundingMode) + { + switch(roundingMode) + { + case AllocatorRoundingMode::Disabled: return onnxruntime::ArenaExtendStrategy::kSameAsRequested; + case AllocatorRoundingMode::Enabled: return onnxruntime::ArenaExtendStrategy::kNextPowerOfTwo; + default: + ORT_THROW_HR(E_UNEXPECTED); + } + } + + DmlGpuAllocator::DmlGpuAllocator( + onnxruntime::BFCArena* bfcAllocator, + BucketizedBufferAllocator* bucketizedBufferAllocator, + std::shared_ptr bfcSubAllocator, + ActiveAllocator activeAllocator) + : onnxruntime::IAllocator( + OrtMemoryInfo( + onnxruntime::DML, + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), + 0 + ) + ), + m_bfcAllocator(bfcAllocator), + m_bucketizedBufferAllocator(bucketizedBufferAllocator), + m_bfcSubAllocator(bfcSubAllocator), + m_activeAllocator(activeAllocator) {} + + void* DmlGpuAllocator::Alloc(size_t sizeInBytes) + { + return Alloc(sizeInBytes, m_defaultRoundingMode); + } + + void* DmlGpuAllocator::Alloc(size_t sizeInBytes, AllocatorRoundingMode roundingMode) + { + switch(m_activeAllocator) + { + case ActiveAllocator::BfcAllocator: + { + if (m_defaultRoundingMode != roundingMode) + { + m_bfcAllocator->SetArenaExtendStrategy(RoundingModeToArenaStrategy(roundingMode)); + } + + auto allocatedPointer = m_bfcAllocator->Alloc(sizeInBytes); + + if (m_defaultRoundingMode != roundingMode) + { + m_bfcAllocator->SetArenaExtendStrategy(RoundingModeToArenaStrategy(m_defaultRoundingMode)); + } + + return allocatedPointer; + } + + case ActiveAllocator::BucketizedBufferAllocator: + return m_bucketizedBufferAllocator->Alloc(sizeInBytes, roundingMode); + default: + ORT_THROW_HR(E_UNEXPECTED); + } + } + + void DmlGpuAllocator::Free(void* ptr) + { + switch(m_activeAllocator) + { + case ActiveAllocator::BfcAllocator: + return m_bfcAllocator->Free(ptr); + case ActiveAllocator::BucketizedBufferAllocator: + return m_bucketizedBufferAllocator->Free(ptr); + default: + ORT_THROW_HR(E_UNEXPECTED); + } + } + + D3D12BufferRegion DmlGpuAllocator::CreateBufferRegion(void* opaquePointer, uint64_t sizeInBytes) + { + switch(m_activeAllocator) + { + case ActiveAllocator::BfcAllocator: + return m_bfcSubAllocator->CreateBufferRegion(opaquePointer, sizeInBytes); + case ActiveAllocator::BucketizedBufferAllocator: + return m_bucketizedBufferAllocator->CreateBufferRegion(opaquePointer, sizeInBytes); + default: + ORT_THROW_HR(E_UNEXPECTED); + } + } + + AllocationInfo* DmlGpuAllocator::GetAllocationInfo(void* opaquePointer) + { + switch(m_activeAllocator) + { + case ActiveAllocator::BfcAllocator: + return m_bfcSubAllocator->GetAllocationInfo(opaquePointer); + case ActiveAllocator::BucketizedBufferAllocator: + return m_bucketizedBufferAllocator->GetAllocationInfo(opaquePointer); + default: + ORT_THROW_HR(E_UNEXPECTED); + } + } + + void DmlGpuAllocator::SetDefaultRoundingMode(AllocatorRoundingMode roundingMode) + { + if (m_activeAllocator == ActiveAllocator::BfcAllocator) + { + m_bfcAllocator->SetArenaExtendStrategy(RoundingModeToArenaStrategy(roundingMode)); + } + + m_defaultRoundingMode = roundingMode; + } + + DmlBuffer DmlGpuAllocator::AllocateDefaultBuffer(uint64_t num_bytes) + { + return DmlBuffer(this, num_bytes, m_defaultRoundingMode); + } + + DmlBuffer DmlGpuAllocator::AllocateDefaultBuffer(uint64_t num_bytes, AllocatorRoundingMode roundingMode) + { + return DmlBuffer(this, num_bytes, roundingMode); + } + + uint64_t DmlGpuAllocator::GetUniqueId(void* opaquePointer) + { + switch(m_activeAllocator) + { + case ActiveAllocator::BfcAllocator: + return m_bfcSubAllocator->GetUniqueId(opaquePointer); + case ActiveAllocator::BucketizedBufferAllocator: + return m_bucketizedBufferAllocator->GetUniqueId(opaquePointer); + default: + ORT_THROW_HR(E_UNEXPECTED); + } + } + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGpuAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGpuAllocator.h new file mode 100644 index 0000000000000..922c4aa82f367 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGpuAllocator.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "DmlBufferRegion.h" +#include "DmlBuffer.h" +#include "DmlAllocatorRoundingMode.h" + +namespace onnxruntime +{ + class BFCArena; +} + +namespace Dml +{ + class DmlReservedResourceSubAllocator; + class BucketizedBufferAllocator; + class AllocationInfo; + struct TaggedPointer; + + enum class ActiveAllocator + { + BfcAllocator, + BucketizedBufferAllocator, + }; + + class DmlGpuAllocator : public onnxruntime::IAllocator + { + public: + DmlGpuAllocator( + onnxruntime::BFCArena* bfcAllocator, + BucketizedBufferAllocator* bucketizedBufferAllocator, + std::shared_ptr bfcSubAllocator, + ActiveAllocator activeAllocator); + + void* Alloc(size_t sizeInBytes) final; + void* Alloc(size_t sizeInBytes, AllocatorRoundingMode roundingMode); + void Free(void* ptr) final; + D3D12BufferRegion CreateBufferRegion(void* opaquePointer, uint64_t sizeInBytes); + AllocationInfo* GetAllocationInfo(void* opaquePointer); + void SetDefaultRoundingMode(AllocatorRoundingMode roundingMode); + DmlBuffer AllocateDefaultBuffer(uint64_t num_bytes); + DmlBuffer AllocateDefaultBuffer(uint64_t num_bytes, AllocatorRoundingMode roundingMode); + uint64_t GetUniqueId(void* opaquePointer); + + private: + // This allocator is managed by ORT and should be used to allocate/free memory in order + // to utilize the BFC acapabilities + onnxruntime::BFCArena* m_bfcAllocator; + + // This allocator is the old bucketized allocator that is kept for backward compatibility purposes + // and is only used when external custom ops are registered. + BucketizedBufferAllocator* m_bucketizedBufferAllocator; + + // This allocator is specific to DML and is used to decode the opaque data returned by the BFC + // allocator into objects that DML understands + std::shared_ptr m_bfcSubAllocator; + + // Unless specifically requested, allocation sizes are not rounded to enable pooling + // until SetDefaultRoundingMode is called. This should be done at completion of session + // initialization. + AllocatorRoundingMode m_defaultRoundingMode = AllocatorRoundingMode::Disabled; + + ActiveAllocator m_activeAllocator; + }; +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 8ee31d4b84f2f..98331369728f8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -1,6 +1,9 @@ #pragma once #include "DmlGraphFusionHelper.h" +#include "DmlBufferRegion.h" +#include "DmlTaggedPointer.h" +#include "DmlAllocationInfo.h" #include "DmlRuntimeFusedGraphKernel.h" using namespace Windows::AI::MachineLearning::Adapter; @@ -90,19 +93,23 @@ namespace DmlGraphFusionHelper return buffer; } - void UnwrapTensor( + D3D12BufferRegion UnwrapTensor( Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider* winmlProvider, const onnxruntime::Tensor* tensor, - ID3D12Resource** resource, uint64_t* allocId) { - IUnknown* allocationUnknown = static_cast(const_cast(tensor->DataRaw())); - Microsoft::WRL::ComPtr resourceUnknown; - winmlProvider->GetABIDataInterface(false, allocationUnknown, &resourceUnknown); + void* opaqueData = const_cast(tensor->DataRaw()); - *allocId = winmlProvider->TryGetPooledAllocationId(allocationUnknown, 0); + if (tensor->Location().device.MemType() == OrtDevice::MemType::DML_EXTERNAL) + { + // The allocation is not pooled + auto allocInfo = static_cast(opaqueData); + *allocId = 0; + return D3D12BufferRegion(0, allocInfo->GetD3D12Resource()->GetDesc().Width, allocInfo->GetD3D12Resource()); + } - ORT_THROW_IF_FAILED(resourceUnknown->QueryInterface(resource)); + *allocId = winmlProvider->GetUniqueId(opaqueData); + return winmlProvider->GetBufferRegion(opaqueData, tensor->SizeInBytes()); } std::tuple, std::vector, std::byte*, size_t> UnpackInitializer( @@ -935,7 +942,6 @@ namespace DmlGraphFusionHelper const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes, IWinmlExecutionProvider* winmlProvider, IExecutionProvider* provider, - IUnknown* persistentResourceAllocatorUnknown, bool keepTemporaryResourceAlive) { DML_BINDING_PROPERTIES execBindingProps = compiledExecutionPlanOperator->GetBindingProperties(); @@ -970,10 +976,9 @@ namespace DmlGraphFusionHelper const onnxruntime::Tensor* tensor = kernelContext->Input(gsl::narrow_cast(i)); uint64_t allocId; - DmlGraphFusionHelper::UnwrapTensor(winmlProvider, tensor, &inputBindings[i].Buffer, &allocId); + auto bufferRegion = DmlGraphFusionHelper::UnwrapTensor(winmlProvider, tensor, &allocId); + inputBindings[i] = bufferRegion.GetBufferBinding(); inputBindingsChanged = inputBindingsChanged || (!allocId || commandListState.inputBindingAllocIds[i] != allocId); - inputBindings[i].Buffer->Release(); // Avoid holding an additional reference - inputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); inputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &inputBindings[i]}; commandListState.inputBindingAllocIds[i] = allocId; } @@ -1007,10 +1012,9 @@ namespace DmlGraphFusionHelper ); uint64_t allocId; - DmlGraphFusionHelper::UnwrapTensor(winmlProvider, tensor, &outputBindings[i].Buffer, &allocId); + auto bufferRegion = DmlGraphFusionHelper::UnwrapTensor(winmlProvider, tensor, &allocId); + outputBindings[i] = bufferRegion.GetBufferBinding(); outputBindingsChanged = outputBindingsChanged || (!allocId || commandListState.outputBindingAllocIds[i] != allocId); - outputBindings[i].Buffer->Release(); // Avoid holding an additional reference - outputBindings[i].SizeInBytes = DmlGraphFusionHelper::AlignToPow2(tensor->SizeInBytes(), 4); outputBindingDescs[i] = {DML_BINDING_TYPE_BUFFER, &outputBindings[i]}; commandListState.outputBindingAllocIds[i] = allocId; } @@ -1022,23 +1026,14 @@ namespace DmlGraphFusionHelper if (execBindingProps.TemporaryResourceSize > 0) { - // Allocate temporary data which will automatically be freed when the GPU work - // which is scheduled up to the point that this method returns has completed. - ComPtr tempAlloc; + // TODO (pavignol): Handle alloc ID uint64_t tempAllocId = 0; - ORT_THROW_IF_FAILED(contextWrapper.AllocateTemporaryData(static_cast(execBindingProps.TemporaryResourceSize), tempAlloc.GetAddressOf(), &tempAllocId)); - - ComPtr tempResourceUnknown; - winmlProvider->GetABIDataInterface(false, tempAlloc.Get(), &tempResourceUnknown); - - // Bind the temporary resource. - ComPtr tempResource; - ORT_THROW_IF_FAILED(tempResourceUnknown->QueryInterface(tempResource.GetAddressOf())); - DML_BUFFER_BINDING tempBufferBinding = {tempResource.Get(), 0, execBindingProps.TemporaryResourceSize}; - DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding }; + auto buffer = provider->AllocatePooledResource(execBindingProps.TemporaryResourceSize, AllocatorRoundingMode::Enabled); if (!tempAllocId || commandListState.tempBindingAllocId != tempAllocId) { + DML_BUFFER_BINDING tempBufferBinding = buffer.GetBufferBinding(); + DML_BINDING_DESC tempBindingDesc = { DML_BINDING_TYPE_BUFFER, &tempBufferBinding }; commandListState.bindingTable->BindTemporaryResource(&tempBindingDesc); } @@ -1046,7 +1041,7 @@ namespace DmlGraphFusionHelper if (keepTemporaryResourceAlive) { - commandListState.temporaryResource = std::move(tempResource); + commandListState.temporaryBuffer = std::move(buffer); } } @@ -1071,7 +1066,6 @@ namespace DmlGraphFusionHelper winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(commandListState.graphicsCommandList).Get()); winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(commandListState.heap).Get()); winmlProvider->QueueReference(commandListState.bindingTable.Get()); - winmlProvider->QueueReference(persistentResourceAllocatorUnknown); } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index 1a810ca58cf45..bd2205b7a7572 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -6,6 +6,7 @@ #include "GraphPartitioner.h" #include "FusedGraphKernel.h" #include "MLOperatorAuthorImpl.h" +#include "DmlBufferRegion.h" #include "DmlReusedCommandListState.h" using Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider; @@ -35,10 +36,9 @@ namespace DmlGraphFusionHelper const std::byte* tensorPtr, size_t tensorByteSize); - void UnwrapTensor( + D3D12BufferRegion UnwrapTensor( Windows::AI::MachineLearning::Adapter::IWinmlExecutionProvider* winmlProvider, const onnxruntime::Tensor* tensor, - ID3D12Resource** resource, uint64_t* allocId); std::unordered_map> @@ -120,7 +120,6 @@ namespace DmlGraphFusionHelper const Windows::AI::MachineLearning::Adapter::EdgeShapes& outputShapes, IWinmlExecutionProvider* winmlProvider, IExecutionProvider* provider, - IUnknown* persistentResourceAllocatorUnknown, bool keepTemporaryResourceAlive); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlHeapAllocation.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlHeapAllocation.h new file mode 100644 index 0000000000000..5ecf135a9ee43 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlHeapAllocation.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Dml +{ + struct DmlHeapAllocation + { + Microsoft::WRL::ComPtr heap; + + // Heaps backing the memory for the allocation. If tiling is supported + // an allocation may comprise multiple heaps. If tiling is not supported + // an allocation will only have a single heap. + std::vector> heaps; + Microsoft::WRL::ComPtr resourceUavState; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlManagedBuffer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlManagedBuffer.h new file mode 100644 index 0000000000000..ced81af68e92e --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlManagedBuffer.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "DmlBuffer.h" + +namespace Dml +{ + // Light wrapper around DmlBuffer used with CommandQueue::QueueReference to keep a reference on the buffer until GPU work is completed + class DmlManagedBuffer : public Microsoft::WRL::RuntimeClass, IUnknown> + { + public: + DmlManagedBuffer(DmlBuffer&& buffer) : m_buffer(std::move(buffer)) {} + uint64_t SizeInBytes() const { return m_buffer.SizeInBytes(); } + + private: + DmlBuffer m_buffer; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceAllocatorWrapper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceAllocatorWrapper.h new file mode 100644 index 0000000000000..e92740e9ce907 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceAllocatorWrapper.h @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" +#include "DmlReservedResourceSubAllocator.h" + +namespace Dml +{ + class DmlReservedResourceAllocatorWrapper : public onnxruntime::IAllocator + { + public: + DmlReservedResourceAllocatorWrapper(std::shared_ptr subAllocator) + : onnxruntime::IAllocator( + OrtMemoryInfo( + onnxruntime::DML, + OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0) + ) + ), + m_subAllocator(std::move(subAllocator)) {} + + void* Alloc(size_t sizeInBytes) final { return m_subAllocator->Alloc(sizeInBytes); } + void Free(void* ptr) final { m_subAllocator->Free(ptr); } + private: + std::shared_ptr m_subAllocator; + }; +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceSubAllocator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceSubAllocator.cpp new file mode 100644 index 0000000000000..de9ee09c9726d --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceSubAllocator.cpp @@ -0,0 +1,366 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +#include "core/session/onnxruntime_c_api.h" +#include "DmlReservedResourceSubAllocator.h" +#include "DmlReservedResourceWrapper.h" +#include "DmlBufferRegion.h" + +namespace Dml +{ + DmlReservedResourceSubAllocator::~DmlReservedResourceSubAllocator() + { +#ifdef PRINT_OUTSTANDING_ALLOCATIONS + if (!m_outstandingAllocationsById.empty()) + { + printf("DmlReservedResourceSubAllocator outstanding allocation indices:\n"); + for (auto& entry : m_outstandingAllocationsById) + { + printf("%u\n", static_cast(entry.first)); + } + printf("\n"); + } +#endif + } + + /*static*/ gsl::index DmlReservedResourceSubAllocator::GetBucketIndexFromSize(uint64_t size) + { + assert(size != 0); + + // Each bucket is twice as large as the previous one, in ascending order + gsl::index index = static_cast(ceil(log2(size))); + assert((1ull << index) >= size); // This must be true unless there were some strange rounding issues + + // The smallest bucket is 2^n bytes large, where n = MinResourceSizeExponent + index = std::max(index, MinResourceSizeExponent); + index -= MinResourceSizeExponent; + + return index; + } + + /*static*/ uint64_t DmlReservedResourceSubAllocator::GetBucketSizeFromIndex(gsl::index index) + { + return (1ull << (index + MinResourceSizeExponent)); + } + + static bool GetTilingEnabled(ID3D12Device* device) + { + D3D12_FEATURE_DATA_D3D12_OPTIONS options = {}; + if (SUCCEEDED(device->CheckFeatureSupport(D3D12_FEATURE_D3D12_OPTIONS, &options, sizeof(options)))) + { + return options.TiledResourcesTier >= D3D12_TILED_RESOURCES_TIER_1; + } + + return false; + } + + static uint64_t GetMaxHeapSizeInTiles() + { + return DmlReservedResourceSubAllocator::DefaultMaxHeapSizeInTiles; + } + + DmlReservedResourceSubAllocator::DmlReservedResourceSubAllocator( + ID3D12Device* device, + ExecutionContext* context, + ID3D12CommandQueue* queue, + const D3D12_HEAP_PROPERTIES& heapProps, + D3D12_HEAP_FLAGS heapFlags, + D3D12_RESOURCE_FLAGS resourceFlags, + D3D12_RESOURCE_STATES initialState) + : m_device(device), + m_context(context), + m_queue(queue), + m_heapProperties(heapProps), + m_heapFlags(heapFlags), + m_resourceFlags(resourceFlags), + m_initialState(initialState), + m_tilingEnabled(GetTilingEnabled(device)), + m_maxHeapSizeInTiles(GetMaxHeapSizeInTiles()) + { + } + + absl::optional DmlReservedResourceSubAllocator::TryCreateTiledAllocation(uint64_t sizeInBytes) + { + DmlHeapAllocation allocation = {}; + + // The allocation may be larger than the requested size to ensure a whole + // number of tiles. + const uint64_t resourceSizeInTiles = 1 + (sizeInBytes - 1) / D3D12_TILED_RESOURCE_TILE_SIZE_IN_BYTES; + const uint64_t resourceSizeInBytes = resourceSizeInTiles * D3D12_TILED_RESOURCE_TILE_SIZE_IN_BYTES; + auto resourceDesc = CD3DX12_RESOURCE_DESC::Buffer(resourceSizeInBytes, m_resourceFlags); + + HRESULT createResourceHr = m_device->CreateReservedResource( + &resourceDesc, + m_initialState, + nullptr, + IID_PPV_ARGS(&allocation.resourceUavState)); + + if (createResourceHr == E_OUTOFMEMORY) + { + return absl::nullopt; + } + ORT_THROW_IF_FAILED(createResourceHr); + + // Reserve enough heaps to store all tiles in the resource. + const uint64_t heapCount = 1 + (resourceSizeInTiles - 1) / m_maxHeapSizeInTiles; + allocation.heaps.resize(heapCount); + + // Create heaps and map them to the primary reserved resource. + D3D12_TILED_RESOURCE_COORDINATE resourceRegionStartCoordinates = {}; + uint64_t unmappedResourceTiles = resourceSizeInTiles; + for (uint64_t i = 0; i < heapCount; i++) + { + // Create heap. The last heap of the allocation may have fewer tiles to + // avoid wasting space. + uint64_t heapSizeInTiles = std::min(unmappedResourceTiles, m_maxHeapSizeInTiles); + uint64_t heapSizeInBytes = heapSizeInTiles * D3D12_TILED_RESOURCE_TILE_SIZE_IN_BYTES; + auto heap_desc = CD3DX12_HEAP_DESC( + heapSizeInBytes, + m_heapProperties, + 0, + m_heapFlags); + + HRESULT createHeapHr = m_device->CreateHeap(&heap_desc, IID_PPV_ARGS(&allocation.heaps[i])); + if (createHeapHr == E_OUTOFMEMORY) + { + return absl::nullopt; + } + ORT_THROW_IF_FAILED(createHeapHr); + + // Source region in the resource to map. + D3D12_TILE_REGION_SIZE resourceRegionSize = {}; + resourceRegionSize.NumTiles = static_cast(heapSizeInTiles); + + // Target range in the current heap to map. + constexpr D3D12_TILE_RANGE_FLAGS tileRangeFlags = D3D12_TILE_RANGE_FLAG_NONE; + const uint32_t heapRangeTileCount = static_cast(heapSizeInTiles); + + constexpr uint32_t heapRangeStartOffset = 0; + constexpr uint32_t numResourceRegions = 1; + constexpr uint32_t numHeapRanges = 1; + + // This is a brand new allocation/resource, so the tile mappings are + // guaranteed to be set (on the GPU timeline) by the time any code can + // reference the returned resource. We only execute operations on a + // single hardware queue so there is no need to wait or signal. + m_queue->UpdateTileMappings( + allocation.resourceUavState.Get(), + numResourceRegions, + &resourceRegionStartCoordinates, + &resourceRegionSize, + allocation.heaps[i].Get(), + numHeapRanges, + &tileRangeFlags, + &heapRangeStartOffset, + &heapRangeTileCount, + D3D12_TILE_MAPPING_FLAG_NONE); + + resourceRegionStartCoordinates.X += static_cast(heapSizeInTiles); + unmappedResourceTiles -= heapSizeInTiles; + } + + assert(unmappedResourceTiles == 0); + + return allocation; + } + + absl::optional DmlReservedResourceSubAllocator::TryCreateUntiledAllocation(uint64_t sizeInBytes) + { + DmlHeapAllocation allocation = {}; + + // Create the allocation's sole heap. The allocation may be larger than the + // requested size to ensure a whole number of tiles. + allocation.heaps.resize(1); + D3D12_HEAP_DESC heap_desc = CD3DX12_HEAP_DESC(sizeInBytes, m_heapProperties, 0, m_heapFlags); + HRESULT createHeapHr = m_device->CreateHeap(&heap_desc, IID_PPV_ARGS(&allocation.heaps.front())); + if (createHeapHr == E_OUTOFMEMORY) + { + return absl::nullopt; + } + ORT_THROW_IF_FAILED(createHeapHr); + + // Create large placed resource that spans the heap. + D3D12_RESOURCE_DESC resourceDesc = CD3DX12_RESOURCE_DESC::Buffer(sizeInBytes, m_resourceFlags); + + HRESULT createResourceHr = m_device->CreatePlacedResource( + allocation.heaps.front().Get(), + 0, + &resourceDesc, + m_initialState, + nullptr, + IID_PPV_ARGS(&allocation.resourceUavState)); + if (createResourceHr == E_OUTOFMEMORY) + { + return absl::nullopt; + } + ORT_THROW_IF_FAILED(createResourceHr); + + return allocation; + } + + uint64_t DmlReservedResourceSubAllocator::ComputeRequiredSize(size_t size) + { + const uint64_t resourceSizeInTiles = 1 + (size - 1) / D3D12_TILED_RESOURCE_TILE_SIZE_IN_BYTES; + const uint64_t resourceSizeInBytes = resourceSizeInTiles * D3D12_TILED_RESOURCE_TILE_SIZE_IN_BYTES; + return resourceSizeInBytes; + } + + void* DmlReservedResourceSubAllocator::Alloc(size_t sizeInBytes) + { + // For some reason lotus likes requesting 0 bytes of memory + sizeInBytes = std::max(1, sizeInBytes); + + // The D3D12 device is thread-safe so we don't need to hold the lock while + // creating an allocation. + absl::optional allocation = + m_tilingEnabled ? TryCreateTiledAllocation(sizeInBytes) + : TryCreateUntiledAllocation(sizeInBytes); + + ORT_THROW_HR_IF(E_INVALIDARG, !allocation); + + // We need to access (mutable) state after this point, so we need to lock + std::unique_lock lock(m_mutex); + + absl::optional allocationId = TryReserveAllocationID(); + ORT_THROW_HR_IF(E_INVALIDARG, !allocationId); + + auto resourceWrapper = wil::MakeOrThrow(std::move(*allocation)); + ComPtr allocInfo = wil::MakeOrThrow( + this, + ++m_currentUniqueAllocationId, + 0, + resourceWrapper.Get(), + sizeInBytes + ); + + m_allocationsById.emplace(*allocationId, allocInfo); + + lock.unlock(); + + #if _DEBUG + m_outstandingAllocationsById[allocInfo->GetId()] = allocInfo.Get(); + #endif + + // DML only has a single device in ORT at the moment + constexpr uint64_t deviceId = 0; + constexpr uint64_t offset = 0; + return TaggedPointer::Pack(deviceId, *allocationId, offset); + } + + void DmlReservedResourceSubAllocator::Free(void* ptr) + { + ORT_THROW_HR_IF(E_INVALIDARG, ptr == nullptr); + + TaggedPointer taggedPtr = TaggedPointer::Unpack(ptr); + ORT_THROW_HR_IF(E_INVALIDARG, taggedPtr.offset != 0); + + // We need to access (mutable) state after this point, so we need to lock + std::unique_lock lock(m_mutex); + + auto it = m_allocationsById.find(taggedPtr.allocationId); + ORT_THROW_HR_IF(E_INVALIDARG, it == m_allocationsById.end()); + + ReleaseAllocationID(taggedPtr.allocationId); + + // Frees the ID3D12Heap + m_allocationsById.erase(it); + } + + uint64_t DmlReservedResourceSubAllocator::GetUniqueId(void* opaquePointer) + { + auto taggedPointer = TaggedPointer::Unpack(opaquePointer); + return taggedPointer.GetUniqueId(); + } + + void DmlReservedResourceSubAllocator::FreeResource(AllocationInfo* allocInfo, uint64_t resourceId) + { + // Since this allocator is warapped by ORT's BFC allocator, it's possible that the context is already + // close at this point if the application is winding down. + if (!m_context->Closed()) + { + assert(allocInfo != nullptr); // Can't free nullptr + + if (allocInfo->GetOwner() != this) + { + // This allocation doesn't belong to this allocator! + ORT_THROW_HR(E_INVALIDARG); + } + + m_context->QueueReference(allocInfo); + } + } + + absl::optional DmlReservedResourceSubAllocator::TryReserveAllocationID() + { + // The mutex must already be held + assert(!m_mutex.try_lock()); + + if (!m_freeAllocationIds.empty()) + { + // Return a free ID from the pool + uint32_t id = m_freeAllocationIds.back(); + m_freeAllocationIds.pop_back(); + return id; + } + + static constexpr uint32_t maxAllocationID = (1 << TaggedPointer::AllocationIDBits) - 1; + if (m_currentAllocationId == maxAllocationID) + { + // We've reached the maximum number of allocations! + return absl::nullopt; + } + + ++m_currentAllocationId; + return m_currentAllocationId; + } + + void DmlReservedResourceSubAllocator::ReleaseAllocationID(uint32_t id) + { + // The mutex must already be held + assert(!m_mutex.try_lock()); + + // Add it to the pool of free IDs + m_freeAllocationIds.push_back(id); + } + + D3D12BufferRegion DmlReservedResourceSubAllocator::CreateBufferRegion( + void* opaquePointer, + uint64_t sizeInBytes) + { + auto taggedPointer = TaggedPointer::Unpack(opaquePointer); + + // We need to access (mutable) state after this point, so we need to lock + std::unique_lock lock(m_mutex); + + // Find the allocation corresponding to this pointer + auto it = m_allocationsById.find(taggedPointer.allocationId); + ORT_THROW_HR_IF(E_INVALIDARG, it == m_allocationsById.end()); + + // Make sure that we are aligned to 4 bytes to satisfy DML's requirements + constexpr uint64_t DML_ALIGNMENT = 4; + sizeInBytes = (1 + (sizeInBytes - 1) / DML_ALIGNMENT) * DML_ALIGNMENT; + + // Make sure the region we're trying to create fits entirely in the resource + assert(it->second->GetD3D12Resource()->GetDesc().Width >= taggedPointer.offset + sizeInBytes); + + return D3D12BufferRegion( + taggedPointer.offset, + sizeInBytes, + it->second->GetD3D12Resource()); + } + + AllocationInfo* DmlReservedResourceSubAllocator::GetAllocationInfo(void* opaquePointer) + { + auto taggedPointer = TaggedPointer::Unpack(opaquePointer); + + // We need to access (mutable) state after this point, so we need to lock + std::unique_lock lock(m_mutex); + + // Find the allocation corresponding to this pointer + auto it = m_allocationsById.find(taggedPointer.allocationId); + return it->second.Get(); + } + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceSubAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceSubAllocator.h new file mode 100644 index 0000000000000..704c373f24966 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceSubAllocator.h @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "ExecutionContext.h" +#include "DmlAllocationInfo.h" +#include "DmlBufferRegion.h" +#include "DmlSubAllocator.h" + +namespace Dml +{ + // An allocator that makes logically contiguous allocations backed by D3D heaps. + // + // Heaps must fit entirely in either local or non-local memory. Larger heaps + // have a greater chance of getting demoted into non-local memory, which can be + // disastrous for performance. This problem is compounded by the fact that heaps + // may be demoted even if overall local memory usage is within the process' + // budget. Heaps are not necessarily mappable to discontiguous regions of + // physical memory, which means physical memory fragmentation *may* make it + // extremely difficult to accommodate larger heaps. + // + // On D3D hardware that supports tiled resource tier 1+ this class implements + // large allocations through tiling. Each allocation is backed by however many + // small heaps are necessary to cover the requested allocation size. Buffer + // regions retrieved through this allocator are reserved resources that span the + // full collection of heaps assigned to an individual allocation. Tile mappings + // are static. + // + // On hardware that doesn't support tiled resources each allocation is backed by + // a single heap. Buffer regions retrieved through this allocator are placed + // resources that span the full heap assigned to an individual allocation. In + // this case it is better make more but smaller allocations (resulting in + // smaller heaps); this fallback path is only retained as a last resort for + // older hardware. + class DmlReservedResourceSubAllocator : public DmlSubAllocator + { + public: + // Maximum size of a heap (in tiles) when allocations are tiled. Each tile + // is 64KB. A default size of 512 tiles (32MB) does a good job of handling + // local video memory fragmentation without requiring lots of heaps. + static constexpr uint64_t DefaultMaxHeapSizeInTiles = 512; + + DmlReservedResourceSubAllocator( + ID3D12Device* device, + ExecutionContext* context, + ID3D12CommandQueue* queue, + const D3D12_HEAP_PROPERTIES& heapProps, + D3D12_HEAP_FLAGS heapFlags, + D3D12_RESOURCE_FLAGS resourceFlags, + D3D12_RESOURCE_STATES initialState); + + // Creates a reserved or placed resource buffer over the given memory range. + // The physical D3D12 resource may be larger than the requested size, so + // callers must ensure to use the offset/size returned in the + // D3D12BufferRegion else risk out of bounds access. Note that in practice + // the ID3D12Resource is cached, so this call typically has a lower cost + // than a call to ID3D12Device::CreatePlacedResource or + // CreateReservedResource. + D3D12BufferRegion CreateBufferRegion(void* opaquePointer, uint64_t sizeInBytes); + + AllocationInfo* GetAllocationInfo(void* opaquePointer); + + void FreeResource(AllocationInfo* allocInfo, uint64_t resourceId) final; + uint64_t ComputeRequiredSize(size_t size); + bool TilingEnabled() const { return m_tilingEnabled; }; + uint64_t GetUniqueId(void* opaquePointer); + + ~DmlReservedResourceSubAllocator(); + + // Constructs a DmlReservedResourceSubAllocator which allocates D3D12 committed resources with the specified heap properties, + // resource flags, and initial resource state. + DmlReservedResourceSubAllocator( + ID3D12Device* device, + ExecutionContext* context, + std::unique_ptr&& subAllocator); + + void* Alloc(size_t size); + void Free(void* p); + + private: + static constexpr uint32_t MinResourceSizeExponent = 16; // 2^16 = 64KB + + // The pool consists of a number of buckets, and each bucket contains a number of resources of the same size. + // The resources in each bucket are always sized as a power of two, and each bucket contains resources twice + // as large as the previous bucket. + struct Resource + { + ComPtr resource; + uint64_t resourceId; + }; + + struct Bucket + { + std::vector resources; + }; + + static gsl::index GetBucketIndexFromSize(uint64_t size); + static uint64_t GetBucketSizeFromIndex(gsl::index index); + + friend class AllocationInfo; + + std::vector m_pool; + size_t m_currentUniqueAllocationId = 0; + uint64_t m_currentResourceId = 0; + std::unique_ptr m_subAllocator; + + #if _DEBUG + // Useful for debugging; keeps track of all allocations that haven't been freed yet + std::map m_outstandingAllocationsById; + #endif + + std::mutex m_mutex; + + Microsoft::WRL::ComPtr m_device; + Microsoft::WRL::ComPtr m_context; + Microsoft::WRL::ComPtr m_queue; + const D3D12_HEAP_PROPERTIES m_heapProperties; + const D3D12_HEAP_FLAGS m_heapFlags; + const D3D12_RESOURCE_FLAGS m_resourceFlags; + const D3D12_RESOURCE_STATES m_initialState; + bool m_tilingEnabled; + uint64_t m_maxHeapSizeInTiles; + + // The largest allocation ID we've returned so far (or 0 if we've never done + // so). Note that our allocation IDs start at 1 (not 0) to ensure that it + // isn't possible for a valid allocation to have a pointer value of + // 0x00000000. + uint32_t m_currentAllocationId = 0; + + // A list of unused allocation IDs. This is for re-use of IDs once they get + // freed. We only bump the max_allocation_id_ once there are no more free + // IDs. + std::vector m_freeAllocationIds; + + absl::optional TryCreateTiledAllocation(uint64_t sizeInBytes); + absl::optional TryCreateUntiledAllocation(uint64_t sizeInBytes); + + friend class D3D12BufferRegion; + + absl::flat_hash_map> m_allocationsById; + + // Retrieves a free allocation ID, or nullopt if no more IDs are available. + absl::optional TryReserveAllocationID(); + + // Releases an allocation ID back to the pool of IDs. + void ReleaseAllocationID(uint32_t id); + }; + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceWrapper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceWrapper.h new file mode 100644 index 0000000000000..e278ecbeb7415 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReservedResourceWrapper.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "DmlResourceWrapper.h" +#include "DmlHeapAllocation.h" +#include "DmlTaggedPointer.h" + +namespace Dml +{ + class DmlReservedResourceWrapper : public Microsoft::WRL::RuntimeClass, DmlResourceWrapper> + { + public: + DmlReservedResourceWrapper(DmlHeapAllocation&& allocation) + : m_allocation(std::move(allocation)) + { + } + + ID3D12Resource* GetD3D12Resource() const final { return m_allocation.resourceUavState.Get(); } + + private: + DmlHeapAllocation m_allocation; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h index 6c3c2fb8c7094..b99c4393c8037 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h @@ -15,8 +15,7 @@ namespace Dml Microsoft::WRL::ComPtr heap; Microsoft::WRL::ComPtr bindingTable; Microsoft::WRL::ComPtr persistentResource; - Microsoft::WRL::ComPtr temporaryResource; - Microsoft::WRL::ComPtr persistentResourceAllocatorUnknown; + std::optional temporaryBuffer; // Bindings from previous executions of a re-used command list mutable std::vector inputBindingAllocIds; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 7cd23256214dd..ad6d003669b87 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -68,13 +68,8 @@ namespace Dml UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) { - ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( - static_cast(persistentResourceSize), - AllocatorRoundingMode::Disabled, - m_persistentResource.ReleaseAndGetAddressOf(), - m_persistentResourceAllocatorUnknown.ReleaseAndGetAddressOf())); - - m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + auto buffer = m_provider->AllocatePooledResource(static_cast(persistentResourceSize), AllocatorRoundingMode::Disabled); + m_persistentResourceBinding = buffer.GetBufferBinding(); } ORT_THROW_IF_FAILED(m_provider->InitializeOperator( @@ -246,7 +241,6 @@ namespace Dml m_persistentResourceBinding); reusableCommandList->persistentResource = m_persistentResource; - reusableCommandList->persistentResourceAllocatorUnknown = m_persistentResourceAllocatorUnknown; // Keep the temporary resource alive since we won't call ExecuteReusableCommandList again, but will merely replay // the graph in the future. Therefore, all executions of the graph will use the same temporary resource that was @@ -264,7 +258,6 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get(), keepTemporaryResourceAlive); providerImpl->AppendCapturedGraph(providerImpl->GetCurrentGraphAnnotationId(), std::move(reusableCommandList)); @@ -298,7 +291,6 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get(), keepTemporaryResourceAlive); m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); @@ -330,7 +322,6 @@ namespace Dml mutable ComPtr m_compiledExecutionPlanOperator; mutable std::vector m_inputsUsed; mutable ComPtr m_persistentResource; - mutable ComPtr m_persistentResourceAllocatorUnknown; // Controls when the persistent resource is returned to the allocator mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; mutable std::unordered_map m_inferredInputShapes; mutable std::deque> m_reusedCommandLists; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlSubAllocator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlSubAllocator.h index cfdaf17710001..d6aa49d51c3f8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlSubAllocator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlSubAllocator.h @@ -10,7 +10,7 @@ namespace Dml class DmlSubAllocator { public: - virtual Microsoft::WRL::ComPtr Alloc(size_t size) = 0; - virtual ~DmlSubAllocator(){} + virtual void FreeResource(AllocationInfo* allocInfo, uint64_t resourceId) = 0; + virtual ~DmlSubAllocator() = default; }; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlTaggedPointer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlTaggedPointer.cpp new file mode 100644 index 0000000000000..f823d05c45382 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlTaggedPointer.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "DmlTaggedPointer.h" +#include + +namespace Dml +{ +/*static*/ TaggedPointer TaggedPointer::Unpack(const void* ptr) +{ + uint64_t ptrVal = reinterpret_cast(ptr); + + static constexpr uint64_t allocationIDMask = (1ull << AllocationIDBits) - 1; + static constexpr uint64_t offsetMask = (1ull << OffsetBits) - 1; + + TaggedPointer taggedPtr; + taggedPtr.deviceId = (ptrVal >> (AllocationIDBits + OffsetBits)); + taggedPtr.allocationId = (ptrVal >> OffsetBits) & allocationIDMask; + taggedPtr.offset = (ptrVal & offsetMask); + + return taggedPtr; +} + +/*static*/ void* TaggedPointer::Pack(uint32_t deviceId, uint32_t allocationId, uint64_t offset) +{ + assert(deviceId < (1ull << DeviceIDBits)); + assert(allocationId < (1ull << AllocationIDBits)); + assert(offset < (1ull << OffsetBits)); + + // Store the device ID in the upper bits of the pointer, followed by the + // allocation id and the offset in the lower bits + uint64_t ptr = ((uint64_t)deviceId << (AllocationIDBits + OffsetBits)) | + ((uint64_t)allocationId << OffsetBits) | offset; + + return reinterpret_cast(ptr); +} + +uint64_t TaggedPointer::GetUniqueId() const +{ + return reinterpret_cast(TaggedPointer::Pack(deviceId, allocationId, offset)); +} + +} // namespace tfdml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlTaggedPointer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlTaggedPointer.h new file mode 100644 index 0000000000000..d49e9d92eeb82 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlTaggedPointer.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace Dml +{ + +// D3D12HeapAllocator and D3D12DescriptorHeapAllocator encode the allocation ID +// into the high bits of the pointers it returns, while the low bits are used as +// an offset into the allocation. Note that since the layout of bitfields is +// implementation-defined, you can't just cast a void* into a TaggedPointer: it +// must be done using masks and shifts. +struct TaggedPointer +{ + static constexpr uint64_t DeviceIDBits = 4; + static constexpr uint64_t AllocationIDBits = 20; + static constexpr uint64_t OffsetBits = 40; + + uint64_t deviceId : DeviceIDBits; + uint64_t allocationId : AllocationIDBits; + uint64_t offset : OffsetBits; + + static void* Pack(uint32_t deviceId, uint32_t allocationId, uint64_t offset); + static TaggedPointer Unpack(const void* ptr); + uint64_t GetUniqueId() const; +}; + +static_assert( + sizeof(TaggedPointer) == sizeof(void*), + "DML requires a 64-bit architecture"); +static_assert( + TaggedPointer::DeviceIDBits + TaggedPointer::AllocationIDBits + TaggedPointer::OffsetBits == sizeof(void*) * CHAR_BIT, + "DML requires a 64-bit architecture"); + +} // namespace tfdml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp index 5dc1213bd76f0..d3e5ca6d58e39 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.cpp @@ -4,6 +4,7 @@ #include "precomp.h" #include "ExecutionContext.h" #include "CommandQueue.h" +#include "DmlGpuAllocator.h" namespace Dml { @@ -22,7 +23,7 @@ namespace Dml ORT_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(m_d3dDevice.GetAddressOf()))); } - void ExecutionContext::SetAllocator(std::weak_ptr allocator) + void ExecutionContext::SetAllocator(std::weak_ptr allocator) { m_dmlRecorder.SetAllocator(allocator); } @@ -40,42 +41,113 @@ namespace Dml SetCommandRecorder(&m_dmlRecorder); - std::vector barriers; - - if (!(dstState & D3D12_RESOURCE_STATE_COPY_DEST)) - { - barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(dstBuffer, dstState, D3D12_RESOURCE_STATE_COPY_DEST)); - } - if (!(srcState & D3D12_RESOURCE_STATE_COPY_SOURCE)) + if (dstBuffer == srcBuffer) { - barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(srcBuffer, srcState, D3D12_RESOURCE_STATE_COPY_SOURCE)); - } + // This type of copy is not common and is only used in rare circumstances. Because a resource + // cannot be both in a source and destination state at the same time (without aliasing), we copy + // the source resource to an intermediate one, and then copy the intermediate resource to the + // destination resource. + D3D12_HEAP_PROPERTIES heapProperties = { + D3D12_HEAP_TYPE_DEFAULT, D3D12_CPU_PAGE_PROPERTY_UNKNOWN, D3D12_MEMORY_POOL_UNKNOWN, 0, 0}; + + D3D12_RESOURCE_DESC resourceDesc = { + D3D12_RESOURCE_DIMENSION_BUFFER, + 0, + byteCount, + 1, + 1, + 1, + DXGI_FORMAT_UNKNOWN, + {1, 0}, + D3D12_TEXTURE_LAYOUT_ROW_MAJOR, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS}; + + ComPtr intermediateBuffer; + ORT_THROW_IF_FAILED(m_d3dDevice->CreateCommittedResource( + &heapProperties, + D3D12_HEAP_FLAG_NONE, + &resourceDesc, + D3D12_RESOURCE_STATE_COPY_DEST, + nullptr, + IID_GRAPHICS_PPV_ARGS(intermediateBuffer.GetAddressOf()))); + + std::vector barriers; + + if (!(srcState & D3D12_RESOURCE_STATE_COPY_SOURCE)) + { + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(srcBuffer, srcState, D3D12_RESOURCE_STATE_COPY_SOURCE)); + m_dmlRecorder.ResourceBarrier(barriers); + } + + m_dmlRecorder.CopyBufferRegion(intermediateBuffer.Get(), 0, srcBuffer, srcOffset, byteCount); + + // Reset src barrier state + for (auto& barrier : barriers) + { + std::swap(barrier.Transition.StateBefore, barrier.Transition.StateAfter); + } + + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(intermediateBuffer.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_COPY_SOURCE)); + + if (!(dstState & D3D12_RESOURCE_STATE_COPY_DEST)) + { + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(dstBuffer, dstState, D3D12_RESOURCE_STATE_COPY_DEST)); + } - if (!barriers.empty()) - { m_dmlRecorder.ResourceBarrier(barriers); - } + m_dmlRecorder.CopyBufferRegion(dstBuffer, dstOffset, intermediateBuffer.Get(), 0, byteCount); - m_dmlRecorder.CopyBufferRegion(dstBuffer, dstOffset, srcBuffer, srcOffset, byteCount); + // Reset dst barrier state + if (!(dstState & D3D12_RESOURCE_STATE_COPY_DEST)) + { + barriers.clear(); + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(dstBuffer, D3D12_RESOURCE_STATE_COPY_DEST, dstState)); + m_dmlRecorder.ResourceBarrier(barriers); + } - // Reset barrier state - if (!barriers.empty()) + // Keep the intermediate buffer alive until we're done with it + QueueReference(intermediateBuffer.Get()); + } + else { + std::vector barriers; + + if (!(dstState & D3D12_RESOURCE_STATE_COPY_DEST)) + { + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(dstBuffer, dstState, D3D12_RESOURCE_STATE_COPY_DEST)); + } + if (!(srcState & D3D12_RESOURCE_STATE_COPY_SOURCE)) + { + barriers.push_back(CD3DX12_RESOURCE_BARRIER::Transition(srcBuffer, srcState, D3D12_RESOURCE_STATE_COPY_SOURCE)); + } + + if (!barriers.empty()) + { + m_dmlRecorder.ResourceBarrier(barriers); + } + + m_dmlRecorder.CopyBufferRegion(dstBuffer, dstOffset, srcBuffer, srcOffset, byteCount); + + // Reset barrier state for (auto& barrier : barriers) { std::swap(barrier.Transition.StateBefore, barrier.Transition.StateAfter); } - m_dmlRecorder.ResourceBarrier(barriers); + if (!barriers.empty()) + { + m_dmlRecorder.ResourceBarrier(barriers); + } } } void ExecutionContext::FillBufferWithPattern( ID3D12Resource* dstBuffer, - gsl::span pattern /* Data type agnostic value, treated as raw bits */) + uint64_t offset, + gsl::span value /* Data type agnostic value, treated as raw bits */) { SetCommandRecorder(&m_dmlRecorder); - m_dmlRecorder.FillBufferWithPattern(dstBuffer, pattern); + m_dmlRecorder.FillBufferWithPattern(dstBuffer, offset, value); } void ExecutionContext::ExecuteCommandList( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h index e7a6fa3d07296..f6c5b98400386 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionContext.h @@ -10,6 +10,7 @@ namespace Dml { class CommandQueue; + class DmlGpuAllocator; // Asynchronously performs GPU work, and automatically manages command list recording and submission to queues. // Work submitted to the ExecutionContext is typically recorded onto a command list and may not immediately begin @@ -26,7 +27,7 @@ namespace Dml bool cpuSyncSpinningEnabled, bool keepOpen); - void SetAllocator(std::weak_ptr allocator); + void SetAllocator(std::weak_ptr allocator); // Waits for flushed work, discards unflushed work, and discards associated references to // prevent circular references. Must be the last call on the object before destruction. @@ -46,7 +47,8 @@ namespace Dml void FillBufferWithPattern( ID3D12Resource* dstBuffer, - gsl::span pattern /* Data type agnostic value, treated as raw bits */); + uint64_t offset, + gsl::span value /* Data type agnostic value, treated as raw bits */); void InitializeOperator( IDMLCompiledOperator* op, @@ -89,6 +91,8 @@ namespace Dml bool CpuSyncSpinningEnabled() const { return m_cpuSyncSpinningEnabled; } bool IsClosed() const { return m_closed; } + bool Closed() const { return m_closed; } + private: Microsoft::WRL::ComPtr m_d3dDevice; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index cb6fc165a932f..7f3dafe3db240 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -8,7 +8,9 @@ #include "PooledUploadHeap.h" #include "ReadbackHeap.h" #include "ExecutionContext.h" +#include "DmlReservedResourceSubAllocator.h" #include "BucketizedBufferAllocator.h" +#include "DmlCpuAllocator.h" #include "MLOperatorAuthorImpl.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" #include "core/providers/dml/OperatorAuthorHelper/OperatorHelper.h" @@ -17,8 +19,15 @@ #include "core/graph/indexed_sub_graph.h" #include "core/framework/compute_capability.h" #include "core/framework/fallback_cpu_capability.h" -#include "DmlCommittedResourceAllocator.h" +#include "core/framework/bfc_arena.h" #include "DmlCommittedResourceWrapper.h" +#include "DmlBufferRegion.h" +#include "DmlReservedResourceAllocatorWrapper.h" +#include "DmlGpuAllocator.h" +#include "DmlBuffer.h" +#include "DmlTaggedPointer.h" +#include "DmlExternalGpuAllocator.h" +#include "DmlAllocatorRoundingMode.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/common/parse_string.h" #include "core/providers/dml/dml_provider_factory_creator.h" @@ -73,7 +82,8 @@ namespace Dml bool enableMetacommands, bool enableGraphCapture, bool enableSyncSpinning, - bool disableMemoryArena) : + bool disableMemoryArena, + bool enableBfcAllocator) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = executionContext->GetCommandListTypeForQueue(); @@ -86,7 +96,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(dmlDevice->GetParentDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), executionContext, enableMetacommands, enableGraphCapture, enableSyncSpinning, disableMemoryArena, enableBfcAllocator); } std::vector> @@ -115,48 +125,25 @@ namespace Dml m_context->GetCurrentCompletionEvent().WaitForSignal(m_cpuSyncSpinningEnabled); } - HRESULT __stdcall ExecutionProviderImpl::AllocatePooledResource( - size_t size, - AllocatorRoundingMode roundingMode, - ID3D12Resource **d3dResource, - IUnknown** pooledResource - ) const noexcept + DmlBuffer ExecutionProviderImpl::AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode) const { - ORT_TRY - { - ComPtr allocation; - allocation.Attach(static_cast(m_allocator->Alloc(size, roundingMode))); - - const auto* allocInfo = m_allocator->DecodeDataHandle(allocation.Get()); - - ComPtr resource = allocInfo->GetResource(); - resource.CopyTo(d3dResource); - *pooledResource = allocation.Detach(); - return S_OK; - } - ORT_CATCH_RETURN + return m_gpuAllocator->AllocateDefaultBuffer(size, roundingMode); } - ID3D12Resource* __stdcall ExecutionProviderImpl::DecodeResource(void* allocation) const noexcept + D3D12BufferRegion ExecutionProviderImpl::GetBufferForTensor(IMLOperatorTensor* tensor) const { - ORT_TRY - { - const AllocationInfo* allocInfo = m_allocator->DecodeDataHandle(allocation); - return allocInfo->GetResource(); - } - ORT_CATCH_GENERIC - { - return nullptr; - } + auto tensorWrapper = static_cast(tensor); + return tensorWrapper->GetBufferRegion(); } - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, bool disableMemoryArena) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ExecutionContext* executionContext, bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, bool disableMemoryArena, bool enableBfcAllocator) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), m_areMetacommandsEnabled(enableMetacommands), m_graphCaptureEnabled(enableGraphCapture), m_cpuSyncSpinningEnabled(enableCpuSyncSpinning), m_memoryArenaDisabled(disableMemoryArena), + m_bfcAllocatorEnabled(enableBfcAllocator), m_context(executionContext) { D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; @@ -225,26 +212,59 @@ namespace Dml m_lastUploadFlushTime = std::chrono::steady_clock::now(); } + static std::shared_ptr CreateBfcAllocator(std::shared_ptr subAllocator) + { + auto bfcArena = std::make_unique( + std::make_unique(subAllocator), + onnxruntime::BFCArena::DEFAULT_MAX_MEM, + onnxruntime::ArenaExtendStrategy::kSameAsRequested, + onnxruntime::BFCArena::DEFAULT_INITIAL_CHUNK_SIZE_BYTES, + onnxruntime::BFCArena::DEFAULT_MAX_DEAD_BYTES_PER_CHUNK, + onnxruntime::BFCArena::DEFAULT_INITIAL_GROWTH_CHUNK_SIZE_BYTES, + onnxruntime::BFCArena::DEFAULT_MAX_POWER_OF_TWO_EXTEND_BYTES); + + return bfcArena; + } + std::vector ExecutionProviderImpl::CreatePreferredAllocators() { - if (!m_allocator) + if (!m_gpuAllocator) { - // Create an allocator for D3D12 buffers used to hold tensor data. The returned buffers from the allocator - // should be DEFAULT heap buffers which can be used as UAVs, and which start in UAV state. - m_allocator = std::make_shared(m_d3d12Device.Get(), + auto subAllocator = std::make_shared( + m_d3d12Device.Get(), m_context.Get(), // TODO(leca): REVIEW: Will it cause memory issue when m_context is released in EP while alloc is released in sessionState? + m_queue.Get(), CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), - D3D12_HEAP_FLAG_NONE, + D3D12_HEAP_FLAG_ALLOW_ONLY_BUFFERS, D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - std::make_unique(m_d3d12Device.Get())); - m_context->SetAllocator(m_allocator); + D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + + m_bfcAllocator = CreateBfcAllocator(subAllocator); + + m_bucketizedAllocator = std::make_shared( + m_d3d12Device.Get(), + m_context.Get(), // TODO(leca): REVIEW: Will it cause memory issue when m_context is released in EP while alloc is released in sessionState? + CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), + D3D12_HEAP_FLAG_ALLOW_ONLY_BUFFERS, + D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + + // Wrap the BFC allocator into our own allocator + m_gpuAllocator = std::make_shared( + m_bfcAllocator.get(), + m_bucketizedAllocator.get(), + subAllocator, + m_bfcAllocatorEnabled ? ActiveAllocator::BfcAllocator : ActiveAllocator::BucketizedBufferAllocator); + m_context->SetAllocator(m_gpuAllocator); + + m_externalGpuAllocator = std::make_shared(m_d3d12Device.Get()); + // CPU Allocator used to create buffers for the MemcpyFromHost, Shape and Size operators. OrtMemoryInfo memoryInfo(onnxruntime::CPU, OrtAllocatorType::OrtDeviceAllocator); memoryInfo.mem_type = ::OrtMemType::OrtMemTypeCPUInput; m_cpuInputAllocator = std::make_shared(memoryInfo); } - return std::vector{m_allocator, m_cpuInputAllocator,}; + return std::vector{m_gpuAllocator, m_externalGpuAllocator, m_cpuInputAllocator}; } HRESULT __stdcall ExecutionProviderImpl::GetD3DDevice(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept @@ -383,10 +403,8 @@ namespace Dml if (tensor) { assert(tensor->IsDataInterface()); - const AllocationInfo* allocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(tensor).GetDataInterface().Get()); - ID3D12Resource* resource = allocInfo->GetResource(); - D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); - bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + auto bufferRegion = GetBufferForTensor(tensor); + bufferBindings.push_back(bufferRegion.GetBufferBinding()); bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); } else @@ -473,15 +491,11 @@ namespace Dml // // CPU -> GPU copy (upload) // - const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(dst).GetDataInterface().Get()); - - ID3D12Resource* dstData = dstAllocInfo->GetResource(); - const void* srcData = src->GetData(); - - constexpr uint64_t dstOffset = 0; - const auto dstState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state - - m_uploadHeap->BeginUploadToGpu(dstData, dstOffset, dstState, AsByteSpan(srcData, dataSizeInBytes)); + auto dstBufferRegion = GetBufferForTensor(dst); + ID3D12Resource* dstData = dstBufferRegion.GetD3D12Resource(); + const auto dstState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + const uint64_t dstOffset = dstBufferRegion.Offset(); + m_uploadHeap->BeginUploadToGpu(dstData, dstOffset, dstState, AsByteSpan(src->GetData(), dataSizeInBytes)); // Continuously upload memory located in upload heaps during session initialization to avoid running out of it if (!m_sessionInitialized) @@ -494,29 +508,28 @@ namespace Dml // // GPU -> CPU copy (readback) // - - void* dstData = dst->GetData(); - const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src).GetDataInterface().Get()); - - ID3D12Resource* srcData = srcAllocInfo->GetResource(); - - const uint64_t srcOffset = 0; - const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state - - // Performs a blocking call to synchronize and read back data from the GPU into the destination buffer - m_readbackHeap->ReadbackFromGpu(AsByteSpan(dstData, dataSizeInBytes), srcData, srcOffset, srcState); + auto srcBufferRegion = GetBufferForTensor(src); + ID3D12Resource* srcData = srcBufferRegion.GetD3D12Resource(); + const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + const uint64_t srcOffset = srcBufferRegion.Offset(); + m_readbackHeap->ReadbackFromGpu(AsByteSpan(dst->GetData(), dataSizeInBytes), srcData, srcOffset, srcState); } else if (!src->IsCpuData() && !dst->IsCpuData()) { // // GPU -> GPU copy // - const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src).GetDataInterface().Get()); - const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(dst).GetDataInterface().Get()); + auto srcBufferRegion = GetBufferForTensor(src); + ID3D12Resource* srcData = srcBufferRegion.GetD3D12Resource(); + const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + const uint64_t srcOffset = srcBufferRegion.Offset(); - ID3D12Resource* srcData = srcAllocInfo->GetResource(); - ID3D12Resource* dstData = dstAllocInfo->GetResource(); - m_context->CopyBufferRegion(dstData, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, srcData, 0, D3D12_RESOURCE_STATE_UNORDERED_ACCESS, dataSizeInBytes); + auto dstBufferRegion = GetBufferForTensor(dst); + ID3D12Resource* dstData = dstBufferRegion.GetD3D12Resource(); + const auto dstState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + const uint64_t dstOffset = dstBufferRegion.Offset(); + + m_context->CopyBufferRegion(dstData, dstOffset, dstState, srcData, srcOffset, srcState, dataSizeInBytes); } else { @@ -536,7 +549,7 @@ namespace Dml ORT_THROW_HR_IF(E_INVALIDARG, dst.size() != src.size()); // Source and destination for batched GPU -> CPU copies - std::vector srcDatas; + std::vector srcBufferRegions; std::vector dstDatas; std::vector dataSizesInBytes; @@ -565,15 +578,12 @@ namespace Dml ORT_THROW_HR_IF(E_INVALIDARG, dataSizesInBytes.back() != ComputeByteSizeFromTensor(*src[i])); // Tensors must be the same size dstDatas.push_back(dst[i]->GetData()); - const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(src[i]).GetDataInterface().Get()); - - srcDatas.push_back(srcAllocInfo->GetResource()); + srcBufferRegions.push_back(GetBufferForTensor(src[i])); } - const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state - // Performs a blocking call to synchronize and read back data from the GPU into the destination buffer - m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcDatas, srcState); + const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcBufferRegions, srcState); return S_OK; } @@ -590,9 +600,8 @@ namespace Dml auto mlTensor = MLOperatorTensor(dst).GetDataInterface(); if (mlTensor != nullptr) { - const AllocationInfo* dstAllocInfo = m_allocator->DecodeDataHandle(mlTensor.Get()); - ID3D12Resource* dstData = dstAllocInfo->GetResource(); - m_context->FillBufferWithPattern(dstData, rawValue); + auto dstBufferRegion = GetBufferForTensor(dst); + m_context->FillBufferWithPattern(dstBufferRegion.GetD3D12Resource(), dstBufferRegion.Offset(), rawValue); } return S_OK; @@ -947,9 +956,14 @@ namespace Dml Status ExecutionProviderImpl::CopyTensors(const std::vector& src_dst_pairs) const { // Source and destination for batched GPU -> CPU copies - std::vector srcDatas; + std::vector srcBufferRegions; + srcBufferRegions.reserve(src_dst_pairs.size()); + std::vector dstDatas; + dstDatas.reserve(src_dst_pairs.size()); + std::vector dataSizesInBytes; + dataSizesInBytes.reserve(src_dst_pairs.size()); assert(!m_closed); auto provider = const_cast(this); @@ -988,15 +1002,12 @@ namespace Dml ORT_THROW_HR_IF(E_INVALIDARG, dataSizesInBytes[i] != ComputeByteSizeFromTensor(srcWrapper)); // Tensors must be the same size dstDatas.push_back(dstWrapper.GetData()); - const AllocationInfo* srcAllocInfo = m_allocator->DecodeDataHandle(MLOperatorTensor(&srcWrapper).GetDataInterface().Get()); - - srcDatas.push_back(srcAllocInfo->GetResource()); + srcBufferRegions.push_back(GetBufferForTensor(&srcWrapper)); } - const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; // GPU resources are always kept in UAV state - // Performs a blocking call to synchronize and read back data from the GPU into the destination buffer - m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcDatas, srcState); + const auto srcState = D3D12_RESOURCE_STATE_UNORDERED_ACCESS; + m_readbackHeap->ReadbackFromGpu(dstDatas, dataSizesInBytes, srcBufferRegions, srcState); return onnxruntime::common::Status::OK(); } @@ -1018,47 +1029,14 @@ namespace Dml m_context->QueueReference(object); } - void ExecutionProviderImpl::GetShadowCopyIfRequired( - bool isInternalOperator, - IUnknown* data, - IUnknown** dataCopy) const - { - assert(!m_closed); - - *dataCopy = data; - data->AddRef(); - } - - void ExecutionProviderImpl::GetABIDataInterface( - bool isInternalOperator, - IUnknown* data, - IUnknown** abiData) const + D3D12BufferRegion ExecutionProviderImpl::GetBufferRegion(void* opaquePointer, uint64_t size) const { - assert(!m_closed); - - if (isInternalOperator) - { - *abiData = data; - data->AddRef(); - } - else - { -#ifdef _GAMING_XBOX - ComPtr wrappedResource = Microsoft::WRL::Make(m_allocator->DecodeDataHandle(data)->GetResource()); - *abiData = wrappedResource.Detach(); -#else - ComPtr resource = m_allocator->DecodeDataHandle(data)->GetResource(); - *abiData = resource.Detach(); -#endif - } + return m_gpuAllocator->CreateBufferRegion(opaquePointer, size); } - uint64_t ExecutionProviderImpl::TryGetPooledAllocationId( - IUnknown* data, - bool isInternalOperator) + uint64_t ExecutionProviderImpl::GetUniqueId(void* opaquePointer) { - assert(!isInternalOperator); - return m_allocator->DecodeDataHandle(data)->GetPooledResourceId(); + return m_gpuAllocator->GetUniqueId(opaquePointer); } void ExecutionProviderImpl::GetABIExecutionInterfaceAndInvalidateState( @@ -1166,7 +1144,7 @@ namespace Dml std::shared_ptr ExecutionProviderImpl::GetGpuAllocator() { - return m_allocator; + return m_gpuAllocator; } std::shared_ptr ExecutionProviderImpl::GetCpuInputAllocator() @@ -1188,7 +1166,7 @@ namespace Dml { // Allocations after this point are potentially transient and their sizes are // rounded to enable pooling. - m_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); + m_gpuAllocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); } m_sessionInitialized = true; @@ -1250,15 +1228,16 @@ namespace Dml bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, - bool disableMemoryArena) + bool disableMemoryArena, + bool enableBfcAllocator) { - return std::make_unique(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning, disableMemoryArena); + return std::make_unique(dmlDevice, executionContext, enableMetacommands, enableGraphCapture, enableCpuSyncSpinning, disableMemoryArena, enableBfcAllocator); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) { - Dml::BucketizedBufferAllocator* pAllocationInfo = static_cast(allocator); - return pAllocationInfo->DecodeDataHandle(ptr)->GetResource(); + auto pAllocationInfo = static_cast(allocator); + return pAllocationInfo->GetAllocationInfo(ptr)->GetD3D12Resource(); } void FlushContext(onnxruntime::IExecutionProvider* provider) @@ -1285,12 +1264,10 @@ namespace Dml void* CreateGPUAllocationFromD3DResource(ID3D12Resource* pResource) { - uint64_t pooledResourceId = 0; // Not a pooled resource - ComPtr resourceWrapper; wil::MakeOrThrow(pResource).As(&resourceWrapper); - ComPtr allocInfo = wil::MakeOrThrow(nullptr, 0, pooledResourceId, resourceWrapper.Get(), (size_t)pResource->GetDesc().Width); + ComPtr allocInfo = wil::MakeOrThrow(nullptr, 0, 0, resourceWrapper.Get(), (size_t)pResource->GetDesc().Width); return allocInfo.Detach(); } void FreeGPUAllocation(void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index c20969250fe84..2fb519d80ae77 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -5,6 +5,9 @@ #include "GraphTransformer.h" #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" +#include "DmlBufferRegion.h" +#include "DmlBuffer.h" +#include "DmlAllocatorRoundingMode.h" #include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h" #include "core/providers/dml/DmlExecutionProvider/src/DmlReusedCommandListState.h" @@ -18,6 +21,11 @@ using Base = Microsoft::WRL::RuntimeClass< TInterfaces...>; } +namespace onnxruntime +{ + class BFCArena; +} + namespace Dml { using Microsoft::WRL::ComPtr; @@ -26,6 +34,9 @@ namespace Dml class ExecutionContext; class BucketizedBufferAllocator; class ExecutionProvider; + class DmlGpuAllocator; + class DmlExternalGpuAllocator; + struct TaggedPointer; class ExecutionProviderImpl : public WRL::Base @@ -38,9 +49,11 @@ namespace Dml bool enableMetacommands, bool enableGraphCapture, bool enableCpuSyncSpinning, - bool disableMemoryArena); + bool disableMemoryArena, + bool enableBfcAllocator); void ReleaseCompletedReferences(); + uint64_t GetUniqueId(void* opaquePointer); public: // implements Dml::IExecutionProvider STDMETHOD(GetD3DDevice)(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept final; @@ -99,19 +112,7 @@ namespace Dml // IWinmlExecutionProvider methods void QueueReference(IUnknown* object) override; - void GetShadowCopyIfRequired( - bool isInternalOperator, - IUnknown* data, - IUnknown** dataCopy) const override; - - void GetABIDataInterface( - bool isInternalOperator, - IUnknown* data, - IUnknown** abiData) const override; - - uint64_t TryGetPooledAllocationId( - IUnknown* data, - bool isInternalOperator) override; + D3D12BufferRegion GetBufferRegion(void* opaquePointer, uint64_t size) const override; void GetABIExecutionInterfaceAndInvalidateState( bool isInternalOperator, @@ -136,15 +137,8 @@ namespace Dml void WaitForOutstandingWork(); - // Allocate a resource from pools. Releasing pooledResource returns it to the pool. - STDMETHOD(AllocatePooledResource)( - size_t size, - AllocatorRoundingMode roundingMode, - ID3D12Resource **d3dResource, - IUnknown* *pooledResource - ) const noexcept final; - - STDMETHOD_(ID3D12Resource*, DecodeResource)(void* allocation) const noexcept final; + // Allocate a resource from pools. Releasing the returned buffer returns it to the pool. + DmlBuffer ExecutionProviderImpl::AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode) const; std::shared_ptr GetKernelRegistry() const { @@ -191,6 +185,7 @@ namespace Dml uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE. ) const; + D3D12BufferRegion GetBufferForTensor(IMLOperatorTensor* tensor) const; void FlushUploadsIfReady() const; ComPtr m_d3d12Device; @@ -208,10 +203,14 @@ namespace Dml bool m_sessionInitialized = false; bool m_cpuSyncSpinningEnabled = false; bool m_memoryArenaDisabled = false; + bool m_bfcAllocatorEnabled = true; ComPtr m_context; std::unique_ptr m_uploadHeap; std::unique_ptr m_readbackHeap; - std::shared_ptr m_allocator; + std::shared_ptr m_bfcAllocator; + std::shared_ptr m_bucketizedAllocator; + std::shared_ptr m_gpuAllocator; + std::shared_ptr m_externalGpuAllocator; std::shared_ptr m_cpuInputAllocator; std::shared_ptr m_kernelRegistry; std::shared_ptr m_internalRegInfoMap; @@ -219,6 +218,7 @@ namespace Dml bool m_closed = false; mutable std::chrono::time_point m_lastUploadFlushTime; static constexpr std::chrono::milliseconds m_batchFlushInterval = std::chrono::milliseconds(10); + ComPtr m_queue; }; class DataTransfer : public onnxruntime::IDataTransfer @@ -262,8 +262,8 @@ namespace Dml bool enableMetacommands, bool enableGraphCapture, bool enableSyncSpinning, - bool disableMemoryArena - ); + bool disableMemoryArena, + bool enableBfcAllocator); std::unique_ptr GetDataTransfer() const final override { @@ -332,6 +332,16 @@ namespace Dml return m_impl->CreatePreferredAllocators(); } + virtual OrtDevice GetExternalOrtDeviceByMemType(OrtMemType mem_type) const final + { + if (mem_type == OrtMemType::OrtMemTypeDefault) + { + return OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DML_EXTERNAL, 0); + } + + return GetOrtDeviceByMemType(mem_type); + } + bool IsGraphCaptureEnabled() const override { return m_impl->GraphCaptureEnabled(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index 53538cfb79ab5..afa62660d1432 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -6,6 +6,8 @@ #include "MLOperatorAuthorImpl.h" #include "FusedGraphKernel.h" #include "DmlGraphFusionHelper.h" +#include "DmlManagedBuffer.h" +#include "DmlAllocatorRoundingMode.h" using namespace Windows::AI::MachineLearning::Adapter; @@ -63,13 +65,11 @@ namespace Dml UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) { - ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( - static_cast(persistentResourceSize), - AllocatorRoundingMode::Disabled, - m_persistentResource.GetAddressOf(), - m_persistentResourceAllocatorUnknown.GetAddressOf())); - - m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + auto buffer = m_provider->AllocatePooledResource(persistentResourceSize, AllocatorRoundingMode::Disabled); + m_persistentResource = buffer.GetD3D12Resource(); + m_persistentResourceBinding = buffer.GetBufferBinding(); + m_managedPersistentBuffer = wil::MakeOrThrow(std::move(buffer)); + m_winmlProvider->QueueReference(m_managedPersistentBuffer.Get()); } ORT_THROW_IF_FAILED(m_provider->InitializeOperator( @@ -79,7 +79,6 @@ namespace Dml // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); - m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnknown.Get()); std::for_each( initializeResourceRefs.begin(), @@ -117,7 +116,7 @@ namespace Dml // Get input resources for execution, excluding those which were specified as owned by DML and provided // at initialization instead. std::vector> inputTensors(kernelContext->InputCount()); - std::vector inputPtrs(kernelContext->InputCount()); + std::vector inputBufferRegions(kernelContext->InputCount()); for (int i = 0; i < kernelContext->InputCount(); ++i) { @@ -128,12 +127,16 @@ namespace Dml if (m_nonOwnedGraphInputsFromInitializers[i]) { - inputPtrs[i] = m_nonOwnedGraphInputsFromInitializers[i].Get(); + inputBufferRegions[i] = D3D12BufferRegion( + 0, + m_nonOwnedGraphInputsFromInitializers[i]->GetDesc().Width, + m_nonOwnedGraphInputsFromInitializers[i].Get()); } else if (!m_isInputsUploadedByDmlEP[i]) { ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); - inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + auto tensorWrapper = static_cast(inputTensors[i].Get()); + inputBufferRegions[i] = tensorWrapper->GetBufferRegion(); } } @@ -141,14 +144,14 @@ namespace Dml ExecuteOperator( m_compiledExecutionPlanOperator.Get(), m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, - inputPtrs, + inputBufferRegions, aux); ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); - m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnknown.Get()); + m_winmlProvider->QueueReference(m_managedPersistentBuffer.Get()); } else { @@ -179,7 +182,6 @@ namespace Dml m_outputShapes, m_winmlProvider.Get(), m_provider.Get(), - m_persistentResourceAllocatorUnknown.Get(), keepTemporaryResourceAlive); m_reusedCommandLists.push_back(std::move(m_reusedCommandLists.front())); @@ -192,7 +194,7 @@ namespace Dml void ExecuteOperator( IDMLCompiledOperator* op, _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, - gsl::span inputTensors, + gsl::span inputBufferRegions, gsl::span outputTensors) const { auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) @@ -201,10 +203,10 @@ namespace Dml { if (tensor) { + auto tensorWrapper = static_cast(tensor); + assert(tensor->IsDataInterface()); - ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get()); - D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); - bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bufferBindings.push_back(tensorWrapper->GetBufferRegion().GetBufferBinding()); bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); } else @@ -215,29 +217,28 @@ namespace Dml } }; - auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span& resources) + auto FillBindingsFromBufferRegions = [](auto& bufferBindings, auto& bindingDescs, gsl::span& bufferRegions) { - for (ID3D12Resource* resource : resources) + for (const D3D12BufferRegion& bufferRegion : bufferRegions) { - if (resource) + bufferBindings.push_back(bufferRegion.GetBufferBinding()); + + if (bufferRegion.GetD3D12Resource() != nullptr) { - D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); - bufferBindings.push_back({ resource, 0, resourceDesc.Width }); bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); } else { - bufferBindings.push_back({ nullptr, 0, 0 }); bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); } } }; std::vector inputBufferBindings; - inputBufferBindings.reserve(inputTensors.size()); + inputBufferBindings.reserve(inputBufferRegions.size()); std::vector inputBindings; - inputBindings.reserve(inputTensors.size()); - FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors); + inputBindings.reserve(inputBufferRegions.size()); + FillBindingsFromBufferRegions(inputBufferBindings, inputBindings, inputBufferRegions); std::vector outputBufferBindings; outputBufferBindings.reserve(outputTensors.size()); @@ -264,7 +265,7 @@ namespace Dml std::optional m_persistentResourceBinding; ComPtr m_persistentResource; - ComPtr m_persistentResourceAllocatorUnknown; // Controls when the persistent resource is returned to the allocator + ComPtr m_managedPersistentBuffer; std::vector m_isInputsUploadedByDmlEP; std::vector> m_nonOwnedGraphInputsFromInitializers; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index f4c3f326274ad..f630fdaa81610 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -6,6 +6,8 @@ #include "directx/d3d12.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlAllocatorRoundingMode.h" +#include "DmlBuffer.h" interface IDMLCompiledOperator; struct DML_BUFFER_BINDING; @@ -13,6 +15,8 @@ struct DML_BINDING_DESC; namespace Dml { + class DmlManagedBufferRegion; + struct Binding { // Non-null if required at the stage where it is used, i.e. Initialization @@ -72,11 +76,10 @@ namespace Dml STDMETHOD_(D3D12_COMMAND_LIST_TYPE, GetCommandListTypeForQueue)() const noexcept = 0; STDMETHOD_(void, Flush)() const noexcept = 0; - STDMETHOD_(ID3D12Resource*, DecodeResource)(void* allocation) const noexcept = 0; - STDMETHOD(AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode, ID3D12Resource **d3dResource, IUnknown* *pooledResource)) const noexcept = 0; - STDMETHOD_(bool, IsMcdmDevice)() const noexcept = 0; STDMETHOD_(bool, CustomHeapsSupported)() const noexcept = 0; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept = 0; + + virtual DmlBuffer AllocatePooledResource(size_t size, AllocatorRoundingMode roundingMode) const = 0; }; } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 26559b54bceb6..96143b7d2a551 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -9,9 +9,13 @@ #include "core/session/onnxruntime_c_api.h" #include "core/providers/dml/DmlExecutionProvider/inc/MLOperatorAuthor.h" +#include "DmlBufferRegion.h" #include "MLOperatorAuthorImpl.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h" +#include "DmlGpuAllocator.h" +#include "DmlAllocationInfo.h" +#include "DmlTaggedPointer.h" using namespace Microsoft::WRL; @@ -101,27 +105,6 @@ namespace Windows::AI::MachineLearning::Adapter return strcmp(info.name, onnxruntime::CPU) && !(info.mem_type == ::OrtMemType::OrtMemTypeCPUOutput || info.mem_type == ::OrtMemType::OrtMemTypeCPUInput); } - // Translate the data object stored in a tensor to the type which will be returned through - // the ABI. The translation is determined by the provider and based on options with which the - // kernels are registered. - void TranslateAllocationDataToAbi( - IWinmlExecutionProvider* winmlProvider, - bool isInternalOperator, - const ::OrtMemoryInfo& allocInfo, - IUnknown* allocation, - IUnknown** abiAllocation) - { - if (winmlProvider) - { - winmlProvider->GetABIDataInterface(isInternalOperator, allocation, abiAllocation); - } - else - { - ComPtr tmp = allocation; - *abiAllocation = tmp.Detach(); - } - } - // // Traits for numeric attribute types // @@ -1671,41 +1654,19 @@ namespace Windows::AI::MachineLearning::Adapter { if (impl) { - if (isDataInterface) - { - // We assume that all data handles derive from IUnknown as their first base. - m_dataInterface = static_cast(m_impl->MutableDataRaw()); + m_tensorData = m_impl->MutableDataRaw(); + } + } - if (m_dataInterface) - { - if (m_winmlExecutionProvider) - { - // The resource may require conversion to the layout expected according to the kernel options. - // This will return either the original object or a shadow copy which uses a different layout. - // This pattern assumes that Lotus is not re-using tensor allocations, so each output is - // a fresh allocation which will not trigger a conversion in the provider. - m_winmlExecutionProvider->GetShadowCopyIfRequired(m_internalOperator, m_dataInterface.Get(), m_dataInterfaceOrShadowCopy.GetAddressOf()); - - // Get the actual object to be returned from the ABI, which varies for internal and external - // kernels (i.e. ID3D12Resource, versus something that tracks the layout). - TranslateAllocationDataToAbi( - m_winmlExecutionProvider.Get(), - m_internalOperator, - m_impl->Location(), - m_dataInterfaceOrShadowCopy ? m_dataInterfaceOrShadowCopy.Get() : m_dataInterface.Get(), - m_abiDataInterface.GetAddressOf()); - } - else - { - m_abiDataInterface = m_dataInterface; - } - } - } - else - { - m_tensorData = m_impl->MutableDataRaw(); - } + Dml::D3D12BufferRegion TensorWrapper::GetBufferRegion() const + { + if (m_impl->Location().device.MemType() == OrtDevice::MemType::DML_EXTERNAL) + { + auto allocInfo = static_cast(m_tensorData); + return Dml::D3D12BufferRegion(0, allocInfo->GetD3D12Resource()->GetDesc().Width, allocInfo->GetD3D12Resource()); } + + return m_winmlExecutionProvider->GetBufferRegion(m_tensorData, m_impl->SizeInBytes()); } uint32_t STDMETHODCALLTYPE TensorWrapper::GetDimensionCount() const noexcept @@ -1782,7 +1743,7 @@ namespace Windows::AI::MachineLearning::Adapter return nullptr; } - return m_isDataInterface ? nullptr : m_tensorData; + return m_tensorData; } void STDMETHODCALLTYPE TensorWrapper::GetDataInterface(IUnknown** dataInterface) noexcept @@ -1794,7 +1755,9 @@ namespace Windows::AI::MachineLearning::Adapter } else { - m_abiDataInterface.CopyTo(dataInterface); + auto bufferRegion = GetBufferRegion(); + bufferRegion.GetD3D12Resource()->AddRef(); + *dataInterface = bufferRegion.GetD3D12Resource(); } } @@ -1808,7 +1771,7 @@ namespace Windows::AI::MachineLearning::Adapter totalInputTensorCount += static_cast(inputTensor.size()); } std::vector resourcesToTransition; - resourcesToTransition.reserve(totalInputTensorCount + m_outputTensors.size() + m_temporaryAllocations.size()); + resourcesToTransition.reserve(totalInputTensorCount + m_outputTensors.size() + m_temporaryBuffers.size()); for (uint32_t i = 0; i < m_inputTensors.size(); ++i) { @@ -1849,9 +1812,9 @@ namespace Windows::AI::MachineLearning::Adapter } } - for (auto& tempAlloc : m_temporaryAbiAllocations) + for (auto& tempBuffer : m_temporaryBuffers) { - resourcesToTransition.push_back(tempAlloc.Get()); + resourcesToTransition.push_back(tempBuffer.GetD3D12Resource()); } m_winmlProvider->TransitionResourcesForOperator( @@ -1903,8 +1866,7 @@ namespace Windows::AI::MachineLearning::Adapter { if (m_winmlProvider) { - m_temporaryAllocations.clear(); - m_temporaryAbiAllocations.clear(); + m_temporaryBuffers.clear(); } } @@ -2216,16 +2178,6 @@ namespace Windows::AI::MachineLearning::Adapter } HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::AllocateTemporaryData(size_t size, IUnknown** abiAllocation) const noexcept - { - ORT_TRY - { - uint64_t allocId; - return AllocateTemporaryData(size, abiAllocation, &allocId); - } - ORT_CATCH_RETURN - } - - HRESULT STDMETHODCALLTYPE OpKernelContextWrapper::AllocateTemporaryData(size_t size, IUnknown** abiAllocation, uint64_t* allocId) const { ORT_TRY { @@ -2240,27 +2192,35 @@ namespace Windows::AI::MachineLearning::Adapter return E_FAIL; } - ComPtr allocation; - allocation.Attach(static_cast(alloc->Alloc(size))); - - *allocId = m_winmlProvider->TryGetPooledAllocationId(allocation.Get(), 0); - - TranslateAllocationDataToAbi(m_winmlProvider.Get(), m_internalOperator, alloc->Info(), allocation.Get(), abiAllocation); - - if (m_winmlProvider->TransitionsRequiredForOperator(m_internalOperator)) - { - m_winmlProvider->TransitionResourcesForOperator(true, 1, abiAllocation); - } + auto dml_gpu_allocator = static_cast(alloc.get()); + auto buffer = dml_gpu_allocator->AllocateDefaultBuffer(size); + buffer.GetD3D12Resource()->AddRef(); + *abiAllocation = buffer.GetD3D12Resource(); // Ensure the allocation is freed and transitioned when the context destructs - m_temporaryAllocations.push_back(allocation); - m_temporaryAbiAllocations.push_back(*abiAllocation); + m_temporaryBuffers.push_back(std::move(buffer)); return S_OK; } ORT_CATCH_RETURN } + const Dml::D3D12BufferRegion& OpKernelContextWrapper::AllocateDefaultBuffer(size_t size) + { + VerifyNotClosed(); + + onnxruntime::AllocatorPtr alloc; + THROW_IF_NOT_OK(m_impl->GetTempSpaceAllocator(&alloc)); + + ORT_THROW_HR_IF(E_FAIL, !IsAllocationInterface(alloc->Info())); + auto dml_gpu_allocator = static_cast(alloc.get()); + auto buffer = dml_gpu_allocator->AllocateDefaultBuffer(size); + + // Ensure the allocation is freed and transitioned when the context destructs + m_temporaryBuffers.push_back(std::move(buffer)); + return m_temporaryBuffers.back().Region(); + } + void STDMETHODCALLTYPE OpKernelContextWrapper::GetExecutionInterface(IUnknown** executionInterface) const noexcept { m_abiExecutionObject.CopyTo(executionInterface); @@ -2574,14 +2534,16 @@ namespace Windows::AI::MachineLearning::Adapter } } - ComPtr kernelContextWrapper = wil::MakeOrThrow( - context, - Info().GetExecutionProvider(), - m_internalOperator, - m_requiresOutputShapesAtCreation ? &m_inferredOutputShapes : nullptr); + { + ComPtr kernelContextWrapper = wil::MakeOrThrow( + context, + Info().GetExecutionProvider(), + m_internalOperator, + m_requiresOutputShapesAtCreation ? &m_inferredOutputShapes : nullptr); - ORT_THROW_IF_FAILED(m_kernel->Compute(kernelContextWrapper.Get())); - kernelContextWrapper->Close(); + ORT_THROW_IF_FAILED(m_kernel->Compute(kernelContextWrapper.Get())); + kernelContextWrapper->Close(); + } // Ensure that scheduled work, if any, is completed before freeing the kernel if the execution // provider requires this. diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 7e51ce026d365..75488552edaea 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -12,6 +12,9 @@ #include "core/framework/tensorprotoutils.h" #include #include +#include "core/providers/dml/DmlExecutionProvider/src/DmlBuffer.h" +#include "DmlBufferRegion.h" +#include "DmlBuffer.h" interface IDMLOperator; @@ -251,6 +254,8 @@ class TensorWrapper : public WRL::Base, public Closable MLOperatorTensorDataType STDMETHODCALLTYPE GetTensorDataType() const noexcept override; + Dml::D3D12BufferRegion GetBufferRegion() const; + bool STDMETHODCALLTYPE IsCpuData() const noexcept override; bool STDMETHODCALLTYPE IsDataInterface() const noexcept override; @@ -270,14 +275,7 @@ class TensorWrapper : public WRL::Base, public Closable bool m_internalOperator = false; void* m_tensorData = nullptr; - ComPtr m_dataInterface; bool m_isDataInterface = false; - - // The returned data may be a converted shadow copy, and the piece of it which - // is returned may vary according to kernel registration options. - ComPtr m_dataInterfaceOrShadowCopy; - ComPtr m_abiDataInterface; - }; class OnnxTensorWrapper : public WRL::Base, public Closable @@ -476,9 +474,7 @@ class OpKernelContextWrapper : public WRL::Base GetInputTensors(); std::vector GetOutputTensors(const EdgeShapes& outputShapes); + const Dml::D3D12BufferRegion& AllocateDefaultBuffer(uint64_t size); onnxruntime::OpKernelContext* GetOpKernelContext() { return m_impl; } @@ -510,8 +507,7 @@ class OpKernelContextWrapper : public WRL::Base> m_temporaryAllocations; - mutable std::vector> m_temporaryAbiAllocations; + mutable std::vector m_temporaryBuffers; }; class AbiOpKernel : public onnxruntime::OpKernel diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h index 1de88a61a0d77..24454c9f2018e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h @@ -2,8 +2,8 @@ #include "../MLOperatorAuthorImpl.h" #include "../../../OperatorAuthorHelper/OperatorHelper.h" - #include "../External/D3DX12/d3dx12.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.h" #include "directx/d3d12.h" // NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback @@ -111,7 +111,7 @@ class GpuDFTOperator : public WRL::Base // Allocate temporary buffers if needed struct ResourceDesc { - ComPtr Resource; + Dml::D3D12BufferRegion BufferRegion; std::array Sizes; std::array Strides; }; @@ -417,16 +417,6 @@ class GpuDFTOperator : public WRL::Base auto outputDims = GetTensorDimensions(outputTensor.Get()); ORT_THROW_HR_IF(E_FAIL, inputDims.size() != outputDims.size()); - ComPtr inputUnknown; - ComPtr inputResource; - inputTensor->GetDataInterface(inputUnknown.GetAddressOf()); - ORT_THROW_IF_FAILED(inputUnknown.As(&inputResource)); - - ComPtr outputUnknown; - ComPtr outputResource; - outputTensor->GetDataInterface(outputUnknown.GetAddressOf()); - ORT_THROW_IF_FAILED(outputUnknown.As(&outputResource)); - // Get optional dft_length input uint32_t dftLength = inputDims[onnxruntime::narrow(m_axis)]; ComPtr dftLengthTensor; @@ -436,12 +426,15 @@ class GpuDFTOperator : public WRL::Base dftLength = onnxruntime::narrow(OperatorHelper::ReadScalarTensorCastToInt64(tensor)); } + auto inputTensorWrapper = static_cast(inputTensor.Get()); + auto outputTensorWrapper = static_cast(outputTensor.Get()); + return Compute( commandList.Get(), context, - inputResource.Get(), + inputTensorWrapper->GetBufferRegion(), inputDims, - outputResource.Get(), + outputTensorWrapper->GetBufferRegion(), outputDims, dftLength ); @@ -457,16 +450,16 @@ class GpuDFTOperator : public WRL::Base HRESULT Compute( ID3D12GraphicsCommandList* commandList, IMLOperatorKernelContext* context, - ID3D12Resource* inputResource, + const Dml::D3D12BufferRegion& inputBufferRegion, gsl::span inputDims, - ID3D12Resource* outputResource, + const Dml::D3D12BufferRegion& outputBufferRegion, gsl::span outputDims, uint32_t dftLength ) { try { - auto dftParams = PrepareDFT(context, inputResource, inputDims, outputResource, outputDims, dftLength); + auto dftParams = PrepareDFT(context, inputBufferRegion, inputDims, outputBufferRegion, outputDims, dftLength); switch (dftParams.Type) { @@ -509,9 +502,9 @@ class GpuDFTOperator : public WRL::Base void PrepareStockhamFFTParams( IMLOperatorKernelContext* context, - ID3D12Resource* inputResource, + const Dml::D3D12BufferRegion& inputBufferRegion, gsl::span inputDims, - ID3D12Resource* outputResource, + const Dml::D3D12BufferRegion& outputBufferRegion, gsl::span outputDims, uint32_t dftLength, int64_t inAxis, @@ -571,36 +564,34 @@ class GpuDFTOperator : public WRL::Base // Create the resource loop list // Add the input resource to the loop list params.ResourceLoopList.push_back({}); - params.ResourceLoopList.back().Resource = inputResource; + params.ResourceLoopList.back().BufferRegion = inputBufferRegion; params.ResourceLoopList.back().Sizes = reshapedInputSize; params.ResourceLoopList.back().Strides = reshapedInputStrides; + auto kernelContext = static_cast(context); + // If 1 temporary should be placed first, or multiple temporaries, then // Add a temp in the list if (oscillateFirstTemporaryThenOutput || oscillateBetweenTwoTemporaries) { params.ResourceLoopList.push_back({}); + params.ResourceLoopList.back().BufferRegion = kernelContext->AllocateDefaultBuffer(temporaryBufferByteSize); params.ResourceLoopList.back().Sizes = temporarySize; params.ResourceLoopList.back().Strides = temporaryStrides; - - auto& resource = params.ResourceLoopList.back().Resource; - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(temporaryBufferByteSize, &resource)); } // If 2 temps, add another if (oscillateBetweenTwoTemporaries) { params.ResourceLoopList.push_back({}); + params.ResourceLoopList.back().BufferRegion = kernelContext->AllocateDefaultBuffer(temporaryBufferByteSize); params.ResourceLoopList.back().Sizes = temporarySize; params.ResourceLoopList.back().Strides = temporaryStrides; - - auto& resource = params.ResourceLoopList.back().Resource; - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(temporaryBufferByteSize, &resource)); } // Add output resource params.ResourceLoopList.push_back({}); - params.ResourceLoopList.back().Resource = outputResource; + params.ResourceLoopList.back().BufferRegion = outputBufferRegion; params.ResourceLoopList.back().Sizes = reshapedOutputSize; params.ResourceLoopList.back().Strides = reshapedOutputStrides; params.OutputIndex = static_cast(params.ResourceLoopList.size() - 1); @@ -609,11 +600,9 @@ class GpuDFTOperator : public WRL::Base if (oscillateFirstOutputThenTemporary) { params.ResourceLoopList.push_back({}); + params.ResourceLoopList.back().BufferRegion = kernelContext->AllocateDefaultBuffer(temporaryBufferByteSize); params.ResourceLoopList.back().Sizes = temporarySize; params.ResourceLoopList.back().Strides = temporaryStrides; - - auto& resource = params.ResourceLoopList.back().Resource; - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(temporaryBufferByteSize, &resource)); } // Define the loop range @@ -622,16 +611,15 @@ class GpuDFTOperator : public WRL::Base if (oscillateFirstOutputThenTemporary) { params.LoopRange = { 1, 2, params.NumberOfPasses + 1 }; } if (oscillateFirstTemporaryThenOutput) { params.LoopRange = { 1, 2, params.NumberOfPasses + 1 }; } - params.Window.Resource = nullptr; params.Window.Sizes = std::array {0, 0, 0, 0}; params.Window.Strides = std::array {0, 0, 0, 0}; } DFTParameters PrepareDFT( IMLOperatorKernelContext* context, - ID3D12Resource* inputResource, + const Dml::D3D12BufferRegion& inputBufferRegion, gsl::span inputDims, - ID3D12Resource* outputResource, + const Dml::D3D12BufferRegion& outputBufferRegion, gsl::span outputDims, uint32_t dftLength ) @@ -647,9 +635,9 @@ class GpuDFTOperator : public WRL::Base params.Type = DFTType::Stockham; PrepareStockhamFFTParams( context, - inputResource, + inputBufferRegion, inputDims, - outputResource, + outputBufferRegion, outputDims, dftLength, m_axis, @@ -681,14 +669,17 @@ class GpuDFTOperator : public WRL::Base auto aIntermediateBufferByteSize = sizeof(float) * ComputeElementCountFromDimensions(params.BluesteinZChirpParams.AFFT.Sizes); auto bIntermediateBufferByteSize = sizeof(float) * ComputeElementCountFromDimensions(params.BluesteinZChirpParams.BFFT.Sizes); - auto& zChirpResource = params.BluesteinZChirpParams.ZChirp.Resource; - auto& aFFTResource = params.BluesteinZChirpParams.AFFT.Resource; - auto& bResource = params.BluesteinZChirpParams.B.Resource; - auto& bFFTResource = params.BluesteinZChirpParams.BFFT.Resource; - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(zChirpBufferByteSize, &zChirpResource)); - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(aIntermediateBufferByteSize, &aFFTResource)); - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(bIntermediateBufferByteSize, &bResource)); - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(bIntermediateBufferByteSize, &bFFTResource)); + auto& zChirpBufferRegion = params.BluesteinZChirpParams.ZChirp.BufferRegion; + auto& aFFTBufferRegion = params.BluesteinZChirpParams.AFFT.BufferRegion; + auto& bBufferRegion = params.BluesteinZChirpParams.B.BufferRegion; + auto& bFFTBufferRegion = params.BluesteinZChirpParams.BFFT.BufferRegion; + + auto kernelContext = static_cast(context); + + zChirpBufferRegion = kernelContext->AllocateDefaultBuffer(zChirpBufferByteSize); + aFFTBufferRegion = kernelContext->AllocateDefaultBuffer(aIntermediateBufferByteSize); + bBufferRegion = kernelContext->AllocateDefaultBuffer(bIntermediateBufferByteSize); + bFFTBufferRegion = kernelContext->AllocateDefaultBuffer(bIntermediateBufferByteSize); // The AFFT call takes input A, and produces output A_FFT. // @@ -699,8 +690,10 @@ class GpuDFTOperator : public WRL::Base // Padding should be handled by the shader. PrepareStockhamFFTParams( context, - inputResource, inputDims, - aFFTResource.Get(), params.BluesteinZChirpParams.AFFT.Sizes, + inputBufferRegion, + inputDims, + aFFTBufferRegion, + params.BluesteinZChirpParams.AFFT.Sizes, M, m_axis, 1, @@ -712,8 +705,10 @@ class GpuDFTOperator : public WRL::Base // Therefore the window function logic shold hangle complex multiplication, and B_FTT should be used like a window function. PrepareStockhamFFTParams( context, - aFFTResource.Get(), params.BluesteinZChirpParams.AFFT.Sizes, - outputResource, outputDims, + aFFTBufferRegion, + params.BluesteinZChirpParams.AFFT.Sizes, + outputBufferRegion, + outputDims, M, 1, m_axis, @@ -729,8 +724,10 @@ class GpuDFTOperator : public WRL::Base // The BFFT call takes input B, and produces output B_FFT. PrepareStockhamFFTParams( context, - bResource.Get(), params.BluesteinZChirpParams.B.Sizes, - bFFTResource.Get(), params.BluesteinZChirpParams.BFFT.Sizes, + bBufferRegion, + params.BluesteinZChirpParams.B.Sizes, + bFFTBufferRegion, + params.BluesteinZChirpParams.BFFT.Sizes, M, 2, 2, @@ -744,25 +741,23 @@ class GpuDFTOperator : public WRL::Base { const auto& bluesteinZChirpParams = dftParams.BluesteinZChirpParams; - // Get input and output resources - auto inputResource = bluesteinZChirpParams.AFFTParams.ResourceLoopList.front().Resource.Get(); - auto outputResource = bluesteinZChirpParams.AFFTInverseParams.ResourceLoopList[bluesteinZChirpParams.AFFTInverseParams.OutputIndex].Resource.Get(); - auto zChirpResource = bluesteinZChirpParams.ZChirp.Resource.Get(); - auto aFFTResource = bluesteinZChirpParams.AFFT.Resource.Get(); - auto bResource = bluesteinZChirpParams.B.Resource.Get(); - auto bFFTResource = bluesteinZChirpParams.BFFT.Resource.Get(); + // Get resources + auto inputBufferRegion = bluesteinZChirpParams.AFFTParams.ResourceLoopList.front().BufferRegion; + auto outputBufferRegion = bluesteinZChirpParams.AFFTInverseParams.ResourceLoopList[bluesteinZChirpParams.AFFTInverseParams.OutputIndex].BufferRegion; + auto zChirpBufferRegion = bluesteinZChirpParams.ZChirp.BufferRegion; + auto bBufferRegion = bluesteinZChirpParams.B.BufferRegion; // Transition resources from common to UAV state D3D12_RESOURCE_BARRIER barriers[2]; barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - inputResource, + inputBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS ); barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( - outputResource, + outputBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS ); @@ -781,8 +776,8 @@ class GpuDFTOperator : public WRL::Base auto totalElementCount = ComputeElementCountFromDimensions(bluesteinZChirpParams.B.Sizes); constants.ElementCount = totalElementCount / bluesteinZChirpParams.B.Sizes[3]; - std::array uav_resources = { zChirpResource, bResource }; - Dispatch(uav_resources, constants, commandList); + std::array uavBufferRegions = { zChirpBufferRegion, bBufferRegion }; + Dispatch(uavBufferRegions, constants, commandList); DFTParameters fft_params = {}; fft_params.Type = DFTType::Stockham; @@ -806,16 +801,16 @@ class GpuDFTOperator : public WRL::Base // Transition resources to common state barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - inputResource, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_COMMON - ); + inputBufferRegion.GetD3D12Resource(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( - outputResource, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_COMMON - ); + outputBufferRegion.GetD3D12Resource(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); commandList->ResourceBarrier(2, barriers); } @@ -833,21 +828,21 @@ class GpuDFTOperator : public WRL::Base const auto& loopList = stockhamParams.ResourceLoopList; // Get input and output resources - auto inputResource = loopList[0].Resource.Get(); - auto outputResource = loopList[stockhamParams.OutputIndex].Resource.Get(); - auto windowResource = dftParams.StockhamParams.Window.Resource.Get(); + auto inputBufferRegion = loopList[0].BufferRegion; + auto outputBufferRegion = loopList[stockhamParams.OutputIndex].BufferRegion; + auto windowBufferRegion = dftParams.StockhamParams.Window.BufferRegion; // Transition resources from common to UAV state D3D12_RESOURCE_BARRIER barriers[2]; barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - inputResource, + inputBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS ); barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( - outputResource, + outputBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS ); @@ -872,11 +867,11 @@ class GpuDFTOperator : public WRL::Base auto inIdx = stockhamParams.LoopRange.CalculateIndex(index); auto outIdx = stockhamParams.LoopRange.CalculateIndex(index + 1); - auto in = loopList[inIdx].Resource.Get(); + const auto& in = loopList[inIdx].BufferRegion; std::copy(loopList[inIdx].Sizes.begin(), loopList[inIdx].Sizes.end(), constants.InputSizes); std::copy(loopList[inIdx].Strides.begin(), loopList[inIdx].Strides.end(), constants.InputStrides); - auto out = loopList[outIdx].Resource.Get(); + const auto& out = loopList[outIdx].BufferRegion; std::copy(loopList[outIdx].Sizes.begin(), loopList[outIdx].Sizes.end(), constants.OutputSizes); std::copy(loopList[outIdx].Strides.begin(), loopList[outIdx].Strides.end(), constants.OutputStrides); @@ -890,24 +885,24 @@ class GpuDFTOperator : public WRL::Base constants.ElementCount = totalElementCount / constants.OutputSizes[3]; constants.DFTIteration = index + 1; constants.ChirpLength = isLastPass ? chirpLength : 0; - constants.HasWindow = isFirstPass && windowResource != nullptr; - auto window = constants.HasWindow ? windowResource : out; - std::array uav_resources = { in, out, window }; + constants.HasWindow = isFirstPass && windowBufferRegion.GetD3D12Resource() != nullptr; + auto window = constants.HasWindow ? windowBufferRegion : out; + std::array uav_resources = { in, out, window }; Dispatch(uav_resources, constants, commandList); } // Transition resources to common state barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - inputResource, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_COMMON - ); + inputBufferRegion.GetD3D12Resource(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( - outputResource, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_COMMON - ); + outputBufferRegion.GetD3D12Resource(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); commandList->ResourceBarrier(2, barriers); } @@ -922,25 +917,25 @@ class GpuDFTOperator : public WRL::Base template void Dispatch( - std::array& resources, + std::array& bufferRegions, TConstants& constants, ID3D12GraphicsCommandList* commandList) { D3D12_RESOURCE_BARRIER uav_barriers[TSize]; std::transform( - resources.begin(), resources.end(), + bufferRegions.begin(), bufferRegions.end(), uav_barriers, - [](auto& resource) { return CD3DX12_RESOURCE_BARRIER::UAV(resource); } ); + [](auto& bufferRegion) { return CD3DX12_RESOURCE_BARRIER::UAV(bufferRegion.GetD3D12Resource()); } ); commandList->ResourceBarrier(TSize, uav_barriers); for (uint32_t i = 0; i < TSize; i++) { // Set resource views - if (resources[i]) { + if (bufferRegions[i]) { commandList->SetComputeRootUnorderedAccessView( i, // root parameter index - resources[i]->GetGPUVirtualAddress() + bufferRegions[i].GetD3D12Resource()->GetGPUVirtualAddress() + bufferRegions[i].Offset() ); } else @@ -982,7 +977,7 @@ class GpuDFTOperator : public WRL::Base commandList->Dispatch(dispatchSizeX, 1, 1); } - commandList->ResourceBarrier(2, uav_barriers); + commandList->ResourceBarrier(TSize, uav_barriers); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h index 5ba936ddf3976..248e2e46bc0f7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h @@ -4,6 +4,7 @@ #include "../MLOperatorAuthorImpl.h" #include "../External/D3DX12/d3dx12.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlBufferRegion.h" #include "directx/d3d12.h" // NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback @@ -259,15 +260,6 @@ class DmlGridSampleOperator : public WRL::Base ComPtr m_gridSamplePipelineState; DmlGridSampleParameters m_params = {}; - - // Allocate temporary buffers if needed - struct ResourceDesc - { - ComPtr Resource; - std::array Sizes; - std::array Strides; - }; - struct GridSampleShaderConstants { uint32_t StartIndex; @@ -488,29 +480,18 @@ class DmlGridSampleOperator : public WRL::Base auto gridDims = GetTensorDimensions(gridTensor.Get()); auto outputDims = GetTensorDimensions(outputTensor.Get()); - ComPtr inputUnknown; - ComPtr inputResource; - inputTensor->GetDataInterface(inputUnknown.GetAddressOf()); - ORT_THROW_IF_FAILED(inputUnknown.As(&inputResource)); - - ComPtr gridUnknown; - ComPtr gridResource; - gridTensor->GetDataInterface(gridUnknown.GetAddressOf()); - ORT_THROW_IF_FAILED(gridUnknown.As(&gridResource)); - - ComPtr outputUnknown; - ComPtr outputResource; - outputTensor->GetDataInterface(outputUnknown.GetAddressOf()); - ORT_THROW_IF_FAILED(outputUnknown.As(&outputResource)); + auto inputTensorWrapper = static_cast(inputTensor.Get()); + auto gridTensorWrapper = static_cast(gridTensor.Get()); + auto outputTensorWrapper = static_cast(outputTensor.Get()); return Compute( commandList.Get(), context, - inputResource.Get(), + inputTensorWrapper->GetBufferRegion(), inputDims, - gridResource.Get(), + gridTensorWrapper->GetBufferRegion(), gridDims, - outputResource.Get(), + outputTensorWrapper->GetBufferRegion(), outputDims ); } @@ -525,21 +506,21 @@ class DmlGridSampleOperator : public WRL::Base HRESULT Compute( ID3D12GraphicsCommandList* commandList, IMLOperatorKernelContext* context, - ID3D12Resource* inputResource, + const Dml::D3D12BufferRegion& inputBufferRegion, gsl::span inputDims, - ID3D12Resource* gridResource, + const Dml::D3D12BufferRegion& gridBufferRegion, gsl::span gridDims, - ID3D12Resource* outputResource, + const Dml::D3D12BufferRegion& outputBufferRegion, gsl::span outputDims) { try { GridSample( - inputResource, + inputBufferRegion, inputDims, - gridResource, + gridBufferRegion, gridDims, - outputResource, + outputBufferRegion, outputDims, commandList); } @@ -552,11 +533,11 @@ class DmlGridSampleOperator : public WRL::Base } void GridSample( - ID3D12Resource* inputResource, + const Dml::D3D12BufferRegion& inputBufferRegion, gsl::span inputDims, - ID3D12Resource* gridResource, + const Dml::D3D12BufferRegion& gridBufferRegion, gsl::span gridDims, - ID3D12Resource* outputResource, + const Dml::D3D12BufferRegion& outputBufferRegion, gsl::span outputDims, ID3D12GraphicsCommandList* commandList) { @@ -571,27 +552,23 @@ class DmlGridSampleOperator : public WRL::Base D3D12_RESOURCE_BARRIER barriers[3]; barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - inputResource, + inputBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS ); barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( - gridResource, + gridBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS ); barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition( - outputResource, + outputBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS ); - inputResource->SetName(L"InputResource"); - outputResource->SetName(L"OutputResource"); - gridResource->SetName(L"GridResource"); - commandList->ResourceBarrier(3, barriers); // Set the root signature and pipeline state @@ -612,27 +589,27 @@ class DmlGridSampleOperator : public WRL::Base std::copy(outputStrides.begin(), outputStrides.end(), constants.OutputStrides); constants.ElementCount = ComputeElementCountFromDimensions(constants.OutputSizes); - std::array uav_resources = { inputResource, gridResource, outputResource }; - Dispatch(uav_resources, constants, commandList); + std::array uavBufferRegions = { inputBufferRegion, gridBufferRegion, outputBufferRegion }; + Dispatch(uavBufferRegions, constants, commandList); // Transition resources to common state barriers[0] = CD3DX12_RESOURCE_BARRIER::Transition( - inputResource, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_COMMON - ); + inputBufferRegion.GetD3D12Resource(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); barriers[1] = CD3DX12_RESOURCE_BARRIER::Transition( - gridResource, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_COMMON - ); + gridBufferRegion.GetD3D12Resource(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); barriers[2] = CD3DX12_RESOURCE_BARRIER::Transition( - outputResource, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_COMMON - ); + outputBufferRegion.GetD3D12Resource(), + D3D12_RESOURCE_STATE_UNORDERED_ACCESS, + D3D12_RESOURCE_STATE_COMMON + ); commandList->ResourceBarrier(3, barriers); } @@ -647,25 +624,25 @@ class DmlGridSampleOperator : public WRL::Base template void Dispatch( - std::array& resources, + std::array& bufferRegions, TConstants& constants, ID3D12GraphicsCommandList* commandList) { D3D12_RESOURCE_BARRIER uav_barriers[TSize]; std::transform( - resources.begin(), resources.end(), + bufferRegions.begin(), bufferRegions.end(), uav_barriers, - [](auto& resource) { return CD3DX12_RESOURCE_BARRIER::UAV(resource); } ); + [](auto& bufferRegion) { return CD3DX12_RESOURCE_BARRIER::UAV(bufferRegion.GetD3D12Resource()); } ); commandList->ResourceBarrier(TSize, uav_barriers); for (uint32_t i = 0; i < TSize; i++) { // Set resource views - if (resources[i]) { + if (bufferRegions[i]) { commandList->SetComputeRootUnorderedAccessView( i, // root parameter index - resources[i]->GetGPUVirtualAddress() + bufferRegions[i].GetD3D12Resource()->GetGPUVirtualAddress() + bufferRegions[i].Offset() ); } else @@ -707,7 +684,7 @@ class DmlGridSampleOperator : public WRL::Base commandList->Dispatch(dispatchSizeX, 1, 1); } - commandList->ResourceBarrier(2, uav_barriers); + commandList->ResourceBarrier(TSize, uav_barriers); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index 287f1e5b6dfe7..6bc5ec052420d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -3,6 +3,8 @@ #include "precomp.h" #include "DmlOperator.h" +#include "../DmlManagedBuffer.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlAllocatorRoundingMode.h" namespace Dml { @@ -101,13 +103,9 @@ namespace Dml UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) { - ORT_THROW_IF_FAILED(m_executionProvider->AllocatePooledResource( - static_cast(persistentResourceSize), - AllocatorRoundingMode::Enabled, - m_persistentResource.GetAddressOf(), - m_persistentResourcePoolingUnk.GetAddressOf())); - - m_persistentResourceBinding = DML_BUFFER_BINDING{ m_persistentResource.Get(), 0, persistentResourceSize }; + auto buffer = m_executionProvider->AllocatePooledResource(persistentResourceSize, AllocatorRoundingMode::Enabled); + m_persistentResourceBinding = buffer.GetBufferBinding(); + m_managedPersistentBuffer = wil::MakeOrThrow(std::move(buffer)); } std::vector initializationInputBindings(m_kernelInputIndices.size()); @@ -205,13 +203,9 @@ namespace Dml UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) { - ORT_THROW_IF_FAILED(m_executionProvider->AllocatePooledResource( - static_cast(persistentResourceSize), - AllocatorRoundingMode::Enabled, - m_persistentResource.GetAddressOf(), - m_persistentResourcePoolingUnk.GetAddressOf())); - - m_persistentResourceBinding = DML_BUFFER_BINDING{ m_persistentResource.Get(), 0, persistentResourceSize }; + auto buffer = m_executionProvider->AllocatePooledResource(persistentResourceSize, AllocatorRoundingMode::Enabled); + m_persistentResourceBinding = buffer.GetBufferBinding(); + m_managedPersistentBuffer = wil::MakeOrThrow(std::move(buffer)); } std::vector initializationInputBindings(m_kernelInputIndices.size()); @@ -239,17 +233,12 @@ namespace Dml UINT64 persistentResourceSize = m_compiledOperator->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) { - if (!m_persistentResource || m_persistentResource->GetDesc().Width < persistentResourceSize) + if (!m_managedPersistentBuffer || m_managedPersistentBuffer->SizeInBytes() < persistentResourceSize) { - m_persistentResource = nullptr; - ORT_THROW_IF_FAILED(m_executionProvider->AllocatePooledResource( - static_cast(persistentResourceSize), - AllocatorRoundingMode::Enabled, - m_persistentResource.GetAddressOf(), - m_persistentResourcePoolingUnk.GetAddressOf())); + auto buffer = m_executionProvider->AllocatePooledResource(persistentResourceSize, AllocatorRoundingMode::Enabled); + m_persistentResourceBinding = buffer.GetBufferBinding(); + m_managedPersistentBuffer = wil::MakeOrThrow(std::move(buffer)); } - - m_persistentResourceBinding = DML_BUFFER_BINDING{ m_persistentResource.Get(), 0, persistentResourceSize }; } ORT_THROW_IF_FAILED(m_executionProvider->InitializeOperator( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h index fa54d4b041b5f..db80d998a4102 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.h @@ -4,6 +4,7 @@ #pragma once #include "OperatorUtility.h" +#include "../DmlManagedBuffer.h" namespace Dml { @@ -25,8 +26,7 @@ namespace Dml std::vector m_outputTensorDescs; ComPtr m_compiledOperator; - ComPtr m_persistentResource; - ComPtr m_persistentResourcePoolingUnk; // Controls when the persistent resource is returned to the pool + ComPtr m_managedPersistentBuffer; std::optional m_persistentResourceBinding; void Initialize( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp index af983b26772d9..96fec218ed87e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp @@ -16,7 +16,7 @@ class DmlOperatorCopy : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 1); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - std::vector> kernelInputOutputIndices = {0}; + std::vector> kernelInputOutputIndices = {0}; Initialize(kernelInfo, kernelInputOutputIndices); @@ -29,30 +29,39 @@ class DmlOperatorCopy : public DmlOperator ComPtr contextPrivate; ORT_THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf())); - if (contextPrivate->IsDmlGraphNode()) - { - std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + // Although we always compile the operator because we don't know where the memory will be allocated in the future, + // we may not always end up executing it. + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); - DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {}; - opDesc.InputTensor = inputDescs.data(); - opDesc.OutputTensor = outputDescs.data(); + DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = inputDescs.data(); + opDesc.OutputTensor = outputDescs.data(); - SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc }, kernelInfo); - } + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_IDENTITY, &opDesc }, kernelInfo); } - void Compute(const MLOperatorKernelContext& kernelContext) + void Compute(const MLOperatorKernelContext& kernelContext) final { MLOperatorTensor inputTensor = kernelContext.GetInputTensor(0); - - // Reshape the output tensor. MLOperatorTensor outputTensor = kernelContext.GetOutputTensor(0); - // Avoid self copying. - if (inputTensor.GetDataInterface().Get() != outputTensor.GetDataInterface().Get()) + // If the input is aliasing the output (i.e. they share the same resource at the same offset), + // we don't need to do anything. This is essentially a no-op. + if (inputTensor.GetByteData() == outputTensor.GetByteData()) + { + return; + } + + // If the input is not aliasing the output but shares the same resource, we have to use an Identity operation + // because the resource cannot simultaneously be in both the COPY_SOURCE and COPY_DEST states. + if (inputTensor.GetDataInterface().Get() == outputTensor.GetDataInterface().Get()) + { + DmlOperator::Compute(kernelContext); + } + else { - // Copy elements from input tensor to output tensor. + // The input and the output don't share the same resource, so we can do a simple copy. ORT_THROW_IF_FAILED(m_executionProvider->CopyTensor( outputTensor.GetInterface().Get(), inputTensor.GetInterface().Get())); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h index e2f38231f7295..b3107dd8d16e6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h @@ -1,6 +1,7 @@ #pragma once #include "DmlDFT.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlAllocatorRoundingMode.h" // NOTE: When this operator's implementation is moved into DML, the associated FP16 fallback // should be removed from IsCustomOpShader(...) in @@ -106,7 +107,7 @@ struct DmlSTFTParameters namespace DmlSTFTHelpers { - ComPtr GetResourceFromKernelContext(IMLOperatorKernelContext* context, uint32_t index, bool isInput) + Dml::D3D12BufferRegion GetBufferRegionFromKernelContext(IMLOperatorKernelContext* context, uint32_t index, bool isInput) { ComPtr tensor; if (isInput) @@ -118,23 +119,18 @@ namespace DmlSTFTHelpers ORT_THROW_IF_FAILED(context->GetOutputTensor(index, &tensor)); } - ComPtr dataInterface; - tensor->GetDataInterface(&dataInterface); - - ComPtr resource; - ORT_THROW_IF_FAILED(dataInterface.As(&resource)); - - return resource; + auto tensorWrapper = static_cast(tensor.Get()); + return tensorWrapper->GetBufferRegion(); } - ComPtr GetInputResourceFromKernelContext(IMLOperatorKernelContext* context, uint32_t index) + Dml::D3D12BufferRegion GetInputBufferRegionFromKernelContext(IMLOperatorKernelContext* context, uint32_t index) { - return GetResourceFromKernelContext(context, index, true); + return GetBufferRegionFromKernelContext(context, index, true); } - ComPtr GetOutputResourceFromKernelContext(IMLOperatorKernelContext* context, uint32_t index) + Dml::D3D12BufferRegion GetOutputBufferRegionFromKernelContext(IMLOperatorKernelContext* context, uint32_t index) { - return GetResourceFromKernelContext(context, index, false); + return GetBufferRegionFromKernelContext(context, index, false); } } @@ -197,9 +193,7 @@ class DmlSTFTOperator : public WRL::Base ComPtr descriptorHeap; ComPtr bindingTable; ComPtr commandRecorder; - ComPtr persistentResource; - ComPtr persistentResourcePoolingUnk; - std::optional persistentResourceBinding; + std::optional persistentBufferRegion; bool hasWindowTensor = false; uint64_t signalBufferSizeInBytes = 0; uint64_t windowBufferSizeInBytes = 0; @@ -320,28 +314,31 @@ class DmlSTFTOperator : public WRL::Base // Initialize { + std::vector initializationInputBindings(params.hasWindowTensor ? 2 : 1); + uint64_t persistentResourceSize = m_framingOperator.op->GetBindingProperties().PersistentResourceSize; if (persistentResourceSize > 0) { - ORT_THROW_IF_FAILED(m_dmlProvider->AllocatePooledResource( - static_cast(persistentResourceSize), - AllocatorRoundingMode::Enabled, - m_framingOperator.persistentResource.GetAddressOf(), - m_framingOperator.persistentResourcePoolingUnk.GetAddressOf())); - - m_framingOperator.persistentResourceBinding = DML_BUFFER_BINDING{ - m_framingOperator.persistentResource.Get(), - 0, - persistentResourceSize - }; + m_framingOperator.persistentBufferRegion = m_dmlProvider->AllocatePooledResource( + persistentResourceSize, + Dml::AllocatorRoundingMode::Enabled); + auto binding = m_framingOperator.persistentBufferRegion->GetBufferBinding(); + ORT_THROW_IF_FAILED(m_dmlProvider->InitializeOperator( + m_framingOperator.op.Get(), + &binding, + gsl::make_span(initializationInputBindings) + )); + } + else + { + ORT_THROW_IF_FAILED(m_dmlProvider->InitializeOperator( + m_framingOperator.op.Get(), + nullptr, + gsl::make_span(initializationInputBindings) + )); } - std::vector initializationInputBindings(params.hasWindowTensor ? 2 : 1); - ORT_THROW_IF_FAILED(m_dmlProvider->InitializeOperator( - m_framingOperator.op.Get(), - m_framingOperator.persistentResourceBinding ? &*m_framingOperator.persistentResourceBinding : nullptr, - gsl::make_span(initializationInputBindings) - )); + } auto execBindingProps = m_framingOperator.op->GetBindingProperties(); @@ -374,11 +371,12 @@ class DmlSTFTOperator : public WRL::Base ComPtr commandList; ORT_THROW_IF_FAILED(executionObject.As(&commandList)); - ComPtr framingOutputResource; - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(onnxruntime::narrow(m_framingOperator.outputBufferSizeInBytes), &framingOutputResource)); - DispatchFramingOperator(commandList.Get(), context, framingOutputResource.Get()); + auto framingOutputBuffer = m_dmlProvider->AllocatePooledResource( + m_framingOperator.outputBufferSizeInBytes, + Dml::AllocatorRoundingMode::Enabled); + DispatchFramingOperator(commandList.Get(), context, framingOutputBuffer.Region()); - ComPtr outputResource = DmlSTFTHelpers::GetOutputResourceFromKernelContext(context, 0); + Dml::D3D12BufferRegion outputBufferRegion = DmlSTFTHelpers::GetOutputBufferRegionFromKernelContext(context, 0); D3D12_RESOURCE_BARRIER uavBarrier = { CD3DX12_RESOURCE_BARRIER::UAV(nullptr) }; commandList->ResourceBarrier(1, &uavBarrier); @@ -386,9 +384,9 @@ class DmlSTFTOperator : public WRL::Base return m_dftOperator.op->Compute( commandList.Get(), context, - framingOutputResource.Get(), + framingOutputBuffer.Region(), m_dftOperator.inputDims, - outputResource.Get(), + outputBufferRegion, m_dftOperator.outputDims, m_dftOperator.dftLength ); @@ -401,7 +399,7 @@ class DmlSTFTOperator : public WRL::Base return S_OK; } - void DispatchFramingOperator(ID3D12GraphicsCommandList* commandList, IMLOperatorKernelContext* context, ID3D12Resource* outputResource) + void DispatchFramingOperator(ID3D12GraphicsCommandList* commandList, IMLOperatorKernelContext* context, const Dml::D3D12BufferRegion& outputBufferRegion) { ID3D12DescriptorHeap* descriptorHeaps[] = { m_framingOperator.descriptorHeap.Get() }; commandList->SetDescriptorHeaps(_countof(descriptorHeaps), descriptorHeaps); @@ -417,38 +415,34 @@ class DmlSTFTOperator : public WRL::Base D3D12_RESOURCE_BARRIER barriers[3]; uint32_t barrierCount = 0; - ComPtr signalResource = DmlSTFTHelpers::GetInputResourceFromKernelContext(context, DmlSTFTKernelInputIndex::Signal); - inputBuffers[0] = { signalResource.Get(), 0, m_framingOperator.signalBufferSizeInBytes }; + Dml::D3D12BufferRegion signalBufferRegion = DmlSTFTHelpers::GetInputBufferRegionFromKernelContext(context, DmlSTFTKernelInputIndex::Signal); + inputBuffers[0] = signalBufferRegion.GetBufferBinding(); inputBindings[0] = { DML_BINDING_TYPE_BUFFER, &inputBuffers[0] }; - barriers[barrierCount++] = CD3DX12_RESOURCE_BARRIER::Transition(signalResource.Get(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + barriers[barrierCount++] = CD3DX12_RESOURCE_BARRIER::Transition(signalBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); - ComPtr windowResource; + Dml::D3D12BufferRegion windowBufferRegion; if (m_framingOperator.hasWindowTensor) { - windowResource = DmlSTFTHelpers::GetInputResourceFromKernelContext(context, DmlSTFTKernelInputIndex::Window); - inputBuffers[1] = { windowResource.Get(), 0, m_framingOperator.windowBufferSizeInBytes }; + windowBufferRegion = DmlSTFTHelpers::GetInputBufferRegionFromKernelContext(context, DmlSTFTKernelInputIndex::Window); + inputBuffers[1] = windowBufferRegion.GetBufferBinding(); inputBindings[1] = { DML_BINDING_TYPE_BUFFER, &inputBuffers[1] }; - barriers[barrierCount++] = CD3DX12_RESOURCE_BARRIER::Transition(windowResource.Get(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + barriers[barrierCount++] = CD3DX12_RESOURCE_BARRIER::Transition(windowBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); inputBindingsCount++; } m_framingOperator.bindingTable->BindInputs(inputBindingsCount, inputBindings.data()); - DML_BUFFER_BINDING outputBuffer = {}; - outputBuffer.Buffer = outputResource; - outputBuffer.SizeInBytes = m_framingOperator.outputBufferSizeInBytes; + DML_BUFFER_BINDING outputBuffer = outputBufferRegion.GetBufferBinding(); DML_BINDING_DESC outputBinding = { DML_BINDING_TYPE_BUFFER, &outputBuffer }; - barriers[barrierCount++] = CD3DX12_RESOURCE_BARRIER::Transition(outputResource, D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); + barriers[barrierCount++] = CD3DX12_RESOURCE_BARRIER::Transition(outputBufferRegion.GetD3D12Resource(), D3D12_RESOURCE_STATE_COMMON, D3D12_RESOURCE_STATE_UNORDERED_ACCESS); m_framingOperator.bindingTable->BindOutputs(1, &outputBinding); - ComPtr tempBuffer; auto tempBufferSize = bindingProps.TemporaryResourceSize; if (tempBufferSize > 0) { - ORT_THROW_IF_FAILED(context->AllocateTemporaryData(onnxruntime::narrow(tempBufferSize), &tempBuffer)); - - DML_BUFFER_BINDING bufferBinding = { tempBuffer.Get(), 0, tempBufferSize }; + auto buffer = m_dmlProvider->AllocatePooledResource(tempBufferSize, Dml::AllocatorRoundingMode::Enabled); + DML_BUFFER_BINDING bufferBinding = buffer.GetBufferBinding(); DML_BINDING_DESC bindingDesc = { DML_BINDING_TYPE_BUFFER, &bufferBinding }; m_framingOperator.bindingTable->BindTemporaryResource(&bindingDesc); } @@ -456,8 +450,9 @@ class DmlSTFTOperator : public WRL::Base auto persistentBufferSize = bindingProps.PersistentResourceSize; if (persistentBufferSize > 0) { - DML_BUFFER_BINDING bufferBinding = { m_framingOperator.persistentResource.Get(), 0, persistentBufferSize }; - DML_BINDING_DESC bindingDesc = { DML_BINDING_TYPE_BUFFER, &bufferBinding }; + assert(m_framingOperator.persistentBufferRegion.has_value()); + auto persistentResourceBinding = m_framingOperator.persistentBufferRegion->GetBufferBinding(); + DML_BINDING_DESC bindingDesc = { DML_BINDING_TYPE_BUFFER, &persistentResourceBinding }; m_framingOperator.bindingTable->BindPersistentResource(&bindingDesc); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp index 375ee87bd42f1..4c9fec7ced95a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.cpp @@ -99,13 +99,19 @@ namespace Dml auto heap = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD); auto buffer = CD3DX12_RESOURCE_DESC::Buffer(sizeInBytes); - ORT_THROW_IF_FAILED(device->CreateCommittedResource( + HRESULT hr = device->CreateCommittedResource( &heap, D3D12_HEAP_FLAG_NONE, &buffer, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, - IID_GRAPHICS_PPV_ARGS(uploadBuffer.ReleaseAndGetAddressOf()))); + IID_GRAPHICS_PPV_ARGS(uploadBuffer.ReleaseAndGetAddressOf())); + + if (hr == DXGI_ERROR_DEVICE_REMOVED) + { + ORT_THROW_IF_FAILED(device->GetDeviceRemovedReason()); + } + ORT_THROW_IF_FAILED(hr); return Chunk{ sizeInBytes, std::move(uploadBuffer) }; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.cpp index 6147b3bf8665f..0be0792318b14 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.cpp @@ -104,11 +104,11 @@ namespace Dml void ReadbackHeap::ReadbackFromGpu( gsl::span dst, gsl::span dstSizes, - gsl::span src, + gsl::span srcBufferRegions, D3D12_RESOURCE_STATES srcState) { - assert(dst.size() == src.size()); - assert(dstSizes.size() == src.size()); + assert(dst.size() == srcBufferRegions.size()); + assert(dstSizes.size() == srcBufferRegions.size()); if (dst.empty()) { @@ -131,8 +131,8 @@ namespace Dml m_readbackHeap.Get(), offset, D3D12_RESOURCE_STATE_COPY_DEST, - src[i], - 0, + srcBufferRegions[i].GetD3D12Resource(), + srcBufferRegions[i].Offset(), srcState, dstSizes[i]); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.h index d641afdd818ba..c941ae1af6c73 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ReadbackHeap.h @@ -26,7 +26,7 @@ namespace Dml void ReadbackFromGpu( gsl::span dst, gsl::span dstSizes, - gsl::span src, + gsl::span srcBufferRegions, D3D12_RESOURCE_STATES srcState); private: diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index ac77616cb96f0..fdc204a5f4aac 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -517,15 +517,11 @@ class MLOperatorTensor // needing to agnostically copy memory. const void* GetByteData() const { - ML_CHECK_BOOL(!IsDataInterface()); - return m_impl->GetData(); } void* GetByteData() { - ML_CHECK_BOOL(!IsDataInterface()); - return m_impl->GetData(); } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 8a218470d30bb..6346ca87180d1 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -181,6 +181,8 @@ IMLOperatorRegistryPrivate : public IUnknown _In_reads_(aliasCount) const std::pair* aliases = nullptr, uint32_t aliasCount = 0 ) const noexcept PURE; + + STDMETHOD_(bool, HasExternalOperators)() const noexcept PURE; }; //! \interface IMLOperatorTensorShapeDescription1 diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index e8fe235fc1d46..98230668edde0 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -29,6 +29,8 @@ using Microsoft::WRL::ComPtr; #include "core/framework/error_code_helper.h" #include "DmlExecutionProvider/src/ErrorHandling.h" #include "DmlExecutionProvider/src/GraphicsUnknownHelper.h" +#include "DmlExecutionProvider/src/DmlAllocationInfo.h" +#include "DmlExecutionProvider/src/DmlTaggedPointer.h" #include "DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/platform/env.h" #include "core/providers/dml/dml_session_options_config_keys.h" @@ -62,8 +64,8 @@ struct DMLProviderFactory : IExecutionProviderFactory { ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; - void SetMetacommandsEnabled(bool metacommands_enabled); + void SetBfcAllocatorEnabled(bool bfc_allocator_enabled); private: ComPtr dml_device_{}; @@ -73,6 +75,7 @@ struct DMLProviderFactory : IExecutionProviderFactory { bool cpu_sync_spinning_enabled_ = false; bool disable_memory_arena_ = false; bool python_api_ = false; + bool bfc_allocator_enabled_ = true; }; std::unique_ptr DMLProviderFactory::CreateProvider() { @@ -93,7 +96,7 @@ std::unique_ptr DMLProviderFactory::CreateProvider() { execution_context = wil::MakeOrThrow(d3d12_device.Get(), dml_device_.Get(), cmd_queue_.Get(), cpu_sync_spinning_enabled_, false); } - auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), execution_context.Get(), metacommands_enabled_, graph_capture_enabled_, cpu_sync_spinning_enabled_, disable_memory_arena_, bfc_allocator_enabled_); return provider; } @@ -101,6 +104,10 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) { metacommands_enabled_ = metacommands_enabled; } +void DMLProviderFactory::SetBfcAllocatorEnabled(bool bfc_allocator_enabled) { + bfc_allocator_enabled_ = bfc_allocator_enabled; +} + std::shared_ptr CreateExecutionProviderFactory_DML(const ConfigOptions& config_options, IDMLDevice* dml_device, ID3D12CommandQueue* cmd_queue, @@ -132,6 +139,10 @@ void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* f dml_provider_factory->SetMetacommandsEnabled(metacommandsEnabled); } +void DmlConfigureProviderFactoryBfcAllocatorEnabled(IExecutionProviderFactory* factory, bool bfc_allocator_enabled) { + auto dml_provider_factory = static_cast(factory); + dml_provider_factory->SetBfcAllocatorEnabled(bfc_allocator_enabled); +} bool IsSoftwareAdapter(IDXGIAdapter1* adapter) { DXGI_ADAPTER_DESC1 desc; @@ -685,8 +696,15 @@ ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_alloc if (!allocator) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); } - *d3d_resource = Dml::GetD3D12ResourceFromAllocation(allocator.get(), allocation); + + // This should never happen since external users of the ORT API should only be able to create DML_EXTERNAL memory + if (wrapping_allocator->Info()->device.MemType() != OrtDevice::MemType::DML_EXTERNAL) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "The resource has been allocated with "); + } + + *d3d_resource = static_cast(allocation)->GetD3D12Resource(); (*d3d_resource)->AddRef(); + #else *d3d_resource = nullptr; #endif // USE_DML diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index 8fdac257297c1..f21bf82169f4d 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -24,13 +24,13 @@ #include "core/framework/provider_options_utils.h" #ifdef USE_DML +#include "core/providers/dml/DmlExecutionProvider/src/DmlExternalGpuAllocator.h" using Microsoft::WRL::ComPtr; #include #include "core/providers/dml/DmlExecutionProvider/src/External/D3DX12/d3dx12.h" #include "core/providers/dml/DmlExecutionProvider/src/ErrorHandling.h" #include "core/providers/dml/DmlExecutionProvider/src/DescriptorPool.h" -#include "core/providers/dml/DmlExecutionProvider/src/DmlCommittedResourceAllocator.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/providers/dml/DmlExecutionProvider/src/BucketizedBufferAllocator.h" #include "core/providers/dml/DmlExecutionProvider/src/PooledUploadHeap.h" @@ -208,46 +208,7 @@ AllocatorPtr GetDmlAllocator(OrtDevice::DeviceId id) { auto hit = id_to_allocator_map->find(id); if (hit == id_to_allocator_map->end()) { - constexpr uint32_t device_id = 0; - auto d3d12_device = onnxruntime::DMLProviderFactoryCreator::CreateD3D12Device(device_id, false); - - ComPtr context; - uint32_t execution_context_ptr_size = gsl::narrow_cast(sizeof(context.GetAddressOf())); - - // First, check if an I/O binding API that was used before this session or another session has already created a queue - if (FAILED(d3d12_device->GetPrivateData(dml_execution_context_guid, &execution_context_ptr_size, context.GetAddressOf()))) { - D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; - cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; - cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; - - ComPtr cmd_queue; - ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); - - auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device.Get()); - ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_device_guid, dml_device.Get())); - - context = wil::MakeOrThrow(d3d12_device.Get(), dml_device.Get(), cmd_queue.Get(), true, true); - ORT_THROW_IF_FAILED(d3d12_device->SetPrivateDataInterface(dml_execution_context_guid, context.Get())); - } - - // We leak the readback and upload heap to keep them alive, just like the map - auto readback_heap = std::make_unique(d3d12_device.Get(), context.Get()).release(); - auto upload_heap = std::make_unique(d3d12_device.Get(), context.Get()).release(); - - auto dml_allocator = std::make_shared( - d3d12_device.Get(), - context.Get(), - CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT), - D3D12_HEAP_FLAG_NONE, - D3D12_RESOURCE_FLAG_ALLOW_UNORDERED_ACCESS, - D3D12_RESOURCE_STATE_UNORDERED_ACCESS, - std::make_unique(d3d12_device.Get())); - dml_allocator->SetDefaultRoundingMode(AllocatorRoundingMode::Enabled); - context->SetAllocator(dml_allocator); - - ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(dml_readback_heap_guid, sizeof(readback_heap), &readback_heap)); - ORT_THROW_IF_FAILED(d3d12_device->SetPrivateData(dml_upload_heap_guid, sizeof(upload_heap), &upload_heap)); - + auto dml_allocator = std::make_shared(id); hit = id_to_allocator_map->emplace(id, std::move(dml_allocator)).first; } diff --git a/winml/adapter/winml_adapter_apis.h b/winml/adapter/winml_adapter_apis.h index 9542c3a03fefe..313bb20ff81be 100644 --- a/winml/adapter/winml_adapter_apis.h +++ b/winml/adapter/winml_adapter_apis.h @@ -82,7 +82,8 @@ ORT_API_STATUS( _In_ OrtSessionOptions* options, _In_ ID3D12Device* d3d_device, _In_ ID3D12CommandQueue* cmd_queue, - bool metacommands_enabled + bool metacommands_enabled, + bool bfc_allocator_enabled ); // OrtSession methods diff --git a/winml/adapter/winml_adapter_c_api.h b/winml/adapter/winml_adapter_c_api.h index 6c77664b92c37..0e76e2394aaac 100644 --- a/winml/adapter/winml_adapter_c_api.h +++ b/winml/adapter/winml_adapter_c_api.h @@ -294,7 +294,11 @@ struct WinmlAdapterApi { * This api is used to add the DML EP to OrtSessionOptions. */ OrtStatus*(ORT_API_CALL* OrtSessionOptionsAppendExecutionProvider_DML)( - _In_ OrtSessionOptions* options, ID3D12Device* device, ID3D12CommandQueue* queue, bool metacommands_enabled + _In_ OrtSessionOptions* options, + ID3D12Device* device, + ID3D12CommandQueue* queue, + bool metacommands_enabled, + bool bfc_allocator_enabled )NO_EXCEPTION; // OrtSession methods diff --git a/winml/adapter/winml_adapter_dml.cpp b/winml/adapter/winml_adapter_dml.cpp index 1b3ceed40cfad..0c4c451f4ed39 100644 --- a/winml/adapter/winml_adapter_dml.cpp +++ b/winml/adapter/winml_adapter_dml.cpp @@ -69,6 +69,7 @@ Microsoft::WRL::ComPtr CreateDmlDevice(ID3D12Device* d3d12Device) { namespace onnxruntime { void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled); +void DmlConfigureProviderFactoryBfcAllocatorEnabled(IExecutionProviderFactory* factory, bool bfc_allocator_enabled); } // namespace onnxruntime #endif // USE_DML @@ -78,7 +79,8 @@ ORT_API_STATUS_IMPL( _In_ OrtSessionOptions* options, _In_ ID3D12Device* d3d_device, _In_ ID3D12CommandQueue* queue, - bool metacommands_enabled + bool metacommands_enabled, + bool bfc_allocator_enabled ) { API_IMPL_BEGIN #ifdef USE_DML @@ -89,6 +91,7 @@ ORT_API_STATUS_IMPL( auto factory = options->provider_factories.back().get(); onnxruntime::DmlConfigureProviderFactoryMetacommandsEnabled(factory, metacommands_enabled); + onnxruntime::DmlConfigureProviderFactoryBfcAllocatorEnabled(factory, bfc_allocator_enabled); #endif // USE_DML return nullptr; API_IMPL_END diff --git a/winml/adapter/winml_adapter_execution_provider.cpp b/winml/adapter/winml_adapter_execution_provider.cpp index 52dbf9710abc7..0d3ae2f0d5ac4 100644 --- a/winml/adapter/winml_adapter_execution_provider.cpp +++ b/winml/adapter/winml_adapter_execution_provider.cpp @@ -51,7 +51,9 @@ ORT_API_STATUS_IMPL( auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); const auto execution_provider = reinterpret_cast(provider); OrtMemoryInfo mem_info( - "", OrtAllocatorType::OrtDeviceAllocator, execution_provider->GetOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault) + "", + OrtAllocatorType::OrtDeviceAllocator, + execution_provider->GetExternalOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault) ); auto allocator_ptr = inference_session->GetAllocator(mem_info); *allocator = new (std::nothrow) OrtAllocatorWrapper(allocator_ptr); @@ -66,7 +68,7 @@ ORT_API_STATUS_IMPL(winmla::GetProviderMemoryInfo, _In_ OrtExecutionProvider* pr API_IMPL_BEGIN const auto execution_provider = reinterpret_cast(provider); - auto device = execution_provider->GetOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault); + auto device = execution_provider->GetExternalOrtDeviceByMemType(::OrtMemType::OrtMemTypeDefault); *memory_info = new (std::nothrow) OrtMemoryInfo("", ::OrtAllocatorType::OrtDeviceAllocator, device); if (*memory_info == nullptr) { return OrtApis::CreateStatus(ORT_FAIL, "Out of memory"); diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp index 4997e07b037c8..9de5585e4ba78 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.cpp @@ -13,12 +13,17 @@ using namespace _winml; HRESULT OnnxruntimeDmlSessionBuilder::RuntimeClassInitialize( - OnnxruntimeEngineFactory* engine_factory, ID3D12Device* device, ID3D12CommandQueue* queue, bool metacommands_enabled + OnnxruntimeEngineFactory* engine_factory, + ID3D12Device* device, + ID3D12CommandQueue* queue, + bool metacommands_enabled, + bool bfc_allocator_enabled ) { engine_factory_ = engine_factory; device_.copy_from(device); queue_.copy_from(queue); metacommands_enabled_ = metacommands_enabled; + bfc_allocator_enabled_ = bfc_allocator_enabled; return S_OK; } @@ -45,7 +50,7 @@ OnnxruntimeDmlSessionBuilder::CreateSessionOptions(OrtSessionOptions** options) // Request the dml ep RETURN_HR_IF_NOT_OK_MSG( winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML( - session_options.get(), device_.get(), queue_.get(), metacommands_enabled_ + session_options.get(), device_.get(), queue_.get(), metacommands_enabled_, bfc_allocator_enabled_ ), ort_api ); diff --git a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h index ad9442fb98b83..3b1ade796d80f 100644 --- a/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h +++ b/winml/lib/Api.Ort/OnnxruntimeDmlSessionBuilder.h @@ -17,7 +17,8 @@ class OnnxruntimeDmlSessionBuilder OnnxruntimeEngineFactory* engine_factory, ID3D12Device* device, ID3D12CommandQueue* queue, - bool metacommands_enabled_ + bool metacommands_enabled, + bool bfc_allocator_enabled ); HRESULT STDMETHODCALLTYPE CreateSessionOptions(OrtSessionOptions** options) override; @@ -36,6 +37,7 @@ class OnnxruntimeDmlSessionBuilder winrt::com_ptr device_; winrt::com_ptr queue_; bool metacommands_enabled_ = true; + bool bfc_allocator_enabled_ = true; }; } // namespace _winml diff --git a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp index 5bb0ce424f66c..8fcaa773f8884 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngine.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngine.cpp @@ -656,10 +656,6 @@ HRESULT OnnxruntimeEngine::CreateTensorValueFromExternalD3DResource( ort_api ); - OrtAllocator* ort_allocator; - RETURN_HR_IF_NOT_OK_MSG(ort_api->CreateAllocator(session_.get(), ort_memory_info, &ort_allocator), ort_api); - auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator); - void* dml_allocator_resource; RETURN_HR_IF_NOT_OK_MSG( ort_dml_api->CreateGPUAllocationFromD3DResource(d3d_resource, &dml_allocator_resource), engine_factory_->UseOrtApi() diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp index 8ddcbe537ebd9..a055b1b02ef64 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.cpp @@ -32,7 +32,12 @@ STDMETHODIMP OnnxruntimeEngineBuilder::CreateEngine(_Outptr_ _winml::IEngine** o } else { #ifdef USE_DML RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize( - &onnxruntime_session_builder, engine_factory_.Get(), device_.Get(), queue_.Get(), metacommands_enabled_ + &onnxruntime_session_builder, + engine_factory_.Get(), + device_.Get(), + queue_.Get(), + metacommands_enabled_, + bfc_allocator_enabled_ )); #endif } @@ -107,6 +112,11 @@ STDMETHODIMP OnnxruntimeEngineBuilder::SetMetacommandsEnabled(int enabled) { return S_OK; } +STDMETHODIMP OnnxruntimeEngineBuilder::SetBfcAllocatorEnabled(int enabled) { + bfc_allocator_enabled_ = static_cast(enabled); + return S_OK; +} + STDMETHODIMP OnnxruntimeEngineBuilder::GetID3D12CommandQueue(_Outptr_ ID3D12CommandQueue** queue) { *queue = queue_.Get(); return S_OK; diff --git a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h index 99add42297f20..e1261272268df 100644 --- a/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h +++ b/winml/lib/Api.Ort/OnnxruntimeEngineBuilder.h @@ -16,6 +16,9 @@ class OnnxruntimeEngineBuilder STDMETHOD(SetMetacommandsEnabled) (int enabled); + STDMETHOD(SetBfcAllocatorEnabled) + (int enabled); + STDMETHOD(GetD3D12Device) (_Outptr_ ID3D12Device** device); @@ -49,6 +52,7 @@ class OnnxruntimeEngineBuilder Microsoft::WRL::ComPtr queue_ = nullptr; Microsoft::WRL::ComPtr thread_pool_ = nullptr; bool metacommands_enabled_ = true; + bool bfc_allocator_enabled_ = true; std::optional batch_size_override_; wfc::IMapView named_dimension_overrides_; std::optional intra_op_num_threads_override_; diff --git a/winml/lib/Api/LearningModelSession.cpp b/winml/lib/Api/LearningModelSession.cpp index 57bafda57fe54..45dd4624de38e 100644 --- a/winml/lib/Api/LearningModelSession.cpp +++ b/winml/lib/Api/LearningModelSession.cpp @@ -13,6 +13,7 @@ #include "LearningModelSessionOptions.h" #include "TensorFeatureDescriptor.h" #include "TelemetryEvent.h" +#include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h" #include "D3DDeviceCache.h" @@ -110,6 +111,12 @@ void LearningModelSession::Initialize() { WINML_THROW_IF_FAILED(engine_builder->SetD3D12Resources(device_impl->GetD3DDevice(), device_impl->GetDeviceQueue()) ); WINML_THROW_IF_FAILED(engine_builder->SetMetacommandsEnabled(device_impl->MetacommandsEnabled())); + + if (model_impl->GetOperatorRegistry()) { + winrt::com_ptr registryPrivate; + WINML_THROW_IF_FAILED(model_impl->GetOperatorRegistry()->QueryInterface(IID_PPV_ARGS(registryPrivate.put()))); + WINML_THROW_IF_FAILED(engine_builder->SetBfcAllocatorEnabled(!registryPrivate->HasExternalOperators())); + } } auto num_intra_op_threads = device_impl->NumberOfIntraOpThreads(); diff --git a/winml/lib/Common/inc/iengine.h b/winml/lib/Common/inc/iengine.h index b7dadcbdbc7ff..4451382114905 100644 --- a/winml/lib/Common/inc/iengine.h +++ b/winml/lib/Common/inc/iengine.h @@ -12,7 +12,7 @@ interface IEngineFactory; using Resource = std::unique_ptr>; // clang-format off -MIDL_INTERFACE("31f39226-cfe8-4758-af38-3d01b2a33ee1") +MIDL_INTERFACE("8ac0b6b9-4561-492b-b63d-a07bdd8292c6") IValue : IUnknown { STDMETHOD(IsEmpty) (bool* out) PURE; @@ -211,7 +211,7 @@ IThreading : IUnknown { }; -MIDL_INTERFACE("8ac0b6b9-4561-492b-b63d-a07bdd8292c6") +MIDL_INTERFACE("edf7b6d1-f788-4057-9f99-28f9b05360e8") IEngineBuilder : IUnknown { STDMETHOD(SetD3D12Resources) (ID3D12Device* device, ID3D12CommandQueue* queue) PURE; @@ -219,6 +219,9 @@ IEngineBuilder : IUnknown { STDMETHOD(SetMetacommandsEnabled) (int enabled) PURE; + STDMETHOD(SetBfcAllocatorEnabled) + (int enabled) PURE; + STDMETHOD(GetD3D12Device) (_Outptr_ ID3D12Device** device) PURE; diff --git a/winml/test/adapter/AdapterDmlEpTest.cpp b/winml/test/adapter/AdapterDmlEpTest.cpp index b4220650abb9c..2666d6cd19ff6 100644 --- a/winml/test/adapter/AdapterDmlEpTest.cpp +++ b/winml/test/adapter/AdapterDmlEpTest.cpp @@ -1,5 +1,5 @@ -// // Copyright (c) Microsoft Corporation. All rights reserved. -// // Licensed under the MIT License. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. #include "testPch.h" @@ -65,7 +65,7 @@ UniqueOrtSession CreateUniqueOrtSession( return UniqueOrtSession(session, ort_api->ReleaseSession); } -UniqueOrtSession CreateDmlSession() { +UniqueOrtSession CreateDmlSession(bool bfc_allocator_enabled) { const auto session_options = CreateUniqueOrtSessionOptions(); THROW_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), ort_api); @@ -79,9 +79,10 @@ UniqueOrtSession CreateDmlSession() { command_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; WINML_EXPECT_HRESULT_SUCCEEDED(device->CreateCommandQueue(&command_queue_desc, IID_PPV_ARGS(queue.put()))); + constexpr bool metacommands_enabled = false; THROW_IF_NOT_OK_MSG( winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML( - session_options.get(), device.get(), queue.get(), false + session_options.get(), device.get(), queue.get(), metacommands_enabled, bfc_allocator_enabled ), ort_api ); @@ -95,18 +96,22 @@ UniqueOrtSession CreateCpuSession() { void DmlExecutionProviderFlushContext() { GPUTEST; - auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), ort_api); + for (bool bfc_allocator_enabled : {false, true}) { + auto session = CreateDmlSession(bfc_allocator_enabled); + OrtExecutionProvider* ort_provider; + THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); + THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider), ort_api); + } } void DmlExecutionProviderReleaseCompletedReferences() { GPUTEST; - auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); - THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider), ort_api); + for (bool bfc_allocator_enabled : {false, true}) { + auto session = CreateDmlSession(bfc_allocator_enabled); + OrtExecutionProvider* ort_provider; + THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); + THROW_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider), ort_api); + } } constexpr std::array dimensions{1, 3, 720, 720}; @@ -168,27 +173,31 @@ void DmlGetD3D12ResourceFromAllocation() { void* gpu_allocation; THROW_IF_NOT_OK_MSG(ort_dml_api->CreateGPUAllocationFromD3DResource(d3d12_resource.get(), &gpu_allocation), ort_api); - auto session = CreateDmlSession(); - - OrtMemoryInfo* ort_memory_info; - THROW_IF_NOT_OK_MSG( - ort_api->CreateMemoryInfo( - "DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info - ), - ort_api - ); - - OrtAllocator* ort_allocator; - THROW_IF_NOT_OK_MSG(ort_api->CreateAllocator(session.get(), ort_memory_info, &ort_allocator), ort_api); - auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator); - - winrt::com_ptr d3d12_resource_from_allocation; - THROW_IF_NOT_OK_MSG( - ort_dml_api->GetD3D12ResourceFromAllocation(allocator.get(), gpu_allocation, d3d12_resource_from_allocation.put()), - ort_api - ); - // Ensure resource is the same - WINML_EXPECT_EQUAL(d3d12_resource, d3d12_resource_from_allocation); + for (bool bfc_allocator_enabled : {false, true}) { + auto session = CreateDmlSession(bfc_allocator_enabled); + + OrtMemoryInfo* ort_memory_info; + THROW_IF_NOT_OK_MSG( + ort_api->CreateMemoryInfo( + "DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info + ), + ort_api + ); + + OrtAllocator* ort_allocator; + THROW_IF_NOT_OK_MSG(ort_api->CreateAllocator(session.get(), ort_memory_info, &ort_allocator), ort_api); + auto allocator = UniqueOrtAllocator(ort_allocator, ort_api->ReleaseAllocator); + + winrt::com_ptr d3d12_resource_from_allocation; + THROW_IF_NOT_OK_MSG( + ort_dml_api->GetD3D12ResourceFromAllocation( + allocator.get(), gpu_allocation, d3d12_resource_from_allocation.put() + ), + ort_api + ); + // Ensure resource is the same + WINML_EXPECT_EQUAL(d3d12_resource, d3d12_resource_from_allocation); + } THROW_IF_NOT_OK_MSG(ort_dml_api->FreeGPUAllocation(gpu_allocation), ort_api); } @@ -212,28 +221,32 @@ UniqueOrtValue CreateTensorFromMemoryInfo(const OrtMemoryInfo* memory_info) { void GetTensorMemoryInfo() { GPUTEST; - auto session = CreateDmlSession(); - - OrtMemoryInfo* ort_memory_info; - THROW_IF_NOT_OK_MSG( - ort_api->CreateMemoryInfo( - "DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info - ), - ort_api - ); - auto tensor = CreateTensorFromMemoryInfo(ort_memory_info); - - const OrtMemoryInfo* value_memory_info; - THROW_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(tensor.get(), &value_memory_info), ort_api); - CreateTensorFromMemoryInfo(value_memory_info); + for (bool bfc_allocator_enabled : {false, true}) { + auto session = CreateDmlSession(bfc_allocator_enabled); + + OrtMemoryInfo* ort_memory_info; + THROW_IF_NOT_OK_MSG( + ort_api->CreateMemoryInfo( + "DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info + ), + ort_api + ); + auto tensor = CreateTensorFromMemoryInfo(ort_memory_info); + + const OrtMemoryInfo* value_memory_info; + THROW_IF_NOT_OK_MSG(ort_api->GetTensorMemoryInfo(tensor.get(), &value_memory_info), ort_api); + CreateTensorFromMemoryInfo(value_memory_info); + } } void ExecutionProviderSync() { GPUTEST; - auto session = CreateDmlSession(); - OrtExecutionProvider* ort_provider; - THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); - THROW_IF_NOT_OK_MSG(winml_adapter_api->ExecutionProviderSync(ort_provider), ort_api); + for (bool bfc_allocator_enabled : {false, true}) { + auto session = CreateDmlSession(bfc_allocator_enabled); + OrtExecutionProvider* ort_provider; + THROW_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session.get(), 0, &ort_provider), ort_api); + THROW_IF_NOT_OK_MSG(winml_adapter_api->ExecutionProviderSync(ort_provider), ort_api); + } } void DmlCopyTensor() { @@ -251,9 +264,11 @@ void DmlCopyTensor() { command_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; WINML_EXPECT_HRESULT_SUCCEEDED(device->CreateCommandQueue(&command_queue_desc, IID_PPV_ARGS(queue.put()))); + constexpr bool metacommands_enabled = false; + constexpr bool bfc_allocator_enabled = true; THROW_IF_NOT_OK_MSG( winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML( - session_options.get(), device.get(), queue.get(), false + session_options.get(), device.get(), queue.get(), metacommands_enabled, bfc_allocator_enabled ), ort_api ); @@ -315,41 +330,45 @@ void CreateCustomRegistry() { void ValueGetDeviceId() { GPUTEST; - auto session = CreateDmlSession(); - - OrtMemoryInfo* ort_memory_info; - THROW_IF_NOT_OK_MSG( - ort_api->CreateMemoryInfo( - "DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info - ), - ort_api - ); - auto gpu_tensor = CreateTensorFromMemoryInfo(ort_memory_info); - - int16_t device_id; - THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(gpu_tensor.get(), &device_id), ort_api); - - OrtMemoryInfo* cpu_memory_info; - THROW_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info), ort_api); - auto unique_cpu_memory_info = UniqueOrtMemoryInfo(cpu_memory_info, ort_api->ReleaseMemoryInfo); - auto cpu_tensor = CreateTensorFromMemoryInfo(unique_cpu_memory_info.get()); - THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(cpu_tensor.get(), &device_id), ort_api); - WINML_EXPECT_EQUAL(0, device_id); + for (bool bfc_allocator_enabled : {false, true}) { + auto session = CreateDmlSession(bfc_allocator_enabled); + + OrtMemoryInfo* ort_memory_info; + THROW_IF_NOT_OK_MSG( + ort_api->CreateMemoryInfo( + "DML", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemType::OrtMemTypeDefault, &ort_memory_info + ), + ort_api + ); + auto gpu_tensor = CreateTensorFromMemoryInfo(ort_memory_info); + + int16_t device_id; + THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(gpu_tensor.get(), &device_id), ort_api); + + OrtMemoryInfo* cpu_memory_info; + THROW_IF_NOT_OK_MSG(ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &cpu_memory_info), ort_api); + auto unique_cpu_memory_info = UniqueOrtMemoryInfo(cpu_memory_info, ort_api->ReleaseMemoryInfo); + auto cpu_tensor = CreateTensorFromMemoryInfo(unique_cpu_memory_info.get()); + THROW_IF_NOT_OK_MSG(winml_adapter_api->ValueGetDeviceId(cpu_tensor.get(), &device_id), ort_api); + WINML_EXPECT_EQUAL(0, device_id); + } } void SessionGetInputRequiredDeviceId() { GPUTEST; - auto session = CreateDmlSession(); - int16_t device_id; - THROW_IF_NOT_OK_MSG( - winml_adapter_api->SessionGetInputRequiredDeviceId(session.get(), "inputImage", &device_id), ort_api - ); - - auto cpu_session = CreateCpuSession(); - THROW_IF_NOT_OK_MSG( - winml_adapter_api->SessionGetInputRequiredDeviceId(cpu_session.get(), "inputImage", &device_id), ort_api - ); - WINML_EXPECT_EQUAL(0, device_id); + for (bool bfc_allocator_enabled : {false, true}) { + auto session = CreateDmlSession(bfc_allocator_enabled); + int16_t device_id; + THROW_IF_NOT_OK_MSG( + winml_adapter_api->SessionGetInputRequiredDeviceId(session.get(), "inputImage", &device_id), ort_api + ); + + auto cpu_session = CreateCpuSession(); + THROW_IF_NOT_OK_MSG( + winml_adapter_api->SessionGetInputRequiredDeviceId(cpu_session.get(), "inputImage", &device_id), ort_api + ); + WINML_EXPECT_EQUAL(0, device_id); + } } } // namespace diff --git a/winml/test/adapter/AdapterSessionTest.cpp b/winml/test/adapter/AdapterSessionTest.cpp index 8c9124b2ff4ae..015a9bdaf5012 100644 --- a/winml/test/adapter/AdapterSessionTest.cpp +++ b/winml/test/adapter/AdapterSessionTest.cpp @@ -103,9 +103,11 @@ void AppendExecutionProvider_DML() { const auto device = CreateD3DDevice(); const auto queue = CreateD3DQueue(device.get()); + constexpr bool metacommands_enabled = true; + constexpr bool bfc_allocator_enabled = true; THROW_IF_NOT_OK_MSG( winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML( - session_options.get(), device.get(), queue.get(), true + session_options.get(), device.get(), queue.get(), metacommands_enabled, bfc_allocator_enabled ), ort_api ); @@ -130,9 +132,11 @@ void GetExecutionProvider_DML() { THROW_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), ort_api); const auto device = CreateD3DDevice(); const auto queue = CreateD3DQueue(device.get()); + constexpr bool metacommands_enabled = true; + constexpr bool bfc_allocator_enabled = true; THROW_IF_NOT_OK_MSG( winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML( - session_options.get(), device.get(), queue.get(), true + session_options.get(), device.get(), queue.get(), metacommands_enabled, bfc_allocator_enabled ), ort_api ); @@ -290,9 +294,11 @@ void CopyInputAcrossDevices_DML() { THROW_IF_NOT_OK_MSG(ort_api->DisableMemPattern(session_options.get()), ort_api); const auto device = CreateD3DDevice(); const auto queue = CreateD3DQueue(device.get()); + constexpr bool metacommands_enabled = true; + constexpr bool bfc_allocator_enabled = true; THROW_IF_NOT_OK_MSG( winml_adapter_api->OrtSessionOptionsAppendExecutionProvider_DML( - session_options.get(), device.get(), queue.get(), true + session_options.get(), device.get(), queue.get(), metacommands_enabled, bfc_allocator_enabled ), ort_api );