-
Notifications
You must be signed in to change notification settings - Fork 156
Description
This RFC proposes a performance optimization for converting framework-specific tensors to the DLPack format (DLManagedTensorVersioned). The core idea is to cache the DLManagedTensorVersioned object directly within the source tensor object, leveraging the tensor's existing reference-counting mechanism. This approach aims to significantly reduce the overhead of repeated DLPack conversions, particularly for tensors that are frequently exchanged.
The DLPack standard is a critical component for enabling zero-copy tensor exchange between different deep learning frameworks. To facilitate this exchange, frameworks must convert their internal tensor representations into a DLManagedTensorVersioned structure. This conversion, which involves populating metadata such as shape, data pointer, and strides, can introduce a small but non-negligible overhead, typically in the range of 40-80 nanoseconds on the C++ side. While this latency is already quite low, frequent tensor exchanges—such as those involving model weights or intermediate values used multiple times—can accumulate this overhead. This RFC addresses the question of whether this overhead can be further reduced, particularly in scenarios where the same tensor is converted to DLPack multiple times during its lifetime.
Thread Safety
In C++ environment, different thread may concurrent write to the cached field and it is important to consider thread-safety, so only one cached value is written and returned to the user. Here is an updated version, at high-level:
- Different thread can race to create their own DLManagedTensorVersioned when they find cached field is nullptr
- Use atomic_compare_exchange_strong_explicit to ensure one of the value get stored and only store it when the cached field is nullptr
- Always return the stored value, and if the value is created by another thread, delete the current one and return the value created by another thread
Proposal
We propose an approach that integrates a caching mechanism directly into the framework's tensor object. The high-level concept is as follows:
- Cache Storage: A std::unique_ptr will be added as a member field to the framework's tensor object (e.g., TensorImpl). This modification requires a change to the framework's internal tensor structure.
- On-Demand Population: When the ToDLPack conversion method is called for the first time on a given tensor, the DLManagedTensorVersioned object will be created and populated. The framework's internal metadata will be transferred, and the manager_ctx of the DLManagedTensorVersioned will be set to point back to the TensorImpl's itself. The deleter will also be configured at this time.
- Ref counting integration
- To prevent the TensorImpl from being deallocated while a DLPack consumer holds a reference, a new reference will be added to the TensorImpl intrusive reference counter each time a DLManagedTensorVersioned is returned.
- The DLManagedTensorVersioned's deleter function will be configured to simply decrement the TensorImpl's reference counter. This ensures that the TensorImpl and its cached DLManagedTensorVersioned are not deallocated until all DLPack and internal references are released.
- Cache Reuse: For subsequent calls to ToDLPack on the same tensor object, the cached DLManagedTensorVersioned will be directly returned. The only overhead will be a pointer lookup and a reference count increment, which is an extremely fast operation, measured to be as low as 1 nanosecond in preliminary benchmarks.
Expected Benefits and Tradeoffs
- Significant Performance Improvement: This caching strategy can reduce the DLPack conversion overhead from 40-80ns to a mere 1ns for repeated conversions.
- Reduced Redundancy: Avoids repeated allocation and population of DLManagedTensorVersioned objects for the same tensor.
- Minimal Cost: The overhead of this approach is limited to one extra pointer field per tensor object, which is negligible given the typical size of tensor metadata and data.
Example Implementation
The following C++ code snippet illustrates the proposed mechanism within a hypothetical TensorObj class that uses intrusive reference counting.
#include <atomic>
// TensorImpl is a target of an intrusive ptr that contains a reference counter.
class TensorImpl : public intrusive_ptr_target<TensorImpl> {
public:
~TensorImpl() {
// deleting the cached dl managed tensor versioned
// need to acquire the value in case it is released by another thread
DLManagedTensorVersioned* cached = cached_dl_managed_tensor_.load(
std::memory_order_acquire);
if (cached != nullptr) {
delete cached;
}
}
/*!
* \brief Converts the current Tensor to a DLPack Tensor.
* \return The converted DLManagedTensorVersioned pointer.
*/
DLManagedTensorVersioned* ToDLPack() const {
// this function holds a strong reference to the TensorImpl
TensorImpl* self = const_cast<TensorImpl*>(this);
// we need to use acquire to ensure that write to DLManagedTensorVersioned
// from another thread is visible to this thread.
DLManagedTensorVersioned* cached = self->cached_dl_managed_tensor_.load(
std::memory_order_acquire);
if (cached == nullptr) {
// First time conversion: create and populate the DLManagedTensorVersioned.
// this creation may race among multiple threads, so we need to use atomic exchange.
DLManagedTensorVersioned* ret = new DLManagedTensorVersioned();
// Populate metadata (framework-specific logic).
// Assuming metadata_ is a DLTensor structure
PopulateMetadata(this, &(self->cached_dl_managed_tensor_->dl_tensor));
// Set the deleter to our custom function.
ret->deleter = DLManagedTensorDeleter;
// now set the cached_dl_managed_tensor_ field using CAS
// success set must release the new value to all other threads
// failure set must acquire, since the expected value is now coming
// from another thread that released this value
if (std::atomic_compare_exchange_strong_explicit(
&cached_dl_managed_tensor_versioned_, &expected, ret,
std::memory_order_release, std::memory_order_acquire)) {
// set is succes
cached = ret;
} else {
// delete the ret value as another thread raced to set this one first
// expected now contains the value that was set by another thread
delete ret;
cached = expected;
}
}
// at this point, cached is the value that officially set to the field
// Always increment the reference counter of the TensorImpl.
// This ensures the TensorImpl remains valid as long as the
// DLPack tensor is in use.
self->IncRef();
return cached;
}
private:
// Intrusive reference counter methods.
void IncRef();
void DecRef();
// Custom deleter function for DLPack.
static void DLManagedTensorDeleter(DLManagedTensorVersioned* tensor) {
// Cast the manager_ctx back to TensorImpl and decrement its reference counter.
// The TensorImpl and its embedded cache will be deallocated once all
// references (internal and DLPack) are released.
static_cast<TensorImpl*>(tensor->manager_ctx)->DecRef();
}
// Normal tensor metadata, e.g., shape, strides, data pointer.
// The cached DLManagedTensorVersioned object.
mutable std::atomic<DLManagedTensorVersioned*> cached_dl_managed_tensor_;
};