From 41b2ebcf73b24a1f161be136d5dbbfa8390ec5a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Fri, 17 Jan 2025 16:35:19 +0100 Subject: [PATCH] [WIP] MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/c_api_2/data_objects.h | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/dali/c_api_2/data_objects.h b/dali/c_api_2/data_objects.h index 227ca87a75..bacbe7167f 100644 --- a/dali/c_api_2/data_objects.h +++ b/dali/c_api_2/data_objects.h @@ -60,6 +60,22 @@ class TensorListInterface : public RefCountedObject { virtual cudaEvent_t GetOrCreateReadyEvent() = 0; }; +struct TensorListDeleter { + daliDeleter_t deleter; + AccessOrder deletion_order; + + void operator()(void *data) { + if (deleter.delete_buffer) { + cudaStream_t stream = deletion_order.stream(); + deleter.delete_buffer(deleter.deleter_ctx, data, + deletion_order.is_device() ? &stream : nullptr); + } + if (deleter.destroy_context) { + deleter.destroy_context(deleter.destroy_context); + } + } +}; + template class TensorListWrapper : public TensorListInterface { public: @@ -91,10 +107,7 @@ class TensorListWrapper : public TensorListInterface { if (!deleter.delete_buffer && !deleter.destroy_context) { buffer.reset(buffer, [](void *){}); } else { - buffer.reset(buffer, [deleter](void *p) { - if (deleter.delete_buffer) - deleter.delete_buffer(deleter.deleter_ctx, p, nullptr); - }); + buffer.reset(buffer, TensorListDeleter{deleter, order()}); } for (int i = 0; i < num_samples; i++) { TensorShape<> sample_shape(make_cspan(&shapes[i*ndim]. ndim)); @@ -105,9 +118,7 @@ class TensorListWrapper : public TensorListInterface { sample_data = static_cast(data) + next_offset; next_offset += volme(sample_shape) * element_size; } - } - } virtual void AttachSamples( @@ -115,7 +126,9 @@ class TensorListWrapper : public TensorListInterface { int ndim, daliDataType_t dtype, const daliTensorDesc_t *samples, - const daliDeleter_t *sample_deleters) = 0; + const daliDeleter_t *sample_deleters) { + + } virtual daliBufferPlacement_t GetBufferPlacement() const = 0;