From 346f6cfa7e28f97bbc02e60fc3a3850e76acd19d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 12 Feb 2024 14:19:33 -0800 Subject: [PATCH] working search unittest --- rust/cuvs-sys/build.rs | 5 +++ rust/cuvs/Cargo.toml | 3 ++ rust/cuvs/src/cagra/index.rs | 62 ++++++++++++++++++++++++---- rust/cuvs/src/cagra/index_params.rs | 29 ++++++------- rust/cuvs/src/cagra/search_params.rs | 44 ++++++++++---------- rust/cuvs/src/dlpack.rs | 47 ++++++++++++++++++--- rust/cuvs/src/error.rs | 21 ++++++++-- rust/cuvs/src/resources.rs | 8 ++-- 8 files changed, 158 insertions(+), 61 deletions(-) diff --git a/rust/cuvs-sys/build.rs b/rust/cuvs-sys/build.rs index 687717489..816a6f259 100644 --- a/rust/cuvs-sys/build.rs +++ b/rust/cuvs-sys/build.rs @@ -39,6 +39,7 @@ fn main() { cuvs_build.display() ); println!("cargo:rustc-link-lib=dylib=cuvs_c"); + println!("cargo:rustc-link-lib=dylib=cudart"); // we need some extra flags both to link against cuvs, and also to run bindgen // specifically we need to: @@ -100,6 +101,10 @@ fn main() { .allowlist_type("(cuvs|cagra|DL).*") .allowlist_function("(cuvs|cagra).*") .rustified_enum("(cuvs|cagra|DL).*") + // also need some basic cuda mem functions + // (TODO: should we be adding in RMM support instead here?) + .allowlist_function("(cudaMalloc|cudaFree|cudaMemcpy)") + .rustified_enum("cudaError") .generate() .expect("Unable to generate cagra_c bindings") .write_to_file(out_path.join("cuvs_bindings.rs")) diff --git a/rust/cuvs/Cargo.toml b/rust/cuvs/Cargo.toml index 462011a1e..cc52db026 100644 --- a/rust/cuvs/Cargo.toml +++ b/rust/cuvs/Cargo.toml @@ -11,3 +11,6 @@ license.workspace = true [dependencies] ffi = { package = "cuvs-sys", path = "../cuvs-sys" } ndarray = "0.15" + +[dev-dependencies] +ndarray-rand = "*" diff --git a/rust/cuvs/src/cagra/index.rs b/rust/cuvs/src/cagra/index.rs index ef33d0e90..1151b4b42 100644 --- a/rust/cuvs/src/cagra/index.rs +++ b/rust/cuvs/src/cagra/index.rs @@ -16,7 +16,7 @@ use std::io::{stderr, Write}; -use crate::cagra::IndexParams; +use crate::cagra::{IndexParams, SearchParams}; use crate::dlpack::ManagedTensor; use crate::error::{check_cuvs, Result}; use crate::resources::Resources; @@ -28,12 +28,12 @@ pub struct Index { impl Index { /// Builds a new index - pub fn build(res: Resources, params: IndexParams, dataset: ManagedTensor) -> Result { + pub fn build(res: &Resources, params: &IndexParams, dataset: ManagedTensor) -> Result { let index = Index::new()?; unsafe { check_cuvs(ffi::cagraBuild( - res.res, - params.params, + res.0, + params.0, dataset.as_ptr(), index.index, ))?; @@ -51,6 +51,26 @@ impl Index { }) } } + + pub fn search( + self, + res: &Resources, + params: &SearchParams, + queries: ManagedTensor, + neighbors: ManagedTensor, + distances: ManagedTensor, + ) -> Result<()> { + unsafe { + check_cuvs(ffi::cagraSearch( + res.0, + params.0, + self.index, + queries.as_ptr(), + neighbors.as_ptr(), + distances.as_ptr(), + )) + } + } } impl Drop for Index { @@ -65,6 +85,9 @@ impl Drop for Index { #[cfg(test)] mod tests { use super::*; + use ndarray::s; + use ndarray_rand::rand_distr::Uniform; + use ndarray_rand::RandomExt; #[test] fn test_create_empty_index() { @@ -72,14 +95,35 @@ mod tests { } #[test] - fn test_build() { + fn test_index() { let res = Resources::new().unwrap(); let params = IndexParams::new().unwrap(); - // TODO: test a more exciting dataset - let arr = ndarray::Array::::zeros((128, 16)); - let dataset = ManagedTensor::from_ndarray(arr); + let n_features = 16; + let dataset = ndarray::Array::::random((256, n_features), Uniform::new(0., 1.0)); + let index = Index::build(&res, ¶ms, ManagedTensor::from_ndarray(&dataset)) + .expect("failed to create cagra index"); + + // use the first 4 points from the dataset as queries : will test that we get them back + // as their own nearest neighbor + let n_queries = 4; + let queries = dataset.slice(s![0..n_queries, ..]); + let queries = ManagedTensor::from_ndarray(&queries).to_device().unwrap(); + + let k = 10; + let neighbors = + ManagedTensor::from_ndarray(&ndarray::Array::::zeros((n_queries, k))) + .to_device() + .unwrap(); + let distances = + ManagedTensor::from_ndarray(&ndarray::Array::::zeros((n_queries, k))) + .to_device() + .unwrap(); + + let search_params = SearchParams::new().unwrap(); - let index = Index::build(res, params, dataset).expect("failed to create cagra index"); + index + .search(&res, &search_params, queries, neighbors, distances) + .unwrap(); } } diff --git a/rust/cuvs/src/cagra/index_params.rs b/rust/cuvs/src/cagra/index_params.rs index eecb3b61a..4d2e57c35 100644 --- a/rust/cuvs/src/cagra/index_params.rs +++ b/rust/cuvs/src/cagra/index_params.rs @@ -21,25 +21,21 @@ use std::io::{stderr, Write}; pub type BuildAlgo = ffi::cagraGraphBuildAlgo; /// Supplemental parameters to build CAGRA Index -pub struct IndexParams { - pub params: ffi::cuvsCagraIndexParams_t, -} +pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t); impl IndexParams { pub fn new() -> Result { unsafe { let mut params = core::mem::MaybeUninit::::uninit(); check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?; - Ok(IndexParams { - params: params.assume_init(), - }) + Ok(IndexParams(params.assume_init())) } } /// Degree of input graph for pruning pub fn set_intermediate_graph_degree(self, intermediate_graph_degree: usize) -> IndexParams { unsafe { - (*self.params).intermediate_graph_degree = intermediate_graph_degree; + (*self.0).intermediate_graph_degree = intermediate_graph_degree; } self } @@ -47,7 +43,7 @@ impl IndexParams { /// Degree of output graph pub fn set_graph_degree(self, graph_degree: usize) -> IndexParams { unsafe { - (*self.params).graph_degree = graph_degree; + (*self.0).graph_degree = graph_degree; } self } @@ -55,7 +51,7 @@ impl IndexParams { /// ANN algorithm to build knn graph pub fn set_build_algo(self, build_algo: BuildAlgo) -> IndexParams { unsafe { - (*self.params).build_algo = build_algo; + (*self.0).build_algo = build_algo; } self } @@ -63,7 +59,7 @@ impl IndexParams { /// Number of iterations to run if building with NN_DESCENT pub fn set_nn_descent_niter(self, nn_descent_niter: usize) -> IndexParams { unsafe { - (*self.params).nn_descent_niter = nn_descent_niter; + (*self.0).nn_descent_niter = nn_descent_niter; } self } @@ -73,13 +69,13 @@ impl fmt::Debug for IndexParams { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // custom debug trait here, default value will show the pointer address // for the inner params object which isn't that useful. - write!(f, "IndexParams {{ params: {:?} }}", unsafe { *self.params }) + write!(f, "IndexParams {{ params: {:?} }}", unsafe { *self.0 }) } } impl Drop for IndexParams { fn drop(&mut self) { - if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.params) }) { + if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.0) }) { write!( stderr(), "failed to call cuvsCagraIndexParamsDestroy {:?}", @@ -103,7 +99,12 @@ mod tests { .set_build_algo(BuildAlgo::NN_DESCENT) .set_nn_descent_niter(10); - // make sure the setters actually updated internal representation - assert_eq!(format!("{:?}", params), "IndexParams { params: cagraIndexParams { intermediate_graph_degree: 128, graph_degree: 16, build_algo: NN_DESCENT, nn_descent_niter: 10 } }"); + // make sure the setters actually updated internal representation on the c-struct + unsafe { + assert_eq!((*params.0).graph_degree, 16); + assert_eq!((*params.0).intermediate_graph_degree, 128); + assert_eq!((*params.0).build_algo, BuildAlgo::NN_DESCENT); + assert_eq!((*params.0).nn_descent_niter, 10); + } } } diff --git a/rust/cuvs/src/cagra/search_params.rs b/rust/cuvs/src/cagra/search_params.rs index f4bfd77b2..d32d6783c 100644 --- a/rust/cuvs/src/cagra/search_params.rs +++ b/rust/cuvs/src/cagra/search_params.rs @@ -22,25 +22,21 @@ pub type SearchAlgo = ffi::cagraSearchAlgo; pub type HashMode = ffi::cagraHashMode; /// Supplemental parameters to search CAGRA index -pub struct SearchParams { - pub params: ffi::cuvsCagraSearchParams_t, -} +pub struct SearchParams(pub ffi::cuvsCagraSearchParams_t); impl SearchParams { pub fn new() -> Result { unsafe { let mut params = core::mem::MaybeUninit::::uninit(); check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?; - Ok(SearchParams { - params: params.assume_init(), - }) + Ok(SearchParams(params.assume_init())) } } /// Maximum number of queries to search at the same time (batch size). Auto select when 0 pub fn set_max_queries(self, max_queries: usize) -> SearchParams { unsafe { - (*self.params).max_queries = max_queries; + (*self.0).max_queries = max_queries; } self } @@ -50,7 +46,7 @@ impl SearchParams { /// Higher values improve the search accuracy pub fn set_itopk_size(self, itopk_size: usize) -> SearchParams { unsafe { - (*self.params).itopk_size = itopk_size; + (*self.0).itopk_size = itopk_size; } self } @@ -58,7 +54,7 @@ impl SearchParams { /// Upper limit of search iterations. Auto select when 0. pub fn set_max_iterations(self, max_iterations: usize) -> SearchParams { unsafe { - (*self.params).max_iterations = max_iterations; + (*self.0).max_iterations = max_iterations; } self } @@ -66,7 +62,7 @@ impl SearchParams { /// Which search implementation to use. pub fn set_algo(self, algo: SearchAlgo) -> SearchParams { unsafe { - (*self.params).algo = algo; + (*self.0).algo = algo; } self } @@ -74,7 +70,7 @@ impl SearchParams { /// Number of threads used to calculate a single distance. 4, 8, 16, or 32. pub fn set_team_size(self, team_size: usize) -> SearchParams { unsafe { - (*self.params).team_size = team_size; + (*self.0).team_size = team_size; } self } @@ -82,7 +78,7 @@ impl SearchParams { /// Lower limit of search iterations. pub fn set_min_iterations(self, min_iterations: usize) -> SearchParams { unsafe { - (*self.params).min_iterations = min_iterations; + (*self.0).min_iterations = min_iterations; } self } @@ -90,7 +86,7 @@ impl SearchParams { /// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. pub fn set_thread_block_size(self, thread_block_size: usize) -> SearchParams { unsafe { - (*self.params).thread_block_size = thread_block_size; + (*self.0).thread_block_size = thread_block_size; } self } @@ -98,7 +94,7 @@ impl SearchParams { /// Hashmap type. Auto selection when AUTO. pub fn set_hashmap_mode(self, hashmap_mode: HashMode) -> SearchParams { unsafe { - (*self.params).hashmap_mode = hashmap_mode; + (*self.0).hashmap_mode = hashmap_mode; } self } @@ -106,7 +102,7 @@ impl SearchParams { /// Lower limit of hashmap bit length. More than 8. pub fn set_hashmap_min_bitlen(self, hashmap_min_bitlen: usize) -> SearchParams { unsafe { - (*self.params).hashmap_min_bitlen = hashmap_min_bitlen; + (*self.0).hashmap_min_bitlen = hashmap_min_bitlen; } self } @@ -114,7 +110,7 @@ impl SearchParams { /// Upper limit of hashmap fill rate. More than 0.1, less than 0.9. pub fn set_hashmap_max_fill_rate(self, hashmap_max_fill_rate: f32) -> SearchParams { unsafe { - (*self.params).hashmap_max_fill_rate = hashmap_max_fill_rate; + (*self.0).hashmap_max_fill_rate = hashmap_max_fill_rate; } self } @@ -122,7 +118,7 @@ impl SearchParams { /// Number of iterations of initial random seed node selection. 1 or more. pub fn set_num_random_samplings(self, num_random_samplings: u32) -> SearchParams { unsafe { - (*self.params).num_random_samplings = num_random_samplings; + (*self.0).num_random_samplings = num_random_samplings; } self } @@ -130,7 +126,7 @@ impl SearchParams { /// Bit mask used for initial random seed node selection. pub fn set_rand_xor_mask(self, rand_xor_mask: u64) -> SearchParams { unsafe { - (*self.params).rand_xor_mask = rand_xor_mask; + (*self.0).rand_xor_mask = rand_xor_mask; } self } @@ -140,15 +136,13 @@ impl fmt::Debug for SearchParams { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // custom debug trait here, default value will show the pointer address // for the inner params object which isn't that useful. - write!(f, "SearchParams {{ params: {:?} }}", unsafe { - *self.params - }) + write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 }) } } impl Drop for SearchParams { fn drop(&mut self) { - if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.params) }) { + if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.0) }) { write!( stderr(), "failed to call cuvsCagraSearchParamsDestroy {:?}", @@ -165,6 +159,10 @@ mod tests { #[test] fn test_search_params() { - let params = SearchParams::new().unwrap(); + let params = SearchParams::new().unwrap().set_itopk_size(128); + + unsafe { + assert_eq!((*params.0).itopk_size, 128); + } } } diff --git a/rust/cuvs/src/dlpack.rs b/rust/cuvs/src/dlpack.rs index 90082b96a..e0551c585 100644 --- a/rust/cuvs/src/dlpack.rs +++ b/rust/cuvs/src/dlpack.rs @@ -14,6 +14,8 @@ * limitations under the License. */ +use crate::error::{check_cuda, Result}; + #[derive(Debug)] pub struct ManagedTensor(ffi::DLManagedTensor); @@ -64,13 +66,13 @@ impl IntoDtype for u32 { impl ManagedTensor { /// Create a non-owning view of a Tensor from a ndarray pub fn from_ndarray, D: ndarray::Dimension>( - arr: ndarray::ArrayBase, + arr: &ndarray::ArrayBase, ) -> ManagedTensor { // There is a draft PR out right now for creating dlpack directly from ndarray // right now, but until its merged we have to implement ourselves //https://github.com/rust-ndarray/ndarray/pull/1306/files unsafe { - let mut ret = core::mem::MaybeUninit::::uninit(); + let mut ret = std::mem::MaybeUninit::::uninit(); let tensor = ret.as_mut_ptr(); (*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void; (*tensor).device = ffi::DLDevice { @@ -94,14 +96,47 @@ impl ManagedTensor { &self.0 as *const _ as *mut _ } - pub fn into_inner(self) -> ffi::DLManagedTensor { - self.0 + fn bytes(&self) -> usize { + // figure out how many bytes to allocate + let mut bytes: usize = 1; + for x in 0..self.0.dl_tensor.ndim { + bytes *= unsafe { (*self.0.dl_tensor.shape.add(x as usize)) as usize }; + } + bytes *= (self.0.dl_tensor.dtype.bits / 8) as usize; + bytes + } + + pub fn to_device(self) -> Result { + unsafe { + let bytes = self.bytes(); + let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut(); + + // allocate storage, copy over + check_cuda(ffi::cudaMalloc(&mut device_data as *mut _, bytes))?; + check_cuda(ffi::cudaMemcpy( + device_data, + self.0.dl_tensor.data, + bytes, + ffi::cudaMemcpyKind_cudaMemcpyDefault, + ))?; + + let mut ret = self.0.clone(); + ret.dl_tensor.data = device_data; + ret.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA; + // TODO: do we need to set the device id here too? + // TODO: set deleter here to call cudaFree + Ok(ManagedTensor(ret)) + } } } impl Drop for ManagedTensor { fn drop(&mut self) { - // TODO: if we have a deletr here, call it to free up the memory + unsafe { + if let Some(deleter) = self.0.deleter { + deleter(&mut self.0 as *mut _); + } + } } } @@ -113,7 +148,7 @@ mod tests { fn test_from_ndarray() { let arr = ndarray::Array::::zeros((8, 4)); - let tensor = ManagedTensor::from_ndarray(arr).into_inner().dl_tensor; + let tensor = unsafe { (*(ManagedTensor::from_ndarray(&arr).as_ptr())).dl_tensor }; assert_eq!(tensor.ndim, 2); diff --git a/rust/cuvs/src/error.rs b/rust/cuvs/src/error.rs index 0a2db20e6..618106aba 100644 --- a/rust/cuvs/src/error.rs +++ b/rust/cuvs/src/error.rs @@ -17,15 +17,21 @@ use std::fmt; #[derive(Debug, Clone)] -pub struct Error { - err: ffi::cuvsError_t, +pub enum Error { + CudaError(ffi::cudaError_t), + CuvsError(ffi::cuvsError_t), } +impl std::error::Error for Error {} + pub type Result = std::result::Result; impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "cuvsError={:?}", self.err) + match self { + Error::CudaError(cuda_error) => write!(f, "cudaError={:?}", cuda_error), + Error::CuvsError(cuvs_error) => write!(f, "cuvsError={:?}", cuvs_error), + } } } @@ -33,6 +39,13 @@ impl fmt::Display for Error { pub fn check_cuvs(err: ffi::cuvsError_t) -> Result<()> { match err { ffi::cuvsError_t::CUVS_SUCCESS => Ok(()), - _ => Err(Error { err }), + _ => Err(Error::CuvsError(err)), + } +} + +pub fn check_cuda(err: ffi::cudaError_t) -> Result<()> { + match err { + ffi::cudaError::cudaSuccess => Ok(()), + _ => Err(Error::CudaError(err)), } } diff --git a/rust/cuvs/src/resources.rs b/rust/cuvs/src/resources.rs index 899b058b1..ad7113e6b 100644 --- a/rust/cuvs/src/resources.rs +++ b/rust/cuvs/src/resources.rs @@ -18,9 +18,7 @@ use crate::error::{check_cuvs, Result}; use std::io::{stderr, Write}; #[derive(Debug)] -pub struct Resources { - pub res: ffi::cuvsResources_t, -} +pub struct Resources(pub ffi::cuvsResources_t); impl Resources { pub fn new() -> Result { @@ -28,14 +26,14 @@ impl Resources { unsafe { check_cuvs(ffi::cuvsResourcesCreate(&mut res))?; } - Ok(Resources { res }) + Ok(Resources(res)) } } impl Drop for Resources { fn drop(&mut self) { unsafe { - if let Err(e) = check_cuvs(ffi::cuvsResourcesDestroy(self.res)) { + if let Err(e) = check_cuvs(ffi::cuvsResourcesDestroy(self.0)) { write!(stderr(), "failed to call cuvsResourcesDestroy {:?}", e) .expect("failed to write to stderr"); }