Skip to content

Commit

Permalink
Add resources to to_host/to_device functions
Browse files Browse the repository at this point in the history
eventually we'll want the stream arg on these
  • Loading branch information
benfred committed Feb 20, 2024
1 parent d6ea993 commit 7d5174a
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
10 changes: 5 additions & 5 deletions rust/cuvs/examples/cagra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ fn cagra_example() -> Result<()> {
// CAGRA search API requires queries and outputs to be on device memory
// copy query data over, and allocate new device memory for the distances/ neighbors
// outputs
let queries = ManagedTensor::from_ndarray(&queries).to_device()?;
let queries = ManagedTensor::from_ndarray(&queries).to_device(&res)?;
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from_ndarray(&neighbors_host).to_device()?;
let neighbors = ManagedTensor::from_ndarray(&neighbors_host).to_device(&res)?;

let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from_ndarray(&distances_host).to_device()?;
let distances = ManagedTensor::from_ndarray(&distances_host).to_device(&res)?;

let search_params = SearchParams::new()?;

index.search(&res, &search_params, &queries, &neighbors, &distances)?;

// Copy back to host memory
distances.to_host(&mut distances_host)?;
neighbors.to_host(&mut neighbors_host)?;
distances.to_host(&res, &mut distances_host)?;
neighbors.to_host(&res, &mut neighbors_host)?;

// nearest neighbors should be themselves, since queries are from the
// dataset
Expand Down
10 changes: 5 additions & 5 deletions rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ mod tests {
// CAGRA search API requires queries and outputs to be on device memory
// copy query data over, and allocate new device memory for the distances/ neighbors
// outputs
let queries = ManagedTensor::from_ndarray(&queries).to_device().unwrap();
let queries = ManagedTensor::from_ndarray(&queries).to_device(&res).unwrap();
let mut neighbors_host = ndarray::Array::<u32, _>::zeros((n_queries, k));
let neighbors = ManagedTensor::from_ndarray(&neighbors_host)
.to_device()
.to_device(&res)
.unwrap();

let mut distances_host = ndarray::Array::<f32, _>::zeros((n_queries, k));
let distances = ManagedTensor::from_ndarray(&distances_host)
.to_device()
.to_device(&res)
.unwrap();

let search_params = SearchParams::new().unwrap();
Expand All @@ -123,8 +123,8 @@ mod tests {
.unwrap();

// Copy back to host memory
distances.to_host(&mut distances_host).unwrap();
neighbors.to_host(&mut neighbors_host).unwrap();
distances.to_host(&res, &mut distances_host).unwrap();
neighbors.to_host(&res, &mut neighbors_host).unwrap();

// nearest neighbors should be themselves, since queries are from the
// dataset
Expand Down
4 changes: 3 additions & 1 deletion rust/cuvs/src/dlpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

use crate::error::{check_cuda, Result};
use crate::resources::Resources;

#[derive(Debug)]
pub struct ManagedTensor(ffi::DLManagedTensor);
Expand Down Expand Up @@ -66,7 +67,7 @@ impl ManagedTensor {
bytes
}

pub fn to_device(&self) -> Result<ManagedTensor> {
pub fn to_device(&self, _res: &Resources) -> Result<ManagedTensor> {
unsafe {
let bytes = self.bytes();
let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut();
Expand Down Expand Up @@ -95,6 +96,7 @@ impl ManagedTensor {
D: ndarray::Dimension,
>(
&self,
_res: &Resources,
arr: &mut ndarray::ArrayBase<S, D>,
) -> Result<()> {
unsafe {
Expand Down

0 comments on commit 7d5174a

Please sign in to comment.