From 6976acc9c81cb47f0c54d47dc6f0a99dc4017ede Mon Sep 17 00:00:00 2001 From: viclafargue Date: Wed, 7 Feb 2024 15:31:13 +0100 Subject: [PATCH] Adding bruteforce KNN --- cpp/CMakeLists.txt | 2 ++ cpp/include/cuvs/neighbors/brute_force.hpp | 39 ++++++++++++++++++++++ cpp/src/neighbors/brute_force.cpp | 35 +++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 cpp/include/cuvs/neighbors/brute_force.hpp create mode 100644 cpp/src/neighbors/brute_force.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 460df7bbf..1a8c9e2af 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -188,6 +188,8 @@ include(cmake/thirdparty/get_cutlass.cmake) add_library( cuvs SHARED + src/neighbors/brute_force.cpp + src/neighbors/cagra_build_float.cpp src/neighbors/cagra_build_int8.cpp src/neighbors/cagra_build_uint8.cpp diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp new file mode 100644 index 000000000..a49e72b47 --- /dev/null +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + + +namespace cuvs::neighbors::brute_force { + +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Unexpanded, \ + std::optional metric_arg = std::make_optional(2.0f), \ + std::optional global_id_offset = std::nullopt); + +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); + +#undef RAFT_INST_BFKNN + +} // namespace cuvs::neighbors::brute_force \ No newline at end of file diff --git a/cpp/src/neighbors/brute_force.cpp b/cpp/src/neighbors/brute_force.cpp new file mode 100644 index 000000000..24df47096 --- /dev/null +++ b/cpp/src/neighbors/brute_force.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace cuvs::neighbors::brute_force { + +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset); + +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); + +#undef RAFT_INST_BFKNN + +} // namespace cuvs::neighbors::brute_force \ No newline at end of file