diff --git a/rust/cuvs/Cargo.toml b/rust/cuvs/Cargo.toml index b5ccc4961..462011a1e 100644 --- a/rust/cuvs/Cargo.toml +++ b/rust/cuvs/Cargo.toml @@ -10,3 +10,4 @@ license.workspace = true [dependencies] ffi = { package = "cuvs-sys", path = "../cuvs-sys" } +ndarray = "0.15" diff --git a/rust/cuvs/src/cagra/index.rs b/rust/cuvs/src/cagra/index.rs new file mode 100644 index 000000000..ef33d0e90 --- /dev/null +++ b/rust/cuvs/src/cagra/index.rs @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::io::{stderr, Write}; + +use crate::cagra::IndexParams; +use crate::dlpack::ManagedTensor; +use crate::error::{check_cuvs, Result}; +use crate::resources::Resources; + +#[derive(Debug)] +pub struct Index { + index: ffi::cagraIndex_t, +} + +impl Index { + /// Builds a new index + pub fn build(res: Resources, params: IndexParams, dataset: ManagedTensor) -> Result { + let index = Index::new()?; + unsafe { + check_cuvs(ffi::cagraBuild( + res.res, + params.params, + dataset.as_ptr(), + index.index, + ))?; + } + Ok(index) + } + + /// Creates a new empty index + pub fn new() -> Result { + unsafe { + let mut index = core::mem::MaybeUninit::::uninit(); + check_cuvs(ffi::cagraIndexCreate(index.as_mut_ptr()))?; + Ok(Index { + index: index.assume_init(), + }) + } + } +} + +impl Drop for Index { + fn drop(&mut self) { + if let Err(e) = check_cuvs(unsafe { ffi::cagraIndexDestroy(self.index) }) { + write!(stderr(), "failed to call cagraIndexDestroy {:?}", e) + .expect("failed to write to stderr"); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_empty_index() { + Index::new().unwrap(); + } + + #[test] + fn test_build() { + let res = Resources::new().unwrap(); + let params = IndexParams::new().unwrap(); + + // TODO: test a more exciting dataset + let arr = ndarray::Array::::zeros((128, 16)); + let dataset = ManagedTensor::from_ndarray(arr); + + let index = Index::build(res, params, dataset).expect("failed to create cagra index"); + } +} diff --git a/rust/cuvs/src/dlpack.rs b/rust/cuvs/src/dlpack.rs index a42528b74..90082b96a 100644 --- a/rust/cuvs/src/dlpack.rs +++ b/rust/cuvs/src/dlpack.rs @@ -14,13 +14,111 @@ * limitations under the License. */ -use std::io::{stderr, Write}; +#[derive(Debug)] +pub struct ManagedTensor(ffi::DLManagedTensor); -pub use ffi::DLDeviceType; -pub use ffi::DLDataTypeCode; -pub use ffi::DLDataType; +pub trait IntoDtype { + fn ffi_dtype() -> ffi::DLDataType; +} -#[derive(Debug)] -pub struct Tensor { - pub res: ffi::DLTensor, +impl IntoDtype for f32 { + fn ffi_dtype() -> ffi::DLDataType { + ffi::DLDataType { + code: ffi::DLDataTypeCode::kDLFloat as _, + bits: 32, + lanes: 1, + } + } +} + +impl IntoDtype for f64 { + fn ffi_dtype() -> ffi::DLDataType { + ffi::DLDataType { + code: ffi::DLDataTypeCode::kDLFloat as _, + bits: 64, + lanes: 1, + } + } +} + +impl IntoDtype for i32 { + fn ffi_dtype() -> ffi::DLDataType { + ffi::DLDataType { + code: ffi::DLDataTypeCode::kDLInt as _, + bits: 32, + lanes: 1, + } + } +} + +impl IntoDtype for u32 { + fn ffi_dtype() -> ffi::DLDataType { + ffi::DLDataType { + code: ffi::DLDataTypeCode::kDLUInt as _, + bits: 32, + lanes: 1, + } + } +} + +impl ManagedTensor { + /// Create a non-owning view of a Tensor from a ndarray + pub fn from_ndarray, D: ndarray::Dimension>( + arr: ndarray::ArrayBase, + ) -> ManagedTensor { + // There is a draft PR out right now for creating dlpack directly from ndarray + // right now, but until its merged we have to implement ourselves + //https://github.com/rust-ndarray/ndarray/pull/1306/files + unsafe { + let mut ret = core::mem::MaybeUninit::::uninit(); + let tensor = ret.as_mut_ptr(); + (*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void; + (*tensor).device = ffi::DLDevice { + device_type: ffi::DLDeviceType::kDLCPU, + device_id: 0, + }; + (*tensor).byte_offset = 0; + (*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor + (*tensor).ndim = arr.ndim() as i32; + (*tensor).shape = arr.shape().as_ptr() as *mut _; + (*tensor).dtype = T::ffi_dtype(); + ManagedTensor(ffi::DLManagedTensor { + dl_tensor: ret.assume_init(), + manager_ctx: std::ptr::null_mut(), + deleter: None, + }) + } + } + + pub fn as_ptr(self) -> *mut ffi::DLManagedTensor { + &self.0 as *const _ as *mut _ + } + + pub fn into_inner(self) -> ffi::DLManagedTensor { + self.0 + } +} + +impl Drop for ManagedTensor { + fn drop(&mut self) { + // TODO: if we have a deletr here, call it to free up the memory + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_ndarray() { + let arr = ndarray::Array::::zeros((8, 4)); + + let tensor = ManagedTensor::from_ndarray(arr).into_inner().dl_tensor; + + assert_eq!(tensor.ndim, 2); + + // make sure we can get the shape ok + assert_eq!(unsafe { *tensor.shape }, 8); + assert_eq!(unsafe { *tensor.shape.add(1) }, 4); + } } diff --git a/rust/cuvs/src/lib.rs b/rust/cuvs/src/lib.rs index ca954659c..3f6752848 100644 --- a/rust/cuvs/src/lib.rs +++ b/rust/cuvs/src/lib.rs @@ -15,9 +15,9 @@ */ pub mod cagra; +mod dlpack; mod error; mod resources; -mod dlpack; pub use error::{Error, Result}; pub use resources::Resources;