Skip to content

Commit

Permalink
Expose exception text to C and Rust API's
Browse files Browse the repository at this point in the history
Previously when something went wrong in the Rust or C API - all the
end user would see is a `CUVS_ERROR` return code with no extra
indication of what went wrong.

This change exposes the exception text to both the C and Rust api's,
and provides a convenience method to automatically catch c++ exceptions,
and convert the exception into an error code with the text set appropiately.
  • Loading branch information
benfred committed Mar 4, 2024
1 parent 8e6979f commit 71850c1
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 72 deletions.
10 changes: 10 additions & 0 deletions cpp/include/cuvs/core/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ cuvsError_t cuvsResourcesDestroy(cuvsResources_t res);
*/
cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream);

/** @brief Returns a string describing the last seen error on this thread, or
* NULL if the last function succeeded.
*/
const char* cuvsGetLastErrorText();

/**
* @brief Sets a string describing an error seen on the thread. Passing NULL
* clears any previously seen error message.
*/
void cuvsSetLastErrorText(const char* error);
#ifdef __cplusplus
}
#endif
Expand Down
45 changes: 45 additions & 0 deletions cpp/include/cuvs/core/exceptions.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "c_api.h"

#include <exception>

namespace cuvs::core {

/**
* @brief Translates C++ exceptions into cuvs C-API error codes
*/
template <typename Fn>
cuvsError_t translate_exceptions(Fn func)
{
cuvsError_t status;
try {
func();
status = CUVS_SUCCESS;
cuvsSetLastErrorText(NULL);
} catch (const std::exception& e) {
cuvsSetLastErrorText(e.what());
status = CUVS_ERROR;
} catch (...) {
cuvsSetLastErrorText("unknown exception");
status = CUVS_ERROR;
}
return status;
}
} // namespace cuvs::core
38 changes: 17 additions & 21 deletions cpp/src/core/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,42 @@

#include <cstdint>
#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
#include <memory>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <thread>

extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res)
{
cuvsError_t status;
try {
return cuvs::core::translate_exceptions([=] {
auto res_ptr = new raft::resources{};
*res = reinterpret_cast<uintptr_t>(res_ptr);
status = CUVS_SUCCESS;
} catch (...) {
status = CUVS_ERROR;
}
return status;
});
}

extern "C" cuvsError_t cuvsResourcesDestroy(cuvsResources_t res)
{
cuvsError_t status;
try {
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
delete res_ptr;
status = CUVS_SUCCESS;
} catch (...) {
status = CUVS_ERROR;
}
return status;
});
}

extern "C" cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream)
{
cuvsError_t status;
try {
return cuvs::core::translate_exceptions([=] {
auto res_ptr = reinterpret_cast<raft::resources*>(res);
raft::resource::set_cuda_stream(*res_ptr, static_cast<rmm::cuda_stream_view>(stream));
status = CUVS_SUCCESS;
} catch (...) {
status = CUVS_ERROR;
}
return status;
});
}

thread_local std::string last_error_text = "";

extern "C" const char* cuvsGetLastErrorText()
{
return last_error_text.empty() ? NULL : last_error_text.c_str();
}

extern "C" void cuvsSetLastErrorText(const char* error) { last_error_text = error ? error : ""; }
62 changes: 14 additions & 48 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <raft/core/resources.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
#include <cuvs/core/interop.hpp>
#include <cuvs/neighbors/cagra.h>
#include <cuvs/neighbors/cagra.hpp>
Expand Down Expand Up @@ -96,17 +97,12 @@ void _search(cuvsResources_t res,

extern "C" cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index)
{
try {
*index = new cuvsCagraIndex{};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { *index = new cuvsCagraIndex{}; });
}

extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr)
{
try {
return cuvs::core::translate_exceptions([=] {
auto index = *index_c_ptr;

if (index.dtype.code == kDLFloat) {
Expand All @@ -123,18 +119,15 @@ extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr)
delete index_ptr;
}
delete index_c_ptr;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
cuvsCagraIndexParams_t params,
DLManagedTensor* dataset_tensor,
cuvsCagraIndex_t index)
{
try {
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;

if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
Expand All @@ -151,13 +144,7 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
dataset.dtype.code,
dataset.dtype.bits);
}
return CUVS_SUCCESS;
} catch (const std::exception& ex) {
std::cerr << "Error occurred: " << ex.what() << std::endl;
return CUVS_ERROR;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
Expand All @@ -167,7 +154,7 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
DLManagedTensor* neighbors_tensor,
DLManagedTensor* distances_tensor)
{
try {
return cuvs::core::translate_exceptions([=] {
auto queries = queries_tensor->dl_tensor;
auto neighbors = neighbors_tensor->dl_tensor;
auto distances = distances_tensor->dl_tensor;
Expand Down Expand Up @@ -198,57 +185,36 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
queries.dtype.code,
queries.dtype.bits);
}
return CUVS_SUCCESS;
} catch (const std::exception& ex) {
std::cerr << "Error occurred: " << ex.what() << std::endl;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params)
{
try {
return cuvs::core::translate_exceptions([=] {
*params = new cuvsCagraIndexParams{.intermediate_graph_degree = 128,
.graph_degree = 64,
.build_algo = IVF_PQ,
.nn_descent_niter = 20};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraIndexParamsDestroy(cuvsCagraIndexParams_t params)
{
try {
delete params;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsCagraSearchParamsCreate(cuvsCagraSearchParams_t* params)
{
try {
return cuvs::core::translate_exceptions([=] {
*params = new cuvsCagraSearchParams{.itopk_size = 64,
.search_width = 1,
.hashmap_max_fill_rate = 0.5,
.num_random_samplings = 1,
.rand_xor_mask = 0x128394};
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
});
}

extern "C" cuvsError_t cuvsCagraSearchParamsDestroy(cuvsCagraSearchParams_t params)
{
try {
delete params;
return CUVS_SUCCESS;
} catch (...) {
return CUVS_ERROR;
}
return cuvs::core::translate_exceptions([=] { delete params; });
}
7 changes: 6 additions & 1 deletion rust/cuvs/src/cagra/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ impl Index {
let dataset: ManagedTensor = dataset.into();
let index = Index::new()?;
unsafe {
check_cuvs(ffi::cuvsCagraBuild(res.0, params.0, dataset.as_ptr(), index.0))?;
check_cuvs(ffi::cuvsCagraBuild(
res.0,
params.0,
dataset.as_ptr(),
index.0,
))?;
}
Ok(index)
}
Expand Down
26 changes: 24 additions & 2 deletions rust/cuvs/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@

use std::fmt;

#[derive(Debug, Clone)]
pub struct CuvsError {
code: ffi::cuvsError_t,
text: String,
}

#[derive(Debug, Clone)]
pub enum Error {
CudaError(ffi::cudaError_t),
CuvsError(ffi::cuvsError_t),
CuvsError(CuvsError),
}

impl std::error::Error for Error {}
impl std::error::Error for CuvsError {}

pub type Result<T> = std::result::Result<T, Error>;

Expand All @@ -35,11 +42,26 @@ impl fmt::Display for Error {
}
}

impl fmt::Display for CuvsError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}:{}", self.code, self.text)
}
}

/// Simple wrapper to convert a cuvsError_t into a Result
pub fn check_cuvs(err: ffi::cuvsError_t) -> Result<()> {
match err {
ffi::cuvsError_t::CUVS_SUCCESS => Ok(()),
_ => Err(Error::CuvsError(err)),
_ => {
// get a description of the error from cuvs
let cstr = unsafe {
let text_ptr = ffi::cuvsGetLastErrorText();
std::ffi::CStr::from_ptr(text_ptr)
};
let text = std::string::String::from_utf8_lossy(cstr.to_bytes()).to_string();

Err(Error::CuvsError(CuvsError { code: err, text }))
}
}
}

Expand Down

0 comments on commit 71850c1

Please sign in to comment.