Skip to content

Commit

Permalink
functioning unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Feb 13, 2024
1 parent 291c140 commit 47ea736
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 80 deletions.
81 changes: 44 additions & 37 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,14 @@ use crate::error::{check_cuvs, Result};
use crate::resources::Resources;

#[derive(Debug)]
pub struct Index {
index: ffi::cagraIndex_t,
}
pub struct Index(ffi::cagraIndex_t);

impl Index {
/// Builds a new index
pub fn build(res: &Resources, params: &IndexParams, dataset: ManagedTensor) -> Result<Index> {
pub fn build(res: &Resources, params: &IndexParams, dataset: &ManagedTensor) -> Result<Index> {
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cagraBuild(
res.0,
params.0,
dataset.as_ptr(),
index.index,
))?;
check_cuvs(ffi::cagraBuild(res.0, params.0, dataset.as_ptr(), index.0))?;
}
Ok(index)
}
Expand All @@ -46,25 +39,23 @@ impl Index {
unsafe {
let mut index = core::mem::MaybeUninit::<ffi::cagraIndex_t>::uninit();
check_cuvs(ffi::cagraIndexCreate(index.as_mut_ptr()))?;
Ok(Index {
index: index.assume_init(),
})
Ok(Index(index.assume_init()))
}
}

pub fn search(
self,
res: &Resources,
params: &SearchParams,
queries: ManagedTensor,
neighbors: ManagedTensor,
distances: ManagedTensor,
queries: &ManagedTensor,
neighbors: &ManagedTensor,
distances: &ManagedTensor,
) -> Result<()> {
unsafe {
check_cuvs(ffi::cagraSearch(
res.0,
params.0,
self.index,
self.0,
queries.as_ptr(),
neighbors.as_ptr(),
distances.as_ptr(),
Expand All @@ -75,7 +66,7 @@ impl Index {

impl Drop for Index {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cagraIndexDestroy(self.index) }) {
if let Err(e) = check_cuvs(unsafe { ffi::cagraIndexDestroy(self.0) }) {
write!(stderr(), "failed to call cagraIndexDestroy {:?}", e)
.expect("failed to write to stderr");
}
Expand All @@ -90,40 +81,56 @@ mod tests {
use ndarray_rand::RandomExt;

#[test]
fn test_create_empty_index() {
Index::new().unwrap();
}

#[test]
fn test_index() {
fn test_cagra_index() {
let res = Resources::new().unwrap();
let params = IndexParams::new().unwrap();

// Create a new random dataset to index
let n_datapoints = 256;
let n_features = 16;
let dataset = ndarray::Array::<f32, _>::random((256, n_features), Uniform::new(0., 1.0));
let index = Index::build(&res, &params, ManagedTensor::from_ndarray(&dataset))
let dataset =
ndarray::Array::<f32, _>::random((n_datapoints, n_features), Uniform::new(0., 1.0));

// build the cagra index
let build_params = IndexParams::new().unwrap();
let index = Index::build(&res, &build_params, &ManagedTensor::from_ndarray(&dataset))
.expect("failed to create cagra index");

// use the first 4 points from the dataset as queries : will test that we get them back
// as their own nearest neighbor
let n_queries = 4;
let queries = dataset.slice(s![0..n_queries, ..]);
let queries = ManagedTensor::from_ndarray(&queries).to_device().unwrap();

let k = 10;
let neighbors =
ManagedTensor::from_ndarray(&ndarray::Array::<u32, _>::zeros((n_queries, k)))
.to_device()
.unwrap();
let distances =
ManagedTensor::from_ndarray(&ndarray::Array::<f32, _>::zeros((n_queries, k)))
.to_device()
.unwrap();

// CAGRA search API requires queries and outputs to be on device memory
// copy query data over, and allocate new device memory for the distances/ neighbors
// outputs
let queries = ManagedTensor::from_ndarray(&queries).to_device().unwrap();
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from_ndarray(&neighbors_host)
.to_device()
.unwrap();

let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from_ndarray(&distances_host)
.to_device()
.unwrap();

let search_params = SearchParams::new().unwrap();

index
.search(&res, &search_params, queries, neighbors, distances)
.search(&res, &search_params, &queries, &neighbors, &distances)
.unwrap();

// Copy back to host memory
distances.to_host(&mut distances_host).unwrap();
neighbors.to_host(&mut neighbors_host).unwrap();

// nearest neighbors should be themselves, since queries are from the
// dataset
assert_eq!(neighbors_host[[0, 0]], 0);
assert_eq!(neighbors_host[[1, 0]], 1);
assert_eq!(neighbors_host[[2, 0]], 2);
assert_eq!(neighbors_host[[3, 0]], 3);
}
}
106 changes: 63 additions & 43 deletions rust/cuvs/src/dlpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,46 +23,6 @@ pub trait IntoDtype {
fn ffi_dtype() -> ffi::DLDataType;
}

impl IntoDtype for f32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLFloat as _,
bits: 32,
lanes: 1,
}
}
}

impl IntoDtype for f64 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLFloat as _,
bits: 64,
lanes: 1,
}
}
}

impl IntoDtype for i32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLInt as _,
bits: 32,
lanes: 1,
}
}
}

impl IntoDtype for u32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLUInt as _,
bits: 32,
lanes: 1,
}
}
}

impl ManagedTensor {
/// Create a non-owning view of a Tensor from a ndarray
pub fn from_ndarray<T: IntoDtype, S: ndarray::RawData<Elem = T>, D: ndarray::Dimension>(
Expand Down Expand Up @@ -92,7 +52,7 @@ impl ManagedTensor {
}
}

pub fn as_ptr(self) -> *mut ffi::DLManagedTensor {
pub fn as_ptr(&self) -> *mut ffi::DLManagedTensor {
&self.0 as *const _ as *mut _
}

Expand All @@ -106,7 +66,7 @@ impl ManagedTensor {
bytes
}

pub fn to_device(self) -> Result<ManagedTensor> {
pub fn to_device(&self) -> Result<ManagedTensor> {
unsafe {
let bytes = self.bytes();
let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut();
Expand All @@ -125,10 +85,30 @@ impl ManagedTensor {
// call cudaFree automatically to clean up data
ret.deleter = Some(cuda_free_tensor);
ret.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA;

Ok(ManagedTensor(ret))
}
}
pub fn to_host<
T: IntoDtype,
S: ndarray::RawData<Elem = T> + ndarray::RawDataMut,
D: ndarray::Dimension,
>(
&self,
arr: &mut ndarray::ArrayBase<S, D>,
) -> Result<()> {
unsafe {
let bytes = self.bytes();
check_cuda(ffi::cudaMemcpy(
arr.as_mut_ptr() as *mut std::ffi::c_void,
self.0.dl_tensor.data,
bytes,
ffi::cudaMemcpyKind_cudaMemcpyDefault,
))?;

Ok(())
}
}
}

unsafe extern "C" fn cuda_free_tensor(self_: *mut ffi::DLManagedTensor) {
Expand All @@ -145,6 +125,46 @@ impl Drop for ManagedTensor {
}
}

impl IntoDtype for f32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLFloat as _,
bits: 32,
lanes: 1,
}
}
}

impl IntoDtype for f64 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLFloat as _,
bits: 64,
lanes: 1,
}
}
}

impl IntoDtype for i32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLInt as _,
bits: 32,
lanes: 1,
}
}
}

impl IntoDtype for u32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLUInt as _,
bits: 32,
lanes: 1,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 47ea736

Please sign in to comment.