diff --git a/rust/tvm-ffi/src/any.rs b/rust/tvm-ffi/src/any.rs index d5a4c855..ecf8b9ea 100644 --- a/rust/tvm-ffi/src/any.rs +++ b/rust/tvm-ffi/src/any.rs @@ -177,6 +177,7 @@ impl Any { #[inline] pub unsafe fn into_raw_ffi_any(this: Self) -> TVMFFIAny { + let this = std::mem::ManuallyDrop::new(this); this.data } diff --git a/rust/tvm-ffi/src/collections/array.rs b/rust/tvm-ffi/src/collections/array.rs new file mode 100644 index 00000000..6f259ba1 --- /dev/null +++ b/rust/tvm-ffi/src/collections/array.rs @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use std::fmt::Debug; +use std::marker::PhantomData; +use std::ops::Deref; + +use crate::any::TryFromTemp; +use crate::derive::Object; +use crate::object::{Object, ObjectArc}; +use crate::{Any, AnyCompatible, AnyView, ObjectCoreWithExtraItems, ObjectRefCore}; +use tvm_ffi_sys::TVMFFITypeIndex as TypeIndex; +use tvm_ffi_sys::{TVMFFIAny, TVMFFIObject}; + +#[repr(C)] +#[derive(Object)] +#[type_key = "ffi.Array"] +#[type_index(TypeIndex::kTVMFFIArray)] +pub struct ArrayObj { + pub object: Object, + /// Pointer to the start of the element buffer (AddressOf(0)). + pub data: *mut core::ffi::c_void, + pub size: i64, + pub capacity: i64, + /// Optional custom deleter for the data pointer. + pub data_deleter: Option, +} + +unsafe impl ObjectCoreWithExtraItems for ArrayObj { + type ExtraItem = TVMFFIAny; + fn extra_items_count(this: &Self) -> usize { + this.size as usize + } +} + +#[repr(C)] +#[derive(Clone)] +pub struct Array { + data: ObjectArc, + _marker: PhantomData, +} + +impl Debug for Array { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let full_name = std::any::type_name::(); + let short_name = full_name.split("::").last().unwrap_or(full_name); + write!(f, "Array<{}>[{}]", short_name, self.len()) + } +} + +impl Default for Array { + fn default() -> Self { + Self::new(vec![]) + } +} + +unsafe impl ObjectRefCore for Array { + type ContainerType = ArrayObj; + + fn data(this: &Self) -> &ObjectArc { + &this.data + } + + fn into_data(this: Self) -> ObjectArc { + this.data + } + + fn from_data(data: ObjectArc) -> Self { + Self { + data, + _marker: PhantomData, + } + } +} + +impl Array { + /// Creates a new Array from a vector of items. + pub fn new(items: Vec) -> Self { + let capacity = items.len(); + Self::new_with_capacity(items, capacity) + } + + /// Internal helper to allocate an ArrayObj with specific headroom. + fn new_with_capacity(items: Vec, capacity: usize) -> Self { + let size = items.len(); + + // Allocate with capacity + let arc = ObjectArc::::new_with_extra_items(ArrayObj { + object: Object::new(), + data: core::ptr::null_mut(), + size: size as i64, + capacity: capacity as i64, + data_deleter: None, + }); + + unsafe { + let raw_ptr = ObjectArc::as_raw(&arc) as *mut ArrayObj; + let container = &mut *raw_ptr; + + let base_ptr = ArrayObj::extra_items_mut(container).as_ptr() as *mut TVMFFIAny; + container.data = base_ptr as *mut _; + + for (i, item) in items.into_iter().enumerate() { + let any: Any = Any::from(item); + let raw = Any::into_raw_ffi_any(any); + core::ptr::write(base_ptr.add(i), raw); + } + } + Self::from_data(arc) + } + + pub fn len(&self) -> usize { + self.data.size as usize + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Retrieves an item at the given index. + pub fn get(&self, index: usize) -> Result { + if index >= self.len() { + crate::bail!(crate::error::INDEX_ERROR, "Array get index out of bound"); + } + unsafe { + let container = self.data.deref(); + let base_ptr = container.data as *const TVMFFIAny; + let raw_any_ref = &*base_ptr.add(index); + + match T::try_cast_from_any_view(raw_any_ref) { + Ok(val) => Ok(val), + Err(_) => crate::bail!( + crate::error::TYPE_ERROR, + "Failed to cast element at {} to {}", + index, + T::type_str() + ), + } + } + } + + pub fn iter(&'_ self) -> ArrayIterator<'_, T> { + ArrayIterator { + array: self, + index: 0, + len: self.len(), + } + } + + #[inline] + fn as_container(&self) -> &ArrayObj { + unsafe { + let ptr = ObjectArc::as_raw(&self.data) as *const ArrayObj; + &*ptr + } + } +} + +// --- Index Implementation --- + +impl std::ops::Index for Array { + type Output = AnyView<'static>; + + fn index(&self, index: usize) -> &Self::Output { + let container = self.as_container(); + let len = container.size as usize; + if index >= len { + panic!( + "Index out of bounds: the len is {} but the index is {}", + len, index + ); + } + unsafe { + let ptr = (container.data as *const AnyView<'static>).add(index); + &*ptr + } + } +} + +// --- Iterator Implementations --- + +pub struct ArrayIterator<'a, T: AnyCompatible + Clone> { + array: &'a Array, + index: usize, + len: usize, +} + +impl<'a, T: AnyCompatible + Clone> Iterator for ArrayIterator<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + if self.index < self.len { + let item = self.array.get(self.index).ok(); + self.index += 1; + item + } else { + None + } + } +} + +impl<'a, T: AnyCompatible + Clone> IntoIterator for &'a Array { + type Item = T; + type IntoIter = ArrayIterator<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} + +impl FromIterator for Array { + fn from_iter>(iter: I) -> Self { + let items: Vec = iter.into_iter().collect(); + Self::new(items) + } +} + +// --- Any Type System Conversions --- + +unsafe impl AnyCompatible for Array +where + T: AnyCompatible + Clone + 'static, +{ + fn type_str() -> String { + format!("Array<{}>", T::type_str()) + } + + unsafe fn check_any_strict(data: &TVMFFIAny) -> bool { + if data.type_index != TypeIndex::kTVMFFIArray as i32 { + return false; + } + + if std::any::TypeId::of::() == std::any::TypeId::of::() { + return true; + } + + let container = &*(data.data_union.v_obj as *const ArrayObj); + let base_ptr = container.data as *const TVMFFIAny; + for i in 0..container.size { + let elem_any = &*base_ptr.add(i as usize); + if !T::check_any_strict(elem_any) { + return false; + } + } + true + } + + unsafe fn copy_to_any_view(src: &Self, data: &mut TVMFFIAny) { + data.type_index = TypeIndex::kTVMFFIArray as i32; + data.data_union.v_obj = ObjectArc::as_raw(Self::data(src)) as *mut TVMFFIObject; + data.small_str_len = 0; + } + + unsafe fn move_to_any(src: Self, data: &mut TVMFFIAny) { + data.type_index = TypeIndex::kTVMFFIArray as i32; + data.data_union.v_obj = ObjectArc::into_raw(Self::into_data(src)) as *mut TVMFFIObject; + data.small_str_len = 0; + } + + unsafe fn copy_from_any_view_after_check(data: &TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *const ArrayObj; + crate::object::unsafe_::inc_ref(ptr as *mut TVMFFIObject); + Self::from_data(ObjectArc::from_raw(ptr)) + } + + unsafe fn move_from_any_after_check(data: &mut TVMFFIAny) -> Self { + let ptr = data.data_union.v_obj as *const ArrayObj; + let obj = Self::from_data(ObjectArc::from_raw(ptr)); + + data.type_index = TypeIndex::kTVMFFINone as i32; + data.data_union.v_int64 = 0; + + obj + } + + unsafe fn try_cast_from_any_view(data: &TVMFFIAny) -> Result { + if data.type_index != TypeIndex::kTVMFFIArray as i32 { + return Err(()); + } + + // Fast path: if types match exactly, we can just copy the reference. + if Self::check_any_strict(data) { + return Ok(Self::copy_from_any_view_after_check(data)); + } + + // Slow path: try to convert element by element. + let container = &*(data.data_union.v_obj as *const ArrayObj); + let base_ptr = container.data as *const TVMFFIAny; + let mut items = Vec::with_capacity(container.size as usize); + + for i in 0..container.size { + let any_v = &*base_ptr.add(i as usize); + if let Ok(item) = T::try_cast_from_any_view(any_v) { + items.push(item); + } else { + return Err(()); + } + } + + Ok(Array::new(items)) + } +} + +impl TryFrom for Array +where + T: AnyCompatible + Clone + 'static, +{ + type Error = crate::error::Error; + + fn try_from(value: Any) -> Result { + let temp: TryFromTemp = TryFromTemp::try_from(value)?; + Ok(TryFromTemp::into_value(temp)) + } +} + +impl<'a, T> TryFrom> for Array +where + T: AnyCompatible + Clone + 'static, +{ + type Error = crate::error::Error; + + fn try_from(value: AnyView<'a>) -> Result { + let temp: TryFromTemp = TryFromTemp::try_from(value)?; + Ok(TryFromTemp::into_value(temp)) + } +} diff --git a/rust/tvm-ffi/src/collections/mod.rs b/rust/tvm-ffi/src/collections/mod.rs index 85635a7c..ad17dcca 100644 --- a/rust/tvm-ffi/src/collections/mod.rs +++ b/rust/tvm-ffi/src/collections/mod.rs @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -pub mod shape; /// Collection types +pub mod array; +pub mod shape; pub mod tensor; diff --git a/rust/tvm-ffi/src/lib.rs b/rust/tvm-ffi/src/lib.rs index 94e87b02..fad82601 100644 --- a/rust/tvm-ffi/src/lib.rs +++ b/rust/tvm-ffi/src/lib.rs @@ -32,6 +32,7 @@ pub mod type_traits; pub use tvm_ffi_sys; pub use crate::any::{Any, AnyView}; +pub use crate::collections::array::Array; pub use crate::collections::shape::Shape; pub use crate::collections::tensor::{CPUNDAlloc, NDAllocator, Tensor}; pub use crate::device::{current_stream, with_stream}; diff --git a/rust/tvm-ffi/tests/test_array.rs b/rust/tvm-ffi/tests/test_array.rs new file mode 100644 index 00000000..fe87c5fd --- /dev/null +++ b/rust/tvm-ffi/tests/test_array.rs @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +use tvm_ffi::*; + +/// Helper to create a Tensor with a specific float value and shape +fn create_tensor(val: f32, shape: &[i64]) -> Tensor { + let dtype = DLDataType::new(DLDataTypeCode::kDLFloat, 32, 1); + let device = DLDevice::new(DLDeviceType::kDLCPU, 0); + let tensor = Tensor::from_nd_alloc(CPUNDAlloc {}, shape, dtype, device); + if let Ok(slice) = tensor.data_as_slice_mut::() { + slice[0] = val; + } + tensor +} + +/// Helper to extract the first float value from a Tensor +fn get_val(tensor: &Tensor) -> f32 { + tensor + .data_as_slice::() + .expect("Type mismatch or null")[0] +} + +#[test] +fn test_array_core_and_iteration() { + let t1 = create_tensor(10.0, &[1, 2]); + let t2 = create_tensor(20.0, &[3, 4, 5]); + + let array = Array::new(vec![t1.clone(), t2.clone()]); + + // Core Accessors + assert_eq!(array.len(), 2); + assert!(!array.is_empty()); + + // Value Integrity + assert_eq!(get_val(&Tensor::try_from(array[0]).unwrap()), 10.0); + assert_eq!(Tensor::try_from(array[0]).unwrap().ndim(), 2); + assert_eq!(Tensor::try_from(array[1]).unwrap().ndim(), 3); + + // Iteration + let vals: Vec = array.iter().map(|t| get_val(&t)).collect(); + assert_eq!(vals, vec![10.0, 20.0]); +} + +#[test] +fn test_array_any_conversions() { + let array = Array::new(vec![ + create_tensor(1.0, &[1]), + create_tensor(2.0, &[1]), + create_tensor(3.0, &[1]), + ]); + + // Test Any/AnyView Roundtrip (Verifies AnyCompatible and Trait Bounds) + let any = Any::from(array); + assert_eq!(any.type_index(), TypeIndex::kTVMFFIArray as i32); + + let back: Array = Array::try_from(any).expect("Any -> Array failed"); + assert_eq!(back.len(), 3); + assert_eq!(get_val(&back.get(2).unwrap()), 3.0); + + let view = AnyView::from(&back); + let back_from_view: Array = Array::try_from(view).expect("AnyView -> Array failed"); + assert_eq!(back_from_view.len(), 3); +} + +#[test] +fn test_array_recursive_type_checking() { + // 1. Create an Array of Shapes + let shape_array = Array::new(vec![Shape::from(vec![1, 2]), Shape::from(vec![3])]); + + // 2. Wrap it in Any + let any_val = Any::from(shape_array); + + // 3. Try to convert Any (containing Shapes) into Array + // This should FAIL because T::check_any_strict (Tensor) will fail on Shape elements + let tensor_cast = Array::::try_from(any_val.clone()); + assert!( + tensor_cast.is_err(), + "Should not be able to cast Array to Array" + ); + + // 4. Verify valid cast works + let shape_cast = Array::::try_from(any_val); + assert!( + shape_cast.is_ok(), + "Should be able to cast back to correct type" + ); +} + +#[test] +fn test_array_parametric_heterogeneity() { + // Verify Array works with different ObjectRefCore types + let shape_array = Array::new(vec![Shape::from(vec![1, 2, 3]), Shape::from(vec![10])]); + assert_eq!(shape_array.get(0).unwrap().as_slice(), &[1, 2, 3]); + assert_eq!(shape_array.get(1).unwrap().as_slice(), &[10]); + + let function_array = Array::new(vec![ + Function::get_global("ffi.String").unwrap(), + Function::get_global("ffi.Bytes").unwrap(), + ]); + assert_eq!( + into_typed_fn!( + function_array.get(0).unwrap(), + Fn(String) -> Result + )("hello".into()) + .unwrap(), + "hello" + ); + assert_eq!( + into_typed_fn!( + function_array.get(1).unwrap(), + Fn(Bytes) -> Result + )([1, 2, 3].into()) + .unwrap(), + &[1, 2, 3] + ); +}