diff --git a/rust/cuvs/examples/cagra.rs b/rust/cuvs/examples/cagra.rs index 879fac787..ccc1466dd 100644 --- a/rust/cuvs/examples/cagra.rs +++ b/rust/cuvs/examples/cagra.rs @@ -14,8 +14,8 @@ * limitations under the License. */ -use cuvs::cagra::{IndexParams, SearchParams, Index}; -use cuvs::{ManagedTensor, Result, Resources}; +use cuvs::cagra::{Index, IndexParams, SearchParams}; +use cuvs::{ManagedTensor, Resources, Result}; use ndarray::s; use ndarray_rand::rand_distr::Uniform; @@ -33,9 +33,11 @@ fn cagra_example() -> Result<()> { // build the cagra index let build_params = IndexParams::new()?; - let index = Index::build(&res, &build_params, &ManagedTensor::from_ndarray(&dataset))?; - println!("Indexed {}x{} datapoints into cagra index", n_datapoints, n_features); - + let index = Index::build(&res, &build_params, &dataset)?; + println!( + "Indexed {}x{} datapoints into cagra index", + n_datapoints, n_features + ); // use the first 4 points from the dataset as queries : will test that we get them back // as their own nearest neighbor @@ -47,12 +49,12 @@ fn cagra_example() -> Result<()> { // CAGRA search API requires queries and outputs to be on device memory // copy query data over, and allocate new device memory for the distances/ neighbors // outputs - let queries = ManagedTensor::from_ndarray(&queries).to_device(&res)?; + let queries = ManagedTensor::from(&queries).to_device(&res)?; let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from_ndarray(&neighbors_host).to_device(&res)?; + let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?; let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from_ndarray(&distances_host).to_device(&res)?; + let distances = ManagedTensor::from(&distances_host).to_device(&res)?; let search_params = SearchParams::new()?; diff --git a/rust/cuvs/src/cagra/index.rs b/rust/cuvs/src/cagra/index.rs index be1543a29..25a54d095 100644 --- a/rust/cuvs/src/cagra/index.rs +++ b/rust/cuvs/src/cagra/index.rs @@ -26,7 +26,12 @@ pub struct Index(ffi::cagraIndex_t); impl Index { /// Builds a new index - pub fn build(res: &Resources, params: &IndexParams, dataset: &ManagedTensor) -> Result { + pub fn build>( + res: &Resources, + params: &IndexParams, + dataset: T, + ) -> Result { + let dataset: ManagedTensor = dataset.into(); let index = Index::new()?; unsafe { check_cuvs(ffi::cagraBuild(res.0, params.0, dataset.as_ptr(), index.0))?; @@ -92,8 +97,8 @@ mod tests { // build the cagra index let build_params = IndexParams::new().unwrap(); - let index = Index::build(&res, &build_params, &ManagedTensor::from_ndarray(&dataset)) - .expect("failed to create cagra index"); + let index = + Index::build(&res, &build_params, &dataset).expect("failed to create cagra index"); // use the first 4 points from the dataset as queries : will test that we get them back // as their own nearest neighbor @@ -105,14 +110,14 @@ mod tests { // CAGRA search API requires queries and outputs to be on device memory // copy query data over, and allocate new device memory for the distances/ neighbors // outputs - let queries = ManagedTensor::from_ndarray(&queries).to_device(&res).unwrap(); + let queries = ManagedTensor::from(&queries).to_device(&res).unwrap(); let mut neighbors_host = ndarray::Array::::zeros((n_queries, k)); - let neighbors = ManagedTensor::from_ndarray(&neighbors_host) + let neighbors = ManagedTensor::from(&neighbors_host) .to_device(&res) .unwrap(); let mut distances_host = ndarray::Array::::zeros((n_queries, k)); - let distances = ManagedTensor::from_ndarray(&distances_host) + let distances = ManagedTensor::from(&distances_host) .to_device(&res) .unwrap(); diff --git a/rust/cuvs/src/dlpack.rs b/rust/cuvs/src/dlpack.rs index 26d618a4a..b86959db1 100644 --- a/rust/cuvs/src/dlpack.rs +++ b/rust/cuvs/src/dlpack.rs @@ -14,6 +14,8 @@ * limitations under the License. */ +use std::convert::From; + use crate::error::{check_cuda, Result}; use crate::resources::Resources; @@ -25,34 +27,6 @@ pub trait IntoDtype { } impl ManagedTensor { - /// Create a non-owning view of a Tensor from a ndarray - pub fn from_ndarray, D: ndarray::Dimension>( - arr: &ndarray::ArrayBase, - ) -> ManagedTensor { - // There is a draft PR out right now for creating dlpack directly from ndarray - // right now, but until its merged we have to implement ourselves - //https://github.com/rust-ndarray/ndarray/pull/1306/files - unsafe { - let mut ret = std::mem::MaybeUninit::::uninit(); - let tensor = ret.as_mut_ptr(); - (*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void; - (*tensor).device = ffi::DLDevice { - device_type: ffi::DLDeviceType::kDLCPU, - device_id: 0, - }; - (*tensor).byte_offset = 0; - (*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor - (*tensor).ndim = arr.ndim() as i32; - (*tensor).shape = arr.shape().as_ptr() as *mut _; - (*tensor).dtype = T::ffi_dtype(); - ManagedTensor(ffi::DLManagedTensor { - dl_tensor: ret.assume_init(), - manager_ctx: std::ptr::null_mut(), - deleter: None, - }) - } - } - pub fn as_ptr(&self) -> *mut ffi::DLManagedTensor { &self.0 as *const _ as *mut _ } @@ -117,6 +91,36 @@ unsafe extern "C" fn cuda_free_tensor(self_: *mut ffi::DLManagedTensor) { let _ = ffi::cudaFree((*self_).dl_tensor.data); } +/// Create a non-owning view of a Tensor from a ndarray +impl, D: ndarray::Dimension> + From<&ndarray::ArrayBase> for ManagedTensor +{ + fn from(arr: &ndarray::ArrayBase) -> Self { + // There is a draft PR out right now for creating dlpack directly from ndarray + // right now, but until its merged we have to implement ourselves + //https://github.com/rust-ndarray/ndarray/pull/1306/files + unsafe { + let mut ret = std::mem::MaybeUninit::::uninit(); + let tensor = ret.as_mut_ptr(); + (*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void; + (*tensor).device = ffi::DLDevice { + device_type: ffi::DLDeviceType::kDLCPU, + device_id: 0, + }; + (*tensor).byte_offset = 0; + (*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor + (*tensor).ndim = arr.ndim() as i32; + (*tensor).shape = arr.shape().as_ptr() as *mut _; + (*tensor).dtype = T::ffi_dtype(); + ManagedTensor(ffi::DLManagedTensor { + dl_tensor: ret.assume_init(), + manager_ctx: std::ptr::null_mut(), + deleter: None, + }) + } + } +} + impl Drop for ManagedTensor { fn drop(&mut self) { unsafe { @@ -175,7 +179,7 @@ mod tests { fn test_from_ndarray() { let arr = ndarray::Array::::zeros((8, 4)); - let tensor = unsafe { (*(ManagedTensor::from_ndarray(&arr).as_ptr())).dl_tensor }; + let tensor = unsafe { (*(ManagedTensor::from(&arr).as_ptr())).dl_tensor }; assert_eq!(tensor.ndim, 2); diff --git a/rust/cuvs/src/lib.rs b/rust/cuvs/src/lib.rs index b60262706..7a6f847f5 100644 --- a/rust/cuvs/src/lib.rs +++ b/rust/cuvs/src/lib.rs @@ -19,6 +19,6 @@ mod dlpack; mod error; mod resources; +pub use dlpack::ManagedTensor; pub use error::{Error, Result}; pub use resources::Resources; -pub use dlpack::ManagedTensor;