From 71850c14a411b43c961bfffeb529f56d4f2c0bc2 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 4 Mar 2024 14:13:24 -0800 Subject: [PATCH] Expose exception text to C and Rust API's Previously when something went wrong in the Rust or C API - all the end user would see is a `CUVS_ERROR` return code with no extra indication of what went wrong. This change exposes the exception text to both the C and Rust api's, and provides a convenience method to automatically catch c++ exceptions, and convert the exception into an error code with the text set appropiately. --- cpp/include/cuvs/core/c_api.h | 10 +++++ cpp/include/cuvs/core/exceptions.hpp | 45 ++++++++++++++++++++ cpp/src/core/c_api.cpp | 38 ++++++++--------- cpp/src/neighbors/cagra_c.cpp | 62 +++++++--------------------- rust/cuvs/src/cagra/index.rs | 7 +++- rust/cuvs/src/error.rs | 26 +++++++++++- 6 files changed, 116 insertions(+), 72 deletions(-) create mode 100644 cpp/include/cuvs/core/exceptions.hpp diff --git a/cpp/include/cuvs/core/c_api.h b/cpp/include/cuvs/core/c_api.h index b50032916..eceb917ff 100644 --- a/cpp/include/cuvs/core/c_api.h +++ b/cpp/include/cuvs/core/c_api.h @@ -67,6 +67,16 @@ cuvsError_t cuvsResourcesDestroy(cuvsResources_t res); */ cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream); +/** @brief Returns a string describing the last seen error on this thread, or + * NULL if the last function succeeded. + */ +const char* cuvsGetLastErrorText(); + +/** + * @brief Sets a string describing an error seen on the thread. Passing NULL + * clears any previously seen error message. + */ +void cuvsSetLastErrorText(const char* error); #ifdef __cplusplus } #endif diff --git a/cpp/include/cuvs/core/exceptions.hpp b/cpp/include/cuvs/core/exceptions.hpp new file mode 100644 index 000000000..01ee42151 --- /dev/null +++ b/cpp/include/cuvs/core/exceptions.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "c_api.h" + +#include + +namespace cuvs::core { + +/** + * @brief Translates C++ exceptions into cuvs C-API error codes + */ +template +cuvsError_t translate_exceptions(Fn func) +{ + cuvsError_t status; + try { + func(); + status = CUVS_SUCCESS; + cuvsSetLastErrorText(NULL); + } catch (const std::exception& e) { + cuvsSetLastErrorText(e.what()); + status = CUVS_ERROR; + } catch (...) { + cuvsSetLastErrorText("unknown exception"); + status = CUVS_ERROR; + } + return status; +} +} // namespace cuvs::core diff --git a/cpp/src/core/c_api.cpp b/cpp/src/core/c_api.cpp index 133021d77..7ddb4f3e4 100644 --- a/cpp/src/core/c_api.cpp +++ b/cpp/src/core/c_api.cpp @@ -16,46 +16,42 @@ #include #include +#include #include #include #include #include +#include extern "C" cuvsError_t cuvsResourcesCreate(cuvsResources_t* res) { - cuvsError_t status; - try { + return cuvs::core::translate_exceptions([=] { auto res_ptr = new raft::resources{}; *res = reinterpret_cast(res_ptr); - status = CUVS_SUCCESS; - } catch (...) { - status = CUVS_ERROR; - } - return status; + }); } extern "C" cuvsError_t cuvsResourcesDestroy(cuvsResources_t res) { - cuvsError_t status; - try { + return cuvs::core::translate_exceptions([=] { auto res_ptr = reinterpret_cast(res); delete res_ptr; - status = CUVS_SUCCESS; - } catch (...) { - status = CUVS_ERROR; - } - return status; + }); } extern "C" cuvsError_t cuvsStreamSet(cuvsResources_t res, cudaStream_t stream) { - cuvsError_t status; - try { + return cuvs::core::translate_exceptions([=] { auto res_ptr = reinterpret_cast(res); raft::resource::set_cuda_stream(*res_ptr, static_cast(stream)); - status = CUVS_SUCCESS; - } catch (...) { - status = CUVS_ERROR; - } - return status; + }); } + +thread_local std::string last_error_text = ""; + +extern "C" const char* cuvsGetLastErrorText() +{ + return last_error_text.empty() ? NULL : last_error_text.c_str(); +} + +extern "C" void cuvsSetLastErrorText(const char* error) { last_error_text = error ? error : ""; } diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index 70e268fb2..9fdfe2c1e 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -96,17 +97,12 @@ void _search(cuvsResources_t res, extern "C" cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index) { - try { - *index = new cuvsCagraIndex{}; - return CUVS_SUCCESS; - } catch (...) { - return CUVS_ERROR; - } + return cuvs::core::translate_exceptions([=] { *index = new cuvsCagraIndex{}; }); } extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr) { - try { + return cuvs::core::translate_exceptions([=] { auto index = *index_c_ptr; if (index.dtype.code == kDLFloat) { @@ -123,10 +119,7 @@ extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr) delete index_ptr; } delete index_c_ptr; - return CUVS_SUCCESS; - } catch (...) { - return CUVS_ERROR; - } + }); } extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res, @@ -134,7 +127,7 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res, DLManagedTensor* dataset_tensor, cuvsCagraIndex_t index) { - try { + return cuvs::core::translate_exceptions([=] { auto dataset = dataset_tensor->dl_tensor; if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) { @@ -151,13 +144,7 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res, dataset.dtype.code, dataset.dtype.bits); } - return CUVS_SUCCESS; - } catch (const std::exception& ex) { - std::cerr << "Error occurred: " << ex.what() << std::endl; - return CUVS_ERROR; - } catch (...) { - return CUVS_ERROR; - } + }); } extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, @@ -167,7 +154,7 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, DLManagedTensor* neighbors_tensor, DLManagedTensor* distances_tensor) { - try { + return cuvs::core::translate_exceptions([=] { auto queries = queries_tensor->dl_tensor; auto neighbors = neighbors_tensor->dl_tensor; auto distances = distances_tensor->dl_tensor; @@ -198,57 +185,36 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res, queries.dtype.code, queries.dtype.bits); } - return CUVS_SUCCESS; - } catch (const std::exception& ex) { - std::cerr << "Error occurred: " << ex.what() << std::endl; - } catch (...) { - return CUVS_ERROR; - } + }); } extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params) { - try { + return cuvs::core::translate_exceptions([=] { *params = new cuvsCagraIndexParams{.intermediate_graph_degree = 128, .graph_degree = 64, .build_algo = IVF_PQ, .nn_descent_niter = 20}; - return CUVS_SUCCESS; - } catch (...) { - return CUVS_ERROR; - } + }); } extern "C" cuvsError_t cuvsCagraIndexParamsDestroy(cuvsCagraIndexParams_t params) { - try { - delete params; - return CUVS_SUCCESS; - } catch (...) { - return CUVS_ERROR; - } + return cuvs::core::translate_exceptions([=] { delete params; }); } extern "C" cuvsError_t cuvsCagraSearchParamsCreate(cuvsCagraSearchParams_t* params) { - try { + return cuvs::core::translate_exceptions([=] { *params = new cuvsCagraSearchParams{.itopk_size = 64, .search_width = 1, .hashmap_max_fill_rate = 0.5, .num_random_samplings = 1, .rand_xor_mask = 0x128394}; - return CUVS_SUCCESS; - } catch (...) { - return CUVS_ERROR; - } + }); } extern "C" cuvsError_t cuvsCagraSearchParamsDestroy(cuvsCagraSearchParams_t params) { - try { - delete params; - return CUVS_SUCCESS; - } catch (...) { - return CUVS_ERROR; - } + return cuvs::core::translate_exceptions([=] { delete params; }); } diff --git a/rust/cuvs/src/cagra/index.rs b/rust/cuvs/src/cagra/index.rs index 43f032676..3c45efafd 100644 --- a/rust/cuvs/src/cagra/index.rs +++ b/rust/cuvs/src/cagra/index.rs @@ -34,7 +34,12 @@ impl Index { let dataset: ManagedTensor = dataset.into(); let index = Index::new()?; unsafe { - check_cuvs(ffi::cuvsCagraBuild(res.0, params.0, dataset.as_ptr(), index.0))?; + check_cuvs(ffi::cuvsCagraBuild( + res.0, + params.0, + dataset.as_ptr(), + index.0, + ))?; } Ok(index) } diff --git a/rust/cuvs/src/error.rs b/rust/cuvs/src/error.rs index 618106aba..53c68a976 100644 --- a/rust/cuvs/src/error.rs +++ b/rust/cuvs/src/error.rs @@ -16,13 +16,20 @@ use std::fmt; +#[derive(Debug, Clone)] +pub struct CuvsError { + code: ffi::cuvsError_t, + text: String, +} + #[derive(Debug, Clone)] pub enum Error { CudaError(ffi::cudaError_t), - CuvsError(ffi::cuvsError_t), + CuvsError(CuvsError), } impl std::error::Error for Error {} +impl std::error::Error for CuvsError {} pub type Result = std::result::Result; @@ -35,11 +42,26 @@ impl fmt::Display for Error { } } +impl fmt::Display for CuvsError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{:?}:{}", self.code, self.text) + } +} + /// Simple wrapper to convert a cuvsError_t into a Result pub fn check_cuvs(err: ffi::cuvsError_t) -> Result<()> { match err { ffi::cuvsError_t::CUVS_SUCCESS => Ok(()), - _ => Err(Error::CuvsError(err)), + _ => { + // get a description of the error from cuvs + let cstr = unsafe { + let text_ptr = ffi::cuvsGetLastErrorText(); + std::ffi::CStr::from_ptr(text_ptr) + }; + let text = std::string::String::from_utf8_lossy(cstr.to_bytes()).to_string(); + + Err(Error::CuvsError(CuvsError { code: err, text })) + } } }