Skip to content

Commit

Permalink
support for building cagra index
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Feb 7, 2024
1 parent 7ddc9de commit 8a3ed55
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 8 deletions.
1 change: 1 addition & 0 deletions rust/cuvs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ license.workspace = true

[dependencies]
ffi = { package = "cuvs-sys", path = "../cuvs-sys" }
ndarray = "0.15"
85 changes: 85 additions & 0 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

use std::io::{stderr, Write};

use crate::cagra::IndexParams;
use crate::dlpack::ManagedTensor;
use crate::error::{check_cuvs, Result};
use crate::resources::Resources;

#[derive(Debug)]
pub struct Index {
index: ffi::cagraIndex_t,
}

impl Index {
/// Builds a new index
pub fn build(res: Resources, params: IndexParams, dataset: ManagedTensor) -> Result<Index> {
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cagraBuild(
res.res,
params.params,
dataset.as_ptr(),
index.index,
))?;
}
Ok(index)
}

/// Creates a new empty index
pub fn new() -> Result<Index> {
unsafe {
let mut index = core::mem::MaybeUninit::<ffi::cagraIndex_t>::uninit();
check_cuvs(ffi::cagraIndexCreate(index.as_mut_ptr()))?;
Ok(Index {
index: index.assume_init(),
})
}
}
}

impl Drop for Index {
fn drop(&mut self) {
if let Err(e) = check_cuvs(unsafe { ffi::cagraIndexDestroy(self.index) }) {
write!(stderr(), "failed to call cagraIndexDestroy {:?}", e)
.expect("failed to write to stderr");
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_create_empty_index() {
Index::new().unwrap();
}

#[test]
fn test_build() {
let res = Resources::new().unwrap();
let params = IndexParams::new().unwrap();

// TODO: test a more exciting dataset
let arr = ndarray::Array::<f32, _>::zeros((128, 16));
let dataset = ManagedTensor::from_ndarray(arr);

let index = Index::build(res, params, dataset).expect("failed to create cagra index");
}
}
112 changes: 105 additions & 7 deletions rust/cuvs/src/dlpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,111 @@
* limitations under the License.
*/

use std::io::{stderr, Write};
#[derive(Debug)]
pub struct ManagedTensor(ffi::DLManagedTensor);

pub use ffi::DLDeviceType;
pub use ffi::DLDataTypeCode;
pub use ffi::DLDataType;
pub trait IntoDtype {
fn ffi_dtype() -> ffi::DLDataType;
}

#[derive(Debug)]
pub struct Tensor {
pub res: ffi::DLTensor,
impl IntoDtype for f32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLFloat as _,
bits: 32,
lanes: 1,
}
}
}

impl IntoDtype for f64 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLFloat as _,
bits: 64,
lanes: 1,
}
}
}

impl IntoDtype for i32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLInt as _,
bits: 32,
lanes: 1,
}
}
}

impl IntoDtype for u32 {
fn ffi_dtype() -> ffi::DLDataType {
ffi::DLDataType {
code: ffi::DLDataTypeCode::kDLUInt as _,
bits: 32,
lanes: 1,
}
}
}

impl ManagedTensor {
/// Create a non-owning view of a Tensor from a ndarray
pub fn from_ndarray<T: IntoDtype, S: ndarray::RawData<Elem = T>, D: ndarray::Dimension>(
arr: ndarray::ArrayBase<S, D>,
) -> ManagedTensor {
// There is a draft PR out right now for creating dlpack directly from ndarray
// right now, but until its merged we have to implement ourselves
//https://github.com/rust-ndarray/ndarray/pull/1306/files
unsafe {
let mut ret = core::mem::MaybeUninit::<ffi::DLTensor>::uninit();
let tensor = ret.as_mut_ptr();
(*tensor).data = arr.as_ptr() as *mut std::os::raw::c_void;
(*tensor).device = ffi::DLDevice {
device_type: ffi::DLDeviceType::kDLCPU,
device_id: 0,
};
(*tensor).byte_offset = 0;
(*tensor).strides = std::ptr::null_mut(); // TODO: error if not rowmajor
(*tensor).ndim = arr.ndim() as i32;
(*tensor).shape = arr.shape().as_ptr() as *mut _;
(*tensor).dtype = T::ffi_dtype();
ManagedTensor(ffi::DLManagedTensor {
dl_tensor: ret.assume_init(),
manager_ctx: std::ptr::null_mut(),
deleter: None,
})
}
}

pub fn as_ptr(self) -> *mut ffi::DLManagedTensor {
&self.0 as *const _ as *mut _
}

pub fn into_inner(self) -> ffi::DLManagedTensor {
self.0
}
}

impl Drop for ManagedTensor {
fn drop(&mut self) {
// TODO: if we have a deletr here, call it to free up the memory
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_from_ndarray() {
let arr = ndarray::Array::<f32, _>::zeros((8, 4));

let tensor = ManagedTensor::from_ndarray(arr).into_inner().dl_tensor;

assert_eq!(tensor.ndim, 2);

// make sure we can get the shape ok
assert_eq!(unsafe { *tensor.shape }, 8);
assert_eq!(unsafe { *tensor.shape.add(1) }, 4);
}
}
2 changes: 1 addition & 1 deletion rust/cuvs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/

pub mod cagra;
mod dlpack;
mod error;
mod resources;
mod dlpack;

pub use error::{Error, Result};
pub use resources::Resources;

0 comments on commit 8a3ed55

Please sign in to comment.