Skip to content

Commit

Permalink
create Image and Tensor/Storage from parts (#208)
Browse files Browse the repository at this point in the history
* implement image and tensor from raw parts

* add tests for from_raw_parts
  • Loading branch information
edgarriba authored Jan 1, 2025
1 parent 5816a5c commit 2bc0585
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 26 deletions.
38 changes: 38 additions & 0 deletions crates/kornia-image/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,32 @@ impl<T, const C: usize> Image<T, C> {
Ok(image)
}

/// Create a new image from raw parts.
///
/// # Arguments
///
/// * `size` - The size of the image in pixels.
/// * `data` - A pointer to the pixel data.
/// * `len` - The length of the pixel data.
///
/// # Returns
///
/// A new image created from the given size and pixel data.
///
/// # Safety
///
/// The pointer must be non-null and the length must be valid.
pub unsafe fn from_raw_parts(
size: ImageSize,
data: *const T,
len: usize,
) -> Result<Self, ImageError>
where
T: Clone,
{
Tensor::from_raw_parts([size.height, size.width, C], data, len, CpuAllocator)?.try_into()
}

/// Create a new image from a slice of pixel data.
///
/// # Arguments
Expand Down Expand Up @@ -654,4 +680,16 @@ mod tests {

Ok(())
}

#[test]
fn image_from_raw_parts() -> Result<(), ImageError> {
let data = vec![0u8, 1, 2, 3, 4, 5];
let image =
unsafe { Image::<_, 1>::from_raw_parts([2, 3].into(), data.as_ptr(), data.len())? };
std::mem::forget(data);
assert_eq!(image.size().width, 2);
assert_eq!(image.size().height, 3);
assert_eq!(image.num_channels(), 1);
Ok(())
}
}
4 changes: 3 additions & 1 deletion crates/kornia-tensor/src/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ impl TensorAllocator for CpuAllocator {
/// The pointer must be non-null and the layout must be correct.
#[allow(clippy::not_unsafe_ptr_arg_deref)]
fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe { alloc::dealloc(ptr, layout) }
if !ptr.is_null() {
unsafe { alloc::dealloc(ptr, layout) }
}
}
}

Expand Down
41 changes: 16 additions & 25 deletions crates/kornia-tensor/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ impl<T, A: TensorAllocator> TensorStorage<T, A> {
}
}

/// Creates a new tensor buffer from a raw pointer.
///
/// # Safety
///
/// The pointer must be non-null and the length must be valid.
pub unsafe fn from_raw_parts(data: *const T, len: usize, alloc: A) -> Self {
let ptr = NonNull::new_unchecked(data as _);
let layout = Layout::from_size_align_unchecked(len, std::mem::size_of::<T>());
Self {
ptr,
len,
layout,
alloc,
}
}

/// Converts the `TensorStorage` into a `Vec<T>`.
///
/// Returns `Err(self)` if the buffer does not have the same layout as the destination Vec.
Expand All @@ -108,31 +124,6 @@ impl<T, A: TensorAllocator> TensorStorage<T, A> {
}
}

// TODO: pass the allocator to constructor
impl<T, A: TensorAllocator> From<Vec<T>> for TensorStorage<T, A>
where
A: Default,
{
/// Creates a new tensor buffer from a vector.
fn from(value: Vec<T>) -> Self {
// Safety
// Vec::as_ptr guaranteed to not be null
let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as *mut T) };
let len = value.len() * std::mem::size_of::<T>();
// Safety
// Vec guaranteed to have a valid layout matching that of `Layout::array`
// This is based on `RawVec::current_memory`
let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
std::mem::forget(value);

Self {
ptr,
len,
layout,
alloc: A::default(),
}
}
}
// Safety:
// TensorStorage is thread safe if the allocator is thread safe.
unsafe impl<T, A: TensorAllocator> Send for TensorStorage<T, A> {}
Expand Down
41 changes: 41 additions & 0 deletions crates/kornia-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,37 @@ where
})
}

/// Creates a new `Tensor` with the given shape and raw parts.
///
/// # Arguments
///
/// * `shape` - An array containing the shape of the tensor.
/// * `data` - A pointer to the data of the tensor.
/// * `len` - The length of the data.
/// * `alloc` - The allocator to use.
///
/// # Safety
///
/// The pointer must be non-null and the length must be valid.
pub unsafe fn from_raw_parts(
shape: [usize; N],
data: *const T,
len: usize,
alloc: A,
) -> Result<Self, TensorError>
where
T: Clone,
{
let storage = TensorStorage::from_raw_parts(data, len, alloc);
let strides = get_strides_from_shape(shape);
Ok(Self {
storage,
shape,
strides,
})
}

/// Creates a new `Tensor` with the given shape and a default value.
/// Creates a new `Tensor` with the given shape and a default value.
///
/// # Arguments
Expand Down Expand Up @@ -1581,4 +1612,14 @@ mod tests {
.is_err_and(|x| x == TensorError::IndexOutOfBounds(12)));
Ok(())
}

#[test]
fn from_raw_parts() -> Result<(), TensorError> {
let data: Vec<u8> = vec![1, 2, 3, 4];
let t = unsafe { Tensor::from_raw_parts([2, 2], data.as_ptr(), data.len(), CpuAllocator)? };
std::mem::forget(data);
assert_eq!(t.shape, [2, 2]);
assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
Ok(())
}
}

0 comments on commit 2bc0585

Please sign in to comment.