Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create Image and Tensor/Storage from parts #208

Merged
merged 2 commits into from
Jan 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions crates/kornia-image/src/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,32 @@ impl<T, const C: usize> Image<T, C> {
Ok(image)
}

/// Create a new image from raw parts.
///
/// # Arguments
///
/// * `size` - The size of the image in pixels.
/// * `data` - A pointer to the pixel data.
/// * `len` - The length of the pixel data.
///
/// # Returns
///
/// A new image created from the given size and pixel data.
///
/// # Safety
///
/// The pointer must be non-null and the length must be valid.
pub unsafe fn from_raw_parts(
size: ImageSize,
data: *const T,
len: usize,
) -> Result<Self, ImageError>
where
T: Clone,
{
Tensor::from_raw_parts([size.height, size.width, C], data, len, CpuAllocator)?.try_into()
}

/// Create a new image from a slice of pixel data.
///
/// # Arguments
Expand Down Expand Up @@ -654,4 +680,16 @@ mod tests {

Ok(())
}

#[test]
fn image_from_raw_parts() -> Result<(), ImageError> {
let data = vec![0u8, 1, 2, 3, 4, 5];
let image =
unsafe { Image::<_, 1>::from_raw_parts([2, 3].into(), data.as_ptr(), data.len())? };
std::mem::forget(data);
assert_eq!(image.size().width, 2);
assert_eq!(image.size().height, 3);
assert_eq!(image.num_channels(), 1);
Ok(())
}
}
4 changes: 3 additions & 1 deletion crates/kornia-tensor/src/allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,9 @@ impl TensorAllocator for CpuAllocator {
/// The pointer must be non-null and the layout must be correct.
#[allow(clippy::not_unsafe_ptr_arg_deref)]
fn dealloc(&self, ptr: *mut u8, layout: Layout) {
unsafe { alloc::dealloc(ptr, layout) }
if !ptr.is_null() {
unsafe { alloc::dealloc(ptr, layout) }
}
}
}

Expand Down
41 changes: 16 additions & 25 deletions crates/kornia-tensor/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ impl<T, A: TensorAllocator> TensorStorage<T, A> {
}
}

/// Creates a new tensor buffer from a raw pointer.
///
/// # Safety
///
/// The pointer must be non-null and the length must be valid.
pub unsafe fn from_raw_parts(data: *const T, len: usize, alloc: A) -> Self {
let ptr = NonNull::new_unchecked(data as _);
let layout = Layout::from_size_align_unchecked(len, std::mem::size_of::<T>());
Self {
ptr,
len,
layout,
alloc,
}
}

/// Converts the `TensorStorage` into a `Vec<T>`.
///
/// Returns `Err(self)` if the buffer does not have the same layout as the destination Vec.
Expand All @@ -108,31 +124,6 @@ impl<T, A: TensorAllocator> TensorStorage<T, A> {
}
}

// TODO: pass the allocator to constructor
impl<T, A: TensorAllocator> From<Vec<T>> for TensorStorage<T, A>
where
A: Default,
{
/// Creates a new tensor buffer from a vector.
fn from(value: Vec<T>) -> Self {
// Safety
// Vec::as_ptr guaranteed to not be null
let ptr = unsafe { NonNull::new_unchecked(value.as_ptr() as *mut T) };
let len = value.len() * std::mem::size_of::<T>();
// Safety
// Vec guaranteed to have a valid layout matching that of `Layout::array`
// This is based on `RawVec::current_memory`
let layout = unsafe { Layout::array::<T>(value.capacity()).unwrap_unchecked() };
std::mem::forget(value);

Self {
ptr,
len,
layout,
alloc: A::default(),
}
}
}
// Safety:
// TensorStorage is thread safe if the allocator is thread safe.
unsafe impl<T, A: TensorAllocator> Send for TensorStorage<T, A> {}
Expand Down
41 changes: 41 additions & 0 deletions crates/kornia-tensor/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,37 @@ where
})
}

/// Creates a new `Tensor` with the given shape and raw parts.
///
/// # Arguments
///
/// * `shape` - An array containing the shape of the tensor.
/// * `data` - A pointer to the data of the tensor.
/// * `len` - The length of the data.
/// * `alloc` - The allocator to use.
///
/// # Safety
///
/// The pointer must be non-null and the length must be valid.
pub unsafe fn from_raw_parts(
shape: [usize; N],
data: *const T,
len: usize,
alloc: A,
) -> Result<Self, TensorError>
where
T: Clone,
{
let storage = TensorStorage::from_raw_parts(data, len, alloc);
let strides = get_strides_from_shape(shape);
Ok(Self {
storage,
shape,
strides,
})
}

/// Creates a new `Tensor` with the given shape and a default value.
/// Creates a new `Tensor` with the given shape and a default value.
///
/// # Arguments
Expand Down Expand Up @@ -1581,4 +1612,14 @@ mod tests {
.is_err_and(|x| x == TensorError::IndexOutOfBounds(12)));
Ok(())
}

#[test]
fn from_raw_parts() -> Result<(), TensorError> {
let data: Vec<u8> = vec![1, 2, 3, 4];
let t = unsafe { Tensor::from_raw_parts([2, 2], data.as_ptr(), data.len(), CpuAllocator)? };
std::mem::forget(data);
assert_eq!(t.shape, [2, 2]);
assert_eq!(t.as_slice(), &[1, 2, 3, 4]);
Ok(())
}
}
Loading