Skip to content

Commit

Permalink
Fixing more namespace and compile issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Nov 21, 2023
1 parent 815e7aa commit 80e874b
Show file tree
Hide file tree
Showing 129 changed files with 1,218 additions and 1,141 deletions.
5 changes: 5 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,11 @@ add_library(
src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu
)

target_compile_options(
cuvs INTERFACE $<$<COMPILE_LANG_AND_ID:CUDA,NVIDIA>:--expt-extended-lambda
--expt-relaxed-constexpr>
)

add_library(cuvs::cuvs ALIAS cuvs)

target_include_directories(
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/cuvs/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ void fusedL2NNImpl(OutT* min,
dim3 blk(P::Nthreads);
auto nblks = raft::ceildiv<int>(m, P::Nthreads);
constexpr auto maxVal = std::numeric_limits<DataT>::max();
typedef KeyValuePair<IdxT, DataT> KVPair;
typedef raft::KeyValuePair<IdxT, DataT> KVPair;

RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream));
if (initOutBuffer) {
Expand Down
30 changes: 16 additions & 14 deletions cpp/include/cuvs/distance/fused_l2_nn-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,33 +98,35 @@ void fusedL2NN(OutT* min,
auto py = reinterpret_cast<uintptr_t>(y);
if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) {
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
detail::fusedL2NNImpl<
DataT,
OutT,
IdxT,
typename raft::linalg::Policy4x4Skinny<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 16 / sizeof(DataT)>::Policy,
typename raft::linalg::Policy4x4<DataT, 16 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
} else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) {
if (is_skinny) {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
detail::fusedL2NNImpl<
DataT,
OutT,
IdxT,
typename raft::linalg::Policy4x4Skinny<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 8 / sizeof(DataT)>::Policy,
typename raft::linalg::Policy4x4<DataT, 8 / sizeof(DataT)>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
Expand All @@ -133,14 +135,14 @@ void fusedL2NN(OutT* min,
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4Skinny<DataT, 1>::Policy,
typename raft::linalg::Policy4x4Skinny<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
} else {
detail::fusedL2NNImpl<DataT,
OutT,
IdxT,
typename linalg::Policy4x4<DataT, 1>::Policy,
typename raft::linalg::Policy4x4<DataT, 1>::Policy,
ReduceOpT>(
min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}
Expand Down
20 changes: 10 additions & 10 deletions cpp/include/cuvs/neighbors/ball_cover-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ void all_knn_query(raft::resources const& handle,
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void all_knn_query(raft::resources const& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
raft::device_matrix_view<idx_t, matrix_idx_t, raft::row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, raft::row_major> dists,
int_t k,
bool perform_post_filtering = true,
float weight = 1.0) RAFT_EXPLICIT;
Expand All @@ -60,9 +60,9 @@ void knn_query(raft::resources const& handle,
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void knn_query(raft::resources const& handle,
const BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
raft::device_matrix_view<const value_t, matrix_idx_t, raft::row_major> query,
raft::device_matrix_view<idx_t, matrix_idx_t, raft::row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, raft::row_major> dists,
int_t k,
bool perform_post_filtering = true,
float weight = 1.0) RAFT_EXPLICIT;
Expand Down Expand Up @@ -91,8 +91,8 @@ void knn_query(raft::resources const& handle,
cuvs::neighbors::ball_cover::all_knn_query<idx_t, value_t, int_t, matrix_idx_t>( \
raft::resources const& handle, \
cuvs::neighbors::ball_cover::BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index, \
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds, \
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists, \
raft::device_matrix_view<idx_t, matrix_idx_t, raft::row_major> inds, \
raft::device_matrix_view<value_t, matrix_idx_t, raft::row_major> dists, \
int_t k, \
bool perform_post_filtering, \
float weight); \
Expand All @@ -112,9 +112,9 @@ void knn_query(raft::resources const& handle,
cuvs::neighbors::ball_cover::knn_query<idx_t, value_t, int_t, matrix_idx_t>( \
raft::resources const& handle, \
const cuvs::neighbors::ball_cover::BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index, \
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query, \
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds, \
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists, \
raft::device_matrix_view<const value_t, matrix_idx_t, raft::row_major> query, \
raft::device_matrix_view<idx_t, matrix_idx_t, raft::row_major> inds, \
raft::device_matrix_view<value_t, matrix_idx_t, raft::row_major> dists, \
int_t k, \
bool perform_post_filtering, \
float weight);
Expand Down
10 changes: 5 additions & 5 deletions cpp/include/cuvs/neighbors/ball_cover-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ void all_knn_query(raft::resources const& handle,
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void all_knn_query(raft::resources const& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
raft::device_matrix_view<idx_t, matrix_idx_t, raft::row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, raft::row_major> dists,
int_t k,
bool perform_post_filtering = true,
float weight = 1.0)
Expand Down Expand Up @@ -354,9 +354,9 @@ void knn_query(raft::resources const& handle,
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void knn_query(raft::resources const& handle,
const BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> query,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, row_major> dists,
raft::device_matrix_view<const value_t, matrix_idx_t, raft::row_major> query,
raft::device_matrix_view<idx_t, matrix_idx_t, raft::row_major> inds,
raft::device_matrix_view<value_t, matrix_idx_t, raft::row_major> dists,
int_t k,
bool perform_post_filtering = true,
float weight = 1.0)
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/cuvs/neighbors/ball_cover_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ class BallCoverIndex {
raft::device_vector_view<value_idx, matrix_idx> get_R_1nn_cols() { return R_1nn_cols.view(); }
raft::device_vector_view<value_t, matrix_idx> get_R_1nn_dists() { return R_1nn_dists.view(); }
raft::device_vector_view<value_t, matrix_idx> get_R_radius() { return R_radius.view(); }
raft::device_matrix_view<value_t, matrix_idx, row_major> get_R() { return R.view(); }
raft::device_matrix_view<value_t, matrix_idx, raft::row_major> get_R() { return R.view(); }
raft::device_vector_view<value_t, matrix_idx> get_R_closest_landmark_dists()
{
return R_closest_landmark_dists.view();
}
raft::device_matrix_view<const value_t, matrix_idx, row_major> get_X() const { return X; }
raft::device_matrix_view<const value_t, matrix_idx, raft::row_major> get_X() const { return X; }

cuvs::distance::DistanceType get_metric() const { return metric; }

Expand All @@ -145,7 +145,7 @@ class BallCoverIndex {
value_int n;
value_int n_landmarks;

raft::device_matrix_view<const value_t, matrix_idx, row_major> X;
raft::device_matrix_view<const value_t, matrix_idx, raft::row_major> X;

cuvs::distance::DistanceType metric;

Expand All @@ -158,7 +158,7 @@ class BallCoverIndex {

raft::device_vector<value_t, matrix_idx> R_radius;

raft::device_matrix<value_t, matrix_idx, row_major> R;
raft::device_matrix<value_t, matrix_idx, raft::row_major> R;

protected:
bool index_trained;
Expand Down
65 changes: 33 additions & 32 deletions cpp/include/cuvs/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,26 @@ namespace cuvs::neighbors::brute_force {
template <typename value_t, typename idx_t>
inline void knn_merge_parts(
raft::resources const& handle,
raft::device_matrix_view<const value_t, idx_t, row_major> in_keys,
raft::device_matrix_view<const idx_t, idx_t, row_major> in_values,
raft::device_matrix_view<value_t, idx_t, row_major> out_keys,
raft::device_matrix_view<idx_t, idx_t, row_major> out_values,
raft::device_matrix_view<const value_t, idx_t, raft::row_major> in_keys,
raft::device_matrix_view<const idx_t, idx_t, raft::row_major> in_values,
raft::device_matrix_view<value_t, idx_t, raft::row_major> out_keys,
raft::device_matrix_view<idx_t, idx_t, raft::row_major> out_values,
size_t n_samples,
std::optional<raft::device_vector_view<idx_t, idx_t>> translations = std::nullopt) RAFT_EXPLICIT;

template <typename T, typename Accessor>
index<T> build(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, Accessor> dataset,
cuvs::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0) RAFT_EXPLICIT;
index<T> build(
raft::resources const& res,
raft::mdspan<const T, raft::matrix_extent<int64_t>, raft::row_major, Accessor> dataset,
cuvs::distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
T metric_arg = 0.0) RAFT_EXPLICIT;

template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<T, int64_t, row_major> distances) RAFT_EXPLICIT;
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<T, int64_t, raft::row_major> distances) RAFT_EXPLICIT;

template <typename idx_t,
typename value_t,
Expand All @@ -61,8 +62,8 @@ template <typename idx_t,
void knn(raft::resources const& handle,
std::vector<raft::device_matrix_view<const value_t, matrix_idx, index_layout>> index,
raft::device_matrix_view<const value_t, matrix_idx, search_layout> search,
raft::device_matrix_view<idx_t, matrix_idx, row_major> indices,
raft::device_matrix_view<value_t, matrix_idx, row_major> distances,
raft::device_matrix_view<idx_t, matrix_idx, raft::row_major> indices,
raft::device_matrix_view<value_t, matrix_idx, raft::row_major> distances,
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
std::optional<float> metric_arg = std::make_optional<float>(2.0f),
std::optional<idx_t> global_id_offset = std::nullopt,
Expand All @@ -72,8 +73,8 @@ template <typename value_t, typename idx_t, typename idx_layout, typename query_
void fused_l2_knn(raft::resources const& handle,
raft::device_matrix_view<const value_t, idx_t, idx_layout> index,
raft::device_matrix_view<const value_t, idx_t, query_layout> query,
raft::device_matrix_view<idx_t, idx_t, row_major> out_inds,
raft::device_matrix_view<value_t, idx_t, row_major> out_dists,
raft::device_matrix_view<idx_t, idx_t, raft::row_major> out_inds,
raft::device_matrix_view<value_t, idx_t, raft::row_major> out_dists,
cuvs::distance::DistanceType metric) RAFT_EXPLICIT;

} // namespace cuvs::neighbors::brute_force
Expand All @@ -89,8 +90,8 @@ void fused_l2_knn(raft::resources const& handle,
raft::resources const& handle, \
std::vector<raft::device_matrix_view<const value_t, matrix_idx, index_layout>> index, \
raft::device_matrix_view<const value_t, matrix_idx, search_layout> search, \
raft::device_matrix_view<idx_t, matrix_idx, row_major> indices, \
raft::device_matrix_view<value_t, matrix_idx, row_major> distances, \
raft::device_matrix_view<idx_t, matrix_idx, raft::row_major> indices, \
raft::device_matrix_view<value_t, matrix_idx, raft::row_major> distances, \
cuvs::distance::DistanceType metric, \
std::optional<float> metric_arg, \
std::optional<idx_t> global_id_offset, \
Expand All @@ -112,32 +113,32 @@ namespace cuvs::neighbors::brute_force {
extern template void search<float, int>(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);

extern template void search<float, int64_t>(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances);

extern template cuvs::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
raft::device_matrix_view<const float, int64_t, row_major> dataset,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric,
float metric_arg);
} // namespace cuvs::neighbors::brute_force

#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \
value_t, idx_t, idx_layout, query_layout) \
extern template void cuvs::neighbors::brute_force::fused_l2_knn( \
raft::resources const& handle, \
raft::device_matrix_view<const value_t, idx_t, idx_layout> index, \
raft::device_matrix_view<const value_t, idx_t, query_layout> query, \
raft::device_matrix_view<idx_t, idx_t, row_major> out_inds, \
raft::device_matrix_view<value_t, idx_t, row_major> out_dists, \
#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \
value_t, idx_t, idx_layout, query_layout) \
extern template void cuvs::neighbors::brute_force::fused_l2_knn( \
raft::resources const& handle, \
raft::device_matrix_view<const value_t, idx_t, idx_layout> index, \
raft::device_matrix_view<const value_t, idx_t, query_layout> query, \
raft::device_matrix_view<idx_t, idx_t, raft::row_major> out_inds, \
raft::device_matrix_view<value_t, idx_t, raft::row_major> out_dists, \
cuvs::distance::DistanceType metric);

instantiate_raft_neighbors_brute_force_fused_l2_knn(float,
Expand Down
Loading

0 comments on commit 80e874b

Please sign in to comment.