diff --git a/ci/build_rust.sh b/ci/build_rust.sh index 59f0924c8..ad2202734 100755 --- a/ci/build_rust.sh +++ b/ci/build_rust.sh @@ -35,5 +35,4 @@ rapids-mamba-retry install \ libcuvs \ libraft -export EXTRA_CMAKE_ARGS="" bash ./build.sh rust diff --git a/cpp/include/cuvs/core/c_api.h b/cpp/include/cuvs/core/c_api.h index a15d7cd5c..d931d6c13 100644 --- a/cpp/include/cuvs/core/c_api.h +++ b/cpp/include/cuvs/core/c_api.h @@ -83,6 +83,50 @@ cuvsError_t cuvsResourcesDestroy(cuvsResources_t res); */ cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream); +/** + * @brief Get the cudaStream_t from a cuvsResources_t t + * + * @param[in] res cuvsResources_t opaque C handle + * @param[out] stream cudaStream_t stream to queue CUDA kernels + * @return cuvsError_t + */ +cuvsError_t cuvsStreamGet(cuvsResources_t res, cudaStream_t* stream); + +/** + * @brief Syncs the current CUDA stream on the resources object + * + * @param[in] res cuvsResources_t opaque C handle + * @return cuvsError_t + */ +cuvsError_t cuvsStreamSync(cuvsResources_t res); +/** @} */ + +/** + * @defgroup memory_c cuVS Memory Allocation + * @{ + */ + +/** + * @brief Allocates device memory using RMM + * + * + * @param[in] res cuvsResources_t opaque C handle + * @param[out] ptr Pointer to allocated device memory + * @param[in] bytes Size in bytes to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes); + +/** + * @brief Deallocates device memory using RMM + * + * @param[in] res cuvsResources_t opaque C handle + * @param[in] ptr Pointer to allocated device memory to free + * @param[in] bytes Size in bytes to allocate + * @return cuvsError_t + */ +cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes); + /** @} */ #ifdef __cplusplus diff --git a/cpp/src/core/c_api.cpp b/cpp/src/core/c_api.cpp index 7ddb4f3e4..96504a2fe 100644 --- a/cpp/src/core/c_api.cpp +++ b/cpp/src/core/c_api.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res) @@ -47,6 +48,40 @@ extern "C" cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream) }); } +extern "C" cuvsError_t cuvsStreamGet(cuvsResources_t res, cudaStream_t* stream) +{ + return cuvs::core::translate_exceptions([=] { + auto res_ptr = reinterpret_cast(res); + *stream = raft::resource::get_cuda_stream(*res_ptr); + }); +} + +extern "C" cuvsError_t cuvsStreamSync(cuvsResources_t res) +{ + return cuvs::core::translate_exceptions([=] { + auto res_ptr = reinterpret_cast(res); + raft::resource::sync_stream(*res_ptr); + }); +} + +extern "C" cuvsError_t cuvsRMMAlloc(cuvsResources_t res, void** ptr, size_t bytes) +{ + return cuvs::core::translate_exceptions([=] { + auto res_ptr = reinterpret_cast(res); + auto mr = rmm::mr::get_current_device_resource(); + *ptr = mr->allocate(bytes, raft::resource::get_cuda_stream(*res_ptr)); + }); +} + +extern "C" cuvsError_t cuvsRMMFree(cuvsResources_t res, void* ptr, size_t bytes) +{ + return cuvs::core::translate_exceptions([=] { + auto res_ptr = reinterpret_cast(res); + auto mr = rmm::mr::get_current_device_resource(); + mr->deallocate(ptr, bytes, raft::resource::get_cuda_stream(*res_ptr)); + }); +} + thread_local std::string last_error_text = ""; extern "C" const char* cuvsGetLastErrorText() diff --git a/rust/cuvs-sys/build.rs b/rust/cuvs-sys/build.rs index 816a6f259..ec9672569 100644 --- a/rust/cuvs-sys/build.rs +++ b/rust/cuvs-sys/build.rs @@ -101,9 +101,8 @@ fn main() { .allowlist_type("(cuvs|cagra|DL).*") .allowlist_function("(cuvs|cagra).*") .rustified_enum("(cuvs|cagra|DL).*") - // also need some basic cuda mem functions - // (TODO: should we be adding in RMM support instead here?) - .allowlist_function("(cudaMalloc|cudaFree|cudaMemcpy)") + // also need some basic cuda mem functions for copying data + .allowlist_function("(cudaMemcpyAsync|cudaMemcpy)") .rustified_enum("cudaError") .generate() .expect("Unable to generate cagra_c bindings") diff --git a/rust/cuvs/src/cagra/index.rs b/rust/cuvs/src/cagra/index.rs index 3394889aa..6a5149f07 100644 --- a/rust/cuvs/src/cagra/index.rs +++ b/rust/cuvs/src/cagra/index.rs @@ -54,7 +54,7 @@ impl Index { /// Creates a new empty index pub fn new() -> Result { unsafe { - let mut index = core::mem::MaybeUninit::::uninit(); + let mut index = std::mem::MaybeUninit::::uninit(); check_cuvs(ffi::cuvsCagraIndexCreate(index.as_mut_ptr()))?; Ok(Index(index.assume_init())) } diff --git a/rust/cuvs/src/cagra/index_params.rs b/rust/cuvs/src/cagra/index_params.rs index ecc660531..2e3367e06 100644 --- a/rust/cuvs/src/cagra/index_params.rs +++ b/rust/cuvs/src/cagra/index_params.rs @@ -27,7 +27,7 @@ impl IndexParams { /// Returns a new IndexParams pub fn new() -> Result { unsafe { - let mut params = core::mem::MaybeUninit::::uninit(); + let mut params = std::mem::MaybeUninit::::uninit(); check_cuvs(ffi::cuvsCagraIndexParamsCreate(params.as_mut_ptr()))?; Ok(IndexParams(params.assume_init())) } diff --git a/rust/cuvs/src/cagra/search_params.rs b/rust/cuvs/src/cagra/search_params.rs index 14956966e..7b891b002 100644 --- a/rust/cuvs/src/cagra/search_params.rs +++ b/rust/cuvs/src/cagra/search_params.rs @@ -28,7 +28,7 @@ impl SearchParams { /// Returns a new SearchParams object pub fn new() -> Result { unsafe { - let mut params = core::mem::MaybeUninit::::uninit(); + let mut params = std::mem::MaybeUninit::::uninit(); check_cuvs(ffi::cuvsCagraSearchParamsCreate(params.as_mut_ptr()))?; Ok(SearchParams(params.assume_init())) } diff --git a/rust/cuvs/src/dlpack.rs b/rust/cuvs/src/dlpack.rs index a1d4e41c6..238caec9d 100644 --- a/rust/cuvs/src/dlpack.rs +++ b/rust/cuvs/src/dlpack.rs @@ -16,7 +16,7 @@ use std::convert::From; -use crate::error::{check_cuda, Result}; +use crate::error::{check_cuda, check_cuvs, Result}; use crate::resources::Resources; /// ManagedTensor is a wrapper around a dlpack DLManagedTensor object. @@ -33,36 +33,27 @@ impl ManagedTensor { &self.0 as *const _ as *mut _ } - fn bytes(&self) -> usize { - // figure out how many bytes to allocate - let mut bytes: usize = 1; - for x in 0..self.0.dl_tensor.ndim { - bytes *= unsafe { (*self.0.dl_tensor.shape.add(x as usize)) as usize }; - } - bytes *= (self.0.dl_tensor.dtype.bits / 8) as usize; - bytes - } - /// Creates a new ManagedTensor on the current GPU device, and copies /// the data into it. - pub fn to_device(&self, _res: &Resources) -> Result { + pub fn to_device(&self, res: &Resources) -> Result { unsafe { - let bytes = self.bytes(); + let bytes = dl_tensor_bytes(&self.0.dl_tensor); let mut device_data: *mut std::ffi::c_void = std::ptr::null_mut(); // allocate storage, copy over - check_cuda(ffi::cudaMalloc(&mut device_data as *mut _, bytes))?; - check_cuda(ffi::cudaMemcpy( + check_cuvs(ffi::cuvsRMMAlloc(res.0, &mut device_data as *mut _, bytes))?; + + check_cuda(ffi::cudaMemcpyAsync( device_data, self.0.dl_tensor.data, bytes, ffi::cudaMemcpyKind_cudaMemcpyDefault, + res.get_cuda_stream()?, ))?; let mut ret = self.0.clone(); ret.dl_tensor.data = device_data; - // call cudaFree automatically to clean up data - ret.deleter = Some(cuda_free_tensor); + ret.deleter = Some(rmm_free_tensor); ret.dl_tensor.device.device_type = ffi::DLDeviceType::kDLCUDA; Ok(ManagedTensor(ret)) @@ -80,21 +71,32 @@ impl ManagedTensor { arr: &mut ndarray::ArrayBase, ) -> Result<()> { unsafe { - let bytes = self.bytes(); + let bytes = dl_tensor_bytes(&self.0.dl_tensor); check_cuda(ffi::cudaMemcpy( arr.as_mut_ptr() as *mut std::ffi::c_void, self.0.dl_tensor.data, bytes, ffi::cudaMemcpyKind_cudaMemcpyDefault, ))?; - Ok(()) } } } -unsafe extern "C" fn cuda_free_tensor(self_: *mut ffi::DLManagedTensor) { - let _ = ffi::cudaFree((*self_).dl_tensor.data); +/// Figures out how many bytes are in a DLTensor +fn dl_tensor_bytes(tensor: &ffi::DLTensor) -> usize { + let mut bytes: usize = 1; + for dim in 0..tensor.ndim { + bytes *= unsafe { (*tensor.shape.add(dim as usize)) as usize }; + } + bytes *= (tensor.dtype.bits / 8) as usize; + bytes +} + +unsafe extern "C" fn rmm_free_tensor(self_: *mut ffi::DLManagedTensor) { + let bytes = dl_tensor_bytes(&(*self_).dl_tensor); + let res = Resources::new().unwrap(); + let _ = ffi::cuvsRMMFree(res.0, (*self_).dl_tensor.data as *mut _, bytes); } /// Create a non-owning view of a Tensor from a ndarray diff --git a/rust/cuvs/src/resources.rs b/rust/cuvs/src/resources.rs index a5c503dc5..0c60cf669 100644 --- a/rust/cuvs/src/resources.rs +++ b/rust/cuvs/src/resources.rs @@ -32,6 +32,25 @@ impl Resources { } Ok(Resources(res)) } + + /// Sets the current cuda stream + pub fn set_cuda_stream(&self, stream: ffi::cudaStream_t) -> Result<()> { + unsafe { check_cuvs(ffi::cuvsStreamSet(self.0, stream)) } + } + + /// Gets the current cuda stream + pub fn get_cuda_stream(&self) -> Result { + unsafe { + let mut stream = std::mem::MaybeUninit::::uninit(); + check_cuvs(ffi::cuvsStreamGet(self.0, stream.as_mut_ptr()))?; + Ok(stream.assume_init()) + } + } + + /// Syncs the current cuda stream + pub fn sync_stream(&self) -> Result<()> { + unsafe { check_cuvs(ffi::cuvsStreamSync(self.0)) } + } } impl Drop for Resources {