Skip to content

Commit

Permalink
working search unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Feb 12, 2024
1 parent 8a3ed55 commit 346f6cf
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 61 deletions.
5 changes: 5 additions & 0 deletions rust/cuvs-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ fn main() {
cuvs_build.display()
);
println!("cargo:rustc-link-lib=dylib=cuvs_c");
println!("cargo:rustc-link-lib=dylib=cudart");

// we need some extra flags both to link against cuvs, and also to run bindgen
// specifically we need to:
Expand Down Expand Up @@ -100,6 +101,10 @@ fn main() {
.allowlist_type("(cuvs|cagra|DL).*")
.allowlist_function("(cuvs|cagra).*")
.rustified_enum("(cuvs|cagra|DL).*")
// also need some basic cuda mem functions
// (TODO: should we be adding in RMM support instead here?)
.allowlist_function("(cudaMalloc|cudaFree|cudaMemcpy)")
.rustified_enum("cudaError")
.generate()
.expect("Unable to generate cagra_c bindings")
.write_to_file(out_path.join("cuvs_bindings.rs"))
Expand Down
3 changes: 3 additions & 0 deletions rust/cuvs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ license.workspace = true
[dependencies]
ffi = { package = "cuvs-sys", path = "../cuvs-sys" }
ndarray = "0.15"

[dev-dependencies]
ndarray-rand = "*"
62 changes: 53 additions & 9 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

use std::io::{stderr, Write};

use crate::cagra::IndexParams;
use crate::cagra::{IndexParams, SearchParams};
use crate::dlpack::ManagedTensor;
use crate::error::{check_cuvs, Result};
use crate::resources::Resources;
Expand All @@ -28,12 +28,12 @@ pub struct Index {

impl Index {
/// Builds a new index
pub fn build(res: Resources, params: IndexParams, dataset: ManagedTensor) -> Result<Index> {
pub fn build(res: &Resources, params: &IndexParams, dataset: ManagedTensor) -> Result<Index> {
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cagraBuild(
res.res,
params.params,
res.0,
params.0,
dataset.as_ptr(),
index.index,
))?;
Expand All @@ -51,6 +51,26 @@ impl Index {
})
}
}

pub fn search(
self,
res: &Resources,
params: &SearchParams,
queries: ManagedTensor,
neighbors: ManagedTensor,
distances: ManagedTensor,
) -> Result<()> {
unsafe {
check_cuvs(ffi::cagraSearch(
res.0,
params.0,
self.index,
queries.as_ptr(),
neighbors.as_ptr(),
distances.as_ptr(),
))
}
}
}

impl Drop for Index {
Expand All @@ -65,21 +85,45 @@ impl Drop for Index {
#[cfg(test)]
mod tests {
use super::*;
use ndarray::s;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;

#[test]
fn test_create_empty_index() {
Index::new().unwrap();
}

#[test]
fn test_build() {
fn test_index() {
let res = Resources::new().unwrap();
let params = IndexParams::new().unwrap();

// TODO: test a more exciting dataset
let arr = ndarray::Array::<f32, _>::zeros((128, 16));
let dataset = ManagedTensor::from_ndarray(arr);
let n_features = 16;
let dataset = ndarray::Array::<f32, _>::random((256, n_features), Uniform::new(0., 1.0));
let index = Index::build(&res, &params, ManagedTensor::from_ndarray(&dataset))
.expect("failed to create cagra index");

// use the first 4 points from the dataset as queries : will test that we get them back
// as their own nearest neighbor
let n_queries = 4;
let queries = dataset.slice(s![0..n_queries, ..]);
let queries = ManagedTensor::from_ndarray(&queries).to_device().unwrap();

let k = 10;
let neighbors =
ManagedTensor::from_ndarray(&ndarray::Array::<u32, _>::zeros((n_queries, k)))
.to_device()
.unwrap();
let distances =
ManagedTensor::from_ndarray(&ndarray::Array::<f32, _>::zeros((n_queries, k)))
.to_device()
.unwrap();

let search_params = SearchParams::new().unwrap();

let index = Index::build(res, params, dataset).expect("failed to create cagra index");
index
.search(&res, &search_params, queries, neighbors, distances)
.unwrap();
}
}
29 changes: 15 additions & 14 deletions rust/cuvs/src/cagra/index_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,49 +21,45 @@ use std::io::{stderr, Write};
pub type BuildAlgo = ffi::cagraGraphBuildAlgo;

/// Supplemental parameters to build CAGRA Index
pub struct IndexParams {
pub params: ffi::cuvsCagraIndexParams_t,
}
pub struct IndexParams(pub ffi::cuvsCagraIndexParams_t);

impl IndexParams {
pub fn new() -> Result<IndexParams> {
unsafe {
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraIndexParams_t>::uninit();
check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?;
Ok(IndexParams {
params: params.assume_init(),
})
Ok(IndexParams(params.assume_init()))
}
}

/// Degree of input graph for pruning
pub fn set_intermediate_graph_degree(self, intermediate_graph_degree: usize) -> IndexParams {
unsafe {
(*self.params).intermediate_graph_degree = intermediate_graph_degree;
(*self.0).intermediate_graph_degree = intermediate_graph_degree;
}
self
}

/// Degree of output graph
pub fn set_graph_degree(self, graph_degree: usize) -> IndexParams {
unsafe {
(*self.params).graph_degree = graph_degree;
(*self.0).graph_degree = graph_degree;
}
self
}

/// ANN algorithm to build knn graph
pub fn set_build_algo(self, build_algo: BuildAlgo) -> IndexParams {
unsafe {
(*self.params).build_algo = build_algo;
(*self.0).build_algo = build_algo;
}
self
}

/// Number of iterations to run if building with NN_DESCENT
pub fn set_nn_descent_niter(self, nn_descent_niter: usize) -> IndexParams {
unsafe {
(*self.params).nn_descent_niter = nn_descent_niter;
(*self.0).nn_descent_niter = nn_descent_niter;
}
self
}
Expand All @@ -73,13 +69,13 @@ impl fmt::Debug for IndexParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// custom debug trait here, default value will show the pointer address
// for the inner params object which isn't that useful.
write!(f, "IndexParams {{ params: {:?} }}", unsafe { *self.params })
write!(f, "IndexParams {{ params: {:?} }}", unsafe { *self.0 })
}
}

impl Drop for IndexParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.params) }) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraIndexParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsCagraIndexParamsDestroy {:?}",
Expand All @@ -103,7 +99,12 @@ mod tests {
.set_build_algo(BuildAlgo::NN_DESCENT)
.set_nn_descent_niter(10);

// make sure the setters actually updated internal representation
assert_eq!(format!("{:?}", params), "IndexParams { params: cagraIndexParams { intermediate_graph_degree: 128, graph_degree: 16, build_algo: NN_DESCENT, nn_descent_niter: 10 } }");
// make sure the setters actually updated internal representation on the c-struct
unsafe {
assert_eq!((*params.0).graph_degree, 16);
assert_eq!((*params.0).intermediate_graph_degree, 128);
assert_eq!((*params.0).build_algo, BuildAlgo::NN_DESCENT);
assert_eq!((*params.0).nn_descent_niter, 10);
}
}
}
44 changes: 21 additions & 23 deletions rust/cuvs/src/cagra/search_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,21 @@ pub type SearchAlgo = ffi::cagraSearchAlgo;
pub type HashMode = ffi::cagraHashMode;

/// Supplemental parameters to search CAGRA index
pub struct SearchParams {
pub params: ffi::cuvsCagraSearchParams_t,
}
pub struct SearchParams(pub ffi::cuvsCagraSearchParams_t);

impl SearchParams {
pub fn new() -> Result<SearchParams> {
unsafe {
let mut params = core::mem::MaybeUninit::<ffi::cuvsCagraSearchParams_t>::uninit();
check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?;
Ok(SearchParams {
params: params.assume_init(),
})
Ok(SearchParams(params.assume_init()))
}
}

/// Maximum number of queries to search at the same time (batch size). Auto select when 0
pub fn set_max_queries(self, max_queries: usize) -> SearchParams {
unsafe {
(*self.params).max_queries = max_queries;
(*self.0).max_queries = max_queries;
}
self
}
Expand All @@ -50,87 +46,87 @@ impl SearchParams {
/// Higher values improve the search accuracy
pub fn set_itopk_size(self, itopk_size: usize) -> SearchParams {
unsafe {
(*self.params).itopk_size = itopk_size;
(*self.0).itopk_size = itopk_size;
}
self
}

/// Upper limit of search iterations. Auto select when 0.
pub fn set_max_iterations(self, max_iterations: usize) -> SearchParams {
unsafe {
(*self.params).max_iterations = max_iterations;
(*self.0).max_iterations = max_iterations;
}
self
}

/// Which search implementation to use.
pub fn set_algo(self, algo: SearchAlgo) -> SearchParams {
unsafe {
(*self.params).algo = algo;
(*self.0).algo = algo;
}
self
}

/// Number of threads used to calculate a single distance. 4, 8, 16, or 32.
pub fn set_team_size(self, team_size: usize) -> SearchParams {
unsafe {
(*self.params).team_size = team_size;
(*self.0).team_size = team_size;
}
self
}

/// Lower limit of search iterations.
pub fn set_min_iterations(self, min_iterations: usize) -> SearchParams {
unsafe {
(*self.params).min_iterations = min_iterations;
(*self.0).min_iterations = min_iterations;
}
self
}

/// Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0.
pub fn set_thread_block_size(self, thread_block_size: usize) -> SearchParams {
unsafe {
(*self.params).thread_block_size = thread_block_size;
(*self.0).thread_block_size = thread_block_size;
}
self
}

/// Hashmap type. Auto selection when AUTO.
pub fn set_hashmap_mode(self, hashmap_mode: HashMode) -> SearchParams {
unsafe {
(*self.params).hashmap_mode = hashmap_mode;
(*self.0).hashmap_mode = hashmap_mode;
}
self
}

/// Lower limit of hashmap bit length. More than 8.
pub fn set_hashmap_min_bitlen(self, hashmap_min_bitlen: usize) -> SearchParams {
unsafe {
(*self.params).hashmap_min_bitlen = hashmap_min_bitlen;
(*self.0).hashmap_min_bitlen = hashmap_min_bitlen;
}
self
}

/// Upper limit of hashmap fill rate. More than 0.1, less than 0.9.
pub fn set_hashmap_max_fill_rate(self, hashmap_max_fill_rate: f32) -> SearchParams {
unsafe {
(*self.params).hashmap_max_fill_rate = hashmap_max_fill_rate;
(*self.0).hashmap_max_fill_rate = hashmap_max_fill_rate;
}
self
}

/// Number of iterations of initial random seed node selection. 1 or more.
pub fn set_num_random_samplings(self, num_random_samplings: u32) -> SearchParams {
unsafe {
(*self.params).num_random_samplings = num_random_samplings;
(*self.0).num_random_samplings = num_random_samplings;
}
self
}

/// Bit mask used for initial random seed node selection.
pub fn set_rand_xor_mask(self, rand_xor_mask: u64) -> SearchParams {
unsafe {
(*self.params).rand_xor_mask = rand_xor_mask;
(*self.0).rand_xor_mask = rand_xor_mask;
}
self
}
Expand All @@ -140,15 +136,13 @@ impl fmt::Debug for SearchParams {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
// custom debug trait here, default value will show the pointer address
// for the inner params object which isn't that useful.
write!(f, "SearchParams {{ params: {:?} }}", unsafe {
*self.params
})
write!(f, "SearchParams {{ params: {:?} }}", unsafe { *self.0 })
}
}

impl Drop for SearchParams {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.params) }) {
if let Err(e) = check_cuvs(unsafe { ffi::cuvsCagraSearchParamsDestroy(self.0) }) {
write!(
stderr(),
"failed to call cuvsCagraSearchParamsDestroy {:?}",
Expand All @@ -165,6 +159,10 @@ mod tests {

#[test]
fn test_search_params() {
let params = SearchParams::new().unwrap();
let params = SearchParams::new().unwrap().set_itopk_size(128);

unsafe {
assert_eq!((*params.0).itopk_size, 128);
}
}
}
Loading

0 comments on commit 346f6cf

Please sign in to comment.