Skip to content

Commit

Permalink
add From trait for converting ndarray to ManagedTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Feb 21, 2024
1 parent 7d5174a commit b2b137c
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 44 deletions.
18 changes: 10 additions & 8 deletions rust/cuvs/examples/cagra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
* limitations under the License.
*/

use cuvs::cagra::{IndexParams, SearchParams, Index};
use cuvs::{ManagedTensor, Result, Resources};
use cuvs::cagra::{Index, IndexParams, SearchParams};
use cuvs::{ManagedTensor, Resources, Result};

use ndarray::s;
use ndarray_rand::rand_distr::Uniform;
Expand All @@ -33,9 +33,11 @@ fn cagra_example() -> Result<()> {

// build the cagra index
let build_params = IndexParams::new()?;
let index = Index::build(&res, &build_params, &ManagedTensor::from_ndarray(&dataset))?;
println!("Indexed {}x{} datapoints into cagra index", n_datapoints, n_features);

let index = Index::build(&res, &build_params, &dataset)?;
println!(
"Indexed {}x{} datapoints into cagra index",
n_datapoints, n_features
);

// use the first 4 points from the dataset as queries : will test that we get them back
// as their own nearest neighbor
Expand All @@ -47,12 +49,12 @@ fn cagra_example() -> Result<()> {
// CAGRA search API requires queries and outputs to be on device memory
// copy query data over, and allocate new device memory for the distances/ neighbors
// outputs
let queries = ManagedTensor::from_ndarray(&queries).to_device(&res)?;
let queries = ManagedTensor::from(&queries).to_device(&res)?;
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from_ndarray(&neighbors_host).to_device(&res)?;
let neighbors = ManagedTensor::from(&neighbors_host).to_device(&res)?;

let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from_ndarray(&distances_host).to_device(&res)?;
let distances = ManagedTensor::from(&distances_host).to_device(&res)?;

let search_params = SearchParams::new()?;

Expand Down
17 changes: 11 additions & 6 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ pub struct Index(ffi::cagraIndex_t);

impl Index {
/// Builds a new index
pub fn build(res: &Resources, params: &IndexParams, dataset: &ManagedTensor) -> Result<Index> {
pub fn build<T: Into<ManagedTensor>>(
res: &Resources,
params: &IndexParams,
dataset: T,
) -> Result<Index> {
let dataset: ManagedTensor = dataset.into();
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cagraBuild(res.0, params.0, dataset.as_ptr(), index.0))?;
Expand Down Expand Up @@ -92,8 +97,8 @@ mod tests {

// build the cagra index
let build_params = IndexParams::new().unwrap();
let index = Index::build(&res, &build_params, &ManagedTensor::from_ndarray(&dataset))
.expect("failed to create cagra index");
let index =
Index::build(&res, &build_params, &dataset).expect("failed to create cagra index");

// use the first 4 points from the dataset as queries : will test that we get them back
// as their own nearest neighbor
Expand All @@ -105,14 +110,14 @@ mod tests {
// CAGRA search API requires queries and outputs to be on device memory
// copy query data over, and allocate new device memory for the distances/ neighbors
// outputs
let queries = ManagedTensor::from_ndarray(&queries).to_device(&res).unwrap();
let queries = ManagedTensor::from(&queries).to_device(&res).unwrap();
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from_ndarray(&neighbors_host)
let neighbors = ManagedTensor::from(&neighbors_host)
.to_device(&res)
.unwrap();

let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from_ndarray(&distances_host)
let distances = ManagedTensor::from(&distances_host)
.to_device(&res)
.unwrap();

Expand Down
62 changes: 33 additions & 29 deletions rust/cuvs/src/dlpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

use std::convert::From;

use crate::error::{check_cuda, Result};
use crate::resources::Resources;

Expand All @@ -25,34 +27,6 @@ pub trait IntoDtype {
}

impl ManagedTensor {
/// Create a non-owning view of a Tensor from a ndarray
pub fn from_ndarray<T: IntoDtype, S: ndarray::RawData<Elem = T>, D: ndarray::Dimension>(
arr: &ndarray::ArrayBase<S, D>,
) -> ManagedTensor {
// There is a draft PR out right now for creating dlpack directly from ndarray
// right now, but until its merged we have to implement ourselves
//https://github.com/rust-ndarray/ndarray/pull/1306/files
unsafe {
let mut ret = std::mem::MaybeUninit::<ffi::DLTensor>::uninit();
let tensor = ret.as_mut_ptr();
(*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void;
(*tensor).device = ffi::DLDevice {
device_type: ffi::DLDeviceType::kDLCPU,
device_id: 0,
};
(*tensor).byte_offset = 0;
(*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor
(*tensor).ndim = arr.ndim() as i32;
(*tensor).shape = arr.shape().as_ptr() as *mut _;
(*tensor).dtype = T::ffi_dtype();
ManagedTensor(ffi::DLManagedTensor {
dl_tensor: ret.assume_init(),
manager_ctx: std::ptr::null_mut(),
deleter: None,
})
}
}

pub fn as_ptr(&self) -> *mut ffi::DLManagedTensor {
&self.0 as *const _ as *mut _
}
Expand Down Expand Up @@ -117,6 +91,36 @@ unsafe extern "C" fn cuda_free_tensor(self_: *mut ffi::DLManagedTensor) {
let _ = ffi::cudaFree((*self_).dl_tensor.data);
}

/// Create a non-owning view of a Tensor from a ndarray
impl<T: IntoDtype, S: ndarray::RawData<Elem = T>, D: ndarray::Dimension>
From<&ndarray::ArrayBase<S, D>> for ManagedTensor
{
fn from(arr: &ndarray::ArrayBase<S, D>) -> Self {
// There is a draft PR out right now for creating dlpack directly from ndarray
// right now, but until its merged we have to implement ourselves
//https://github.com/rust-ndarray/ndarray/pull/1306/files
unsafe {
let mut ret = std::mem::MaybeUninit::<ffi::DLTensor>::uninit();
let tensor = ret.as_mut_ptr();
(*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void;
(*tensor).device = ffi::DLDevice {
device_type: ffi::DLDeviceType::kDLCPU,
device_id: 0,
};
(*tensor).byte_offset = 0;
(*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor
(*tensor).ndim = arr.ndim() as i32;
(*tensor).shape = arr.shape().as_ptr() as *mut _;
(*tensor).dtype = T::ffi_dtype();
ManagedTensor(ffi::DLManagedTensor {
dl_tensor: ret.assume_init(),
manager_ctx: std::ptr::null_mut(),
deleter: None,
})
}
}
}

impl Drop for ManagedTensor {
fn drop(&mut self) {
unsafe {
Expand Down Expand Up @@ -175,7 +179,7 @@ mod tests {
fn test_from_ndarray() {
let arr = ndarray::Array::<f32, _>::zeros((8, 4));

let tensor = unsafe { (*(ManagedTensor::from_ndarray(&arr).as_ptr())).dl_tensor };
let tensor = unsafe { (*(ManagedTensor::from(&arr).as_ptr())).dl_tensor };

assert_eq!(tensor.ndim, 2);

Expand Down
2 changes: 1 addition & 1 deletion rust/cuvs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ mod dlpack;
mod error;
mod resources;

pub use dlpack::ManagedTensor;
pub use error::{Error, Result};
pub use resources::Resources;
pub use dlpack::ManagedTensor;

0 comments on commit b2b137c

Please sign in to comment.