Skip to content

Commit

Permalink
Tensor list.
Browse files Browse the repository at this point in the history
Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Jan 27, 2025
1 parent 41b2ebc commit 5ebb4f8
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 17 deletions.
192 changes: 175 additions & 17 deletions dali/c_api_2/data_objects.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
#ifndef DALI_C_API_2_DATA_OBJECTS_H_
#define DALI_C_API_2_DATA_OBJECTS_H_

#include <cstdint>
#include <memory>
#include <optional>
#include "dali/dali.h"
#include "dali/pipeline/data/tensor_list.h"
#include "dali/c_api_2/ref_counting.h"

struct _DALITensorList {};

namespace dali {
namespace c_api {

class TensorListInterface : public RefCountedObject {
class TensorListInterface : public _DALITensorList, public RefCountedObject {
public:
virtual ~TensorListInterface() = default;

Expand All @@ -37,6 +41,7 @@ class TensorListInterface : public RefCountedObject {
int num_samples,
int ndim,
daliDataType_t dtype,
const char *layout,
const int64_t *shapes,
void *data,
const ptrdiff_t *sample_offsets,
Expand All @@ -46,6 +51,7 @@ class TensorListInterface : public RefCountedObject {
int num_samples,
int ndim,
daliDataType_t dtype,
const char *layout,
const daliTensorDesc_t *samples,
const daliDeleter_t *sample_deleters) = 0;

Expand All @@ -55,12 +61,14 @@ class TensorListInterface : public RefCountedObject {

virtual std::optional<cudaStream_t> GetStream() const = 0;

virtual std::optional<cudaEvent_t> GetReadyEvent() const() = 0;
virtual std::optional<cudaEvent_t> GetReadyEvent() const = 0;

virtual cudaEvent_t GetOrCreateReadyEvent() = 0;

static RefCountedPtr<TensorListInterface> Create(daliBufferPlacement_t placement);
};

struct TensorListDeleter {
struct BufferDeleter {
daliDeleter_t deleter;
AccessOrder deletion_order;

Expand All @@ -71,7 +79,7 @@ struct TensorListDeleter {
deletion_order.is_device() ? &stream : nullptr);
}
if (deleter.destroy_context) {
deleter.destroy_context(deleter.destroy_context);
deleter.destroy_context(deleter.deleter_ctx);
}
}
};
Expand All @@ -86,63 +94,213 @@ class TensorListWrapper : public TensorListInterface {
int ndim,
daliDataType_t dtype,
const int64_t *shapes) override {
tl_->Resize(TensorListShape<>(make_cspan(shapes, num_samples*ndim), num_samples, ndim), dtype);
std::vector<int64_t> shape_data(shapes, shapes + ndim * num_samples);
tl_->Resize(TensorListShape<>(shape_data, num_samples, ndim), dtype);
}

void AttachBuffer(
int num_samples,
int ndim,
daliDataType_t dtype,
const char *layout,
const int64_t *shapes,
void *data,
const ptrdiff_t *sample_offsets,
daliDeleter_t deleter) override {

if (num_samples < 0)
throw std::invalid_argument("The number of samples must not be negative.");
if (ndim < 0)
throw std::invalid_argument("The number of dimensions must not be negative.");
if (!shapes && ndim >= 0)
throw std::invalid_argument("The `shapes` are required for non-scalar (ndim>=0) samples.");
if (!data && num_samples > 0) {
for (int i = 0; i < num_samples; i++) {
auto sample_shape = make_cspan(&shapes[i*ndim], ndim);
if (volume(sample_shape) > 0)
throw std::invalid_argument(
"The pointer to the data buffer must not be null for a non-empty tensor list.");
if (sample_offsets && sample_offsets[i])
throw std::invalid_argument(
"All sample_offsets must be zero when the data pointer is NULL.");
}
}

TensorLayout new_layout = {};

if (!layout) {
if (ndim == tl_->sample_dim())
new_layout = tl_->GetLayout();
} else {
new_layout = layout;
if (new_layout.ndim() != ndim)
throw std::invalid_argument(make_string(
"The layout '", new_layout, "' cannot describe ", ndim, "-dimensional data."));
}

tl_->Reset();
tl_->SetSize(num_samples);
tl_->set_sample_dim(ndim);
ptridff_t next_offset = 0;
tl_->SetLayout(new_layout);
ptrdiff_t next_offset = 0;
auto type_info = TypeTable::GetTypeInfo(dtype);
auto element_size = type_info.size();
std::shared_ptr<void *> buffer;

std::shared_ptr<void> buffer;
if (!deleter.delete_buffer && !deleter.destroy_context) {
buffer.reset(buffer, [](void *){});
buffer = std::shared_ptr<void>(data, [](void *){});
} else {
buffer.reset(buffer, TensorListDeleter{deleter, order()});
buffer = std::shared_ptr<void>(data, BufferDeleter{deleter, tl_->order()});
}

for (int i = 0; i < num_samples; i++) {
TensorShape<> sample_shape(make_cspan(&shapes[i*ndim]. ndim));
TensorShape<> sample_shape(make_cspan(&shapes[i*ndim], ndim));
void *sample_data;
size_t sample_bytes = volume(sample_shape) * element_size;
if (sample_offsets) {
sample_data = static_cast<char *>(data) + sample_offsets[i];
} else {
sample_data = static_cast<char *>(data) + next_offset;
next_offset += volme(sample_shape) * element_size;
next_offset += sample_bytes;
}
tl_->SetSample(
i,
std::shared_ptr<void>(buffer, sample_data),
sample_bytes,
tl_->is_pinned(),
sample_shape,
dtype,
tl_->device_id(),
tl_->order(),
new_layout);
}
}

virtual void AttachSamples(
int num_samples,
int ndim,
daliDataType_t dtype,
const char *layout,
const daliTensorDesc_t *samples,
const daliDeleter_t *sample_deleters) {
if (num_samples < 0)
throw std::invalid_argument("The number of samples must not be negative.");
if (num_samples > 0 && !samples)
throw std::invalid_argument("The pointer to sample descriptors must not be NULL.");
if (ndim < 0) {
if (num_samples == 0)
throw std::invalid_argument(
"The number of dimensions must not be negative when num_samples is 0.");
else
ndim = samples[0].ndim;
}

for (int i = 0; i < num_samples; i++) {
if (samples[i].ndim != ndim)
throw std::invalid_argument(make_string(
"Invalid `ndim` at sample ", i, ": got ", samples[i].ndim, ", expected ", ndim, "."));
if (ndim && !samples[i].shape)
throw std::invalid_argument(make_string("Got NULL shape in sample ", i, "."));
if (!samples[i].data && volume(make_cspan(samples[i].shape, ndim)))
throw std::invalid_argument(make_string(
"Got NULL data pointer in a non-empty sample ", i, "."));
}

TensorLayout new_layout = {};

if (!layout) {
if (ndim == tl_->sample_dim())
new_layout = tl_->GetLayout();
} else {
new_layout = layout;
if (new_layout.ndim() != ndim)
throw std::invalid_argument(make_string(
"The layout '", new_layout, "' cannot describe ", ndim, "-dimensional data."));
}

tl_->Reset();
tl_->SetSize(num_samples);
tl_->set_sample_dim(ndim);
tl_->SetLayout(new_layout);

auto deletion_order = tl_->order();

auto type_info = TypeTable::GetTypeInfo(dtype);
auto element_size = type_info.size();
for (int i = 0; i < num_samples; i++) {
TensorShape<> sample_shape(make_cspan(samples[i].shape, samples[i].ndim));
size_t sample_bytes = volume(sample_shape) * element_size;
std::shared_ptr<void> sample_ptr;
if (sample_deleters) {
sample_ptr = std::shared_ptr<void>(
samples[i].data,
BufferDeleter{sample_deleters[i], deletion_order});
} else {
sample_ptr = std::shared_ptr<void>(samples[i].data, [](void*) {});
}

tl_->SetSample(
i,
sample_ptr,
sample_bytes,
tl_->is_pinned(),
sample_shape,
dtype,
tl_->device_id(),
tl_->order(),
new_layout);
}
}

virtual daliBufferPlacement_t GetBufferPlacement() const = 0;
daliBufferPlacement_t GetBufferPlacement() const override {
daliBufferPlacement_t placement;
placement.device_id = tl_->device_id();
StorageDevice dev = backend_to_storage_device<Backend>::value;
placement.device_type = static_cast<daliStorageDevice_t>(dev);
placement.pinned = tl_->is_pinned();
return placement;
}

virtual void SetStream(std::optional<cudaStream_t> stream, bool synchronize) = 0;
void SetStream(std::optional<cudaStream_t> stream, bool synchronize) override {
tl_->set_order(stream.has_value() ? AccessOrder(*stream) : AccessOrder::host(), synchronize);
}

virtual std::optional<cudaStream_t> GetStream() const = 0;
std::optional<cudaStream_t> GetStream() const override {
auto o = tl_->order();
if (o.is_device())
return o.stream();
else
return std::nullopt;
}

virtual std::optional<cudaEvent_t> GetReadyEvent() const() = 0;
std::optional<cudaEvent_t> GetReadyEvent() const override {
auto &e = tl_->ready_event();
if (e)
return e.get();
else
return std::nullopt;
}

virtual cudaEvent_t GetOrCreateReadyEvent() = 0;
cudaEvent_t GetOrCreateReadyEvent() override {
auto &e = tl_->ready_event();
if (e)
return e.get();
int device_id = tl_->device_id();
if (device_id < 0)
throw std::runtime_error("The tensor list is not associated with a CUDA device.");
tl_->set_ready_event(CUDASharedEvent::Create(device_id));
return tl_->ready_event().get();
}
private:
std::shared_ptr<TensorList<Backend>> impl_;
std::shared_ptr<TensorList<Backend>> tl_;
};

template <typename Backend>
RefCountedPtr<TensorListWrapper<Backend>> Wrap(std::shared_ptr<TensorList<Backend>> tl) {
return RefCountedPtr<TensorListWrapper<Backend>>(new TensorListWrapper<Backend>(std::move(tl)));
}


} // namespace c_api
} // namespace dali

Expand Down
92 changes: 92 additions & 0 deletions dali/c_api_2/ref_counting.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,102 @@
#ifndef DALI_C_API_2_REF_COUNTING_H_
#define DALI_C_API_2_REF_COUNTING_H_

#include <atomic>
#include <type_traits>
#include <utility>

namespace dali::c_api {

class RefCountedObject {
public:
int IncRef() noexcept {
return std::atomic_fetch_add_explicit(&ref_, 1, std::memory_order_relaxed) + 1;
}

int DecRef() noexcept {
int ret = std::atomic_fetch_sub_explicit(&ref_, 1, std::memory_order_acq_rel) - 1;
if (!ret)
delete this;
return ret;
}

int RefCount() const noexcept {
return ref_.load(std::memory_order_relaxed);
}

virtual ~RefCountedObject() = default;
private:
std::atomic<int> ref_{1};
};

template <typename T>
class RefCountedPtr {
public:
constexpr RefCountedPtr() noexcept = default;

explicit RefCountedPtr(T *ptr, bool inc_ref = false) noexcept : ptr_(ptr) {
if (inc_ref && ptr_)
ptr_->IncRef();
}

~RefCountedPtr() {
reset();
}

template <typename U, std::enable_if_t<std::is_convertible_v<U *, T *>, int> = 0>
RefCountedPtr(const RefCountedPtr<U> &other) noexcept : ptr_(other.ptr_) {
if (ptr_)
ptr_->IncRef();
}

template <typename U, std::enable_if_t<std::is_convertible_v<U *, T *>, int> = 0>
RefCountedPtr(RefCountedPtr<U> &&other) noexcept : ptr_(other.ptr_) {
other.ptr_ = nullptr;
}

template <typename U>
std::enable_if_t<std::is_convertible_v<U *, T *>, RefCountedPtr> &
operator=(const RefCountedPtr<U> &other) noexcept {
if (ptr_ == other.ptr_)
return *this;
if (other.ptr_)
other.ptr_->IncRef();
ptr_->DecRef();
ptr_ = other.ptr_;
return *this;
}

template <typename U>
std::enable_if_t<std::is_convertible_v<U *, T *>, RefCountedPtr> &
operator=(RefCountedPtr &&other) noexcept {
if (&other == this)
return *this;
std::swap(ptr_, other.ptr_);
other.reset();
}

void reset() noexcept {
if (ptr_)
ptr_->DecRef();
ptr_= nullptr;
}

[[nodiscard]] T *release() noexcept {
T *p = ptr_;
ptr_ = nullptr;
return p;
}

constexpr T *operator->() const & noexcept { return ptr_; }

constexpr T &operator*() const & noexcept { return *ptr_; }

constexpr T *get() const & noexcept { return ptr_; }

private:
template <typename U>
friend class RefCountedPtr;
T *ptr_ = nullptr;
};

} // namespace dali::c_api
Expand Down
Loading

0 comments on commit 5ebb4f8

Please sign in to comment.