Skip to content

Commit

Permalink
Adding bruteforce KNN
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Feb 7, 2024
1 parent 20658f8 commit 6976acc
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ include(cmake/thirdparty/get_cutlass.cmake)

add_library(
cuvs SHARED
src/neighbors/brute_force.cpp

src/neighbors/cagra_build_float.cpp
src/neighbors/cagra_build_int8.cpp
src/neighbors/cagra_build_uint8.cpp
Expand Down
39 changes: 39 additions & 0 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/distance/distance_types.hpp>
#include <raft_runtime/neighbors/brute_force.hpp>


namespace cuvs::neighbors::brute_force {

#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::resources const& handle, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, raft::row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, raft::row_major> distances, \
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded, \
std::optional<float> metric_arg = std::make_optional<float>(2.0f), \
std::optional<IDX_T> global_id_offset = std::nullopt);

RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);

#undef RAFT_INST_BFKNN

} // namespace cuvs::neighbors::brute_force
35 changes: 35 additions & 0 deletions cpp/src/neighbors/brute_force.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuvs/neighbors/brute_force.hpp>

namespace cuvs::neighbors::brute_force {

#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::resources const& handle, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, raft::row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, raft::row_major> distances, \
raft::distance::DistanceType metric, \
std::optional<float> metric_arg, \
std::optional<IDX_T> global_id_offset);

RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);

#undef RAFT_INST_BFKNN

} // namespace cuvs::neighbors::brute_force

0 comments on commit 6976acc

Please sign in to comment.