Skip to content

Commit

Permalink
Merge pull request #271 from ashvardanian/main-dev
Browse files Browse the repository at this point in the history
Much faster exact search with slightly higher memory consumption
  • Loading branch information
ashvardanian authored Sep 20, 2023
2 parents 9145417 + 87135d3 commit a3896ac
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 38 deletions.
4 changes: 4 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fmax-errors=1")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fdiagnostics-color=always")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftime-report")

if(${USEARCH_USE_OPENMP})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
Expand All @@ -32,6 +34,8 @@ elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address -fsanitize=leak -fsanitize=alignment -fsanitize=undefined")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wfatal-errors")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pedantic")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fcolor-diagnostics")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ftime-trace")

if(${USEARCH_USE_OPENMP})
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
Expand Down
17 changes: 12 additions & 5 deletions include/usearch/index_dense.hpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
#pragma once
#include <stdlib.h> // `aligned_alloc`

#include <functional> // `std::function`
#include <numeric> // `std::iota`
#include <shared_mutex> // `std::shared_mutex`
#include <thread> // `std::thread`
#include <vector> // `std::vector`
#include <functional> // `std::function`
#include <numeric> // `std::iota`
#include <thread> // `std::thread`
#include <vector> // `std::vector`

#include <usearch/index.hpp>
#include <usearch/index_plugins.hpp>

#if defined(USEARCH_DEFINED_CPP17)
#include <shared_mutex> // `std::shared_mutex`
#endif

namespace unum {
namespace usearch {

Expand Down Expand Up @@ -383,7 +386,11 @@ class index_dense_gt {
/// @brief Mutex, controlling concurrent access to `available_threads_`.
mutable std::mutex available_threads_mutex_;

#if defined(USEARCH_DEFINED_CPP17)
using shared_mutex_t = std::shared_mutex;
#else
using shared_mutex_t = unfair_shared_mutex_t;
#endif
using shared_lock_t = shared_lock_gt<shared_mutex_t>;
using unique_lock_t = std::unique_lock<shared_mutex_t>;

Expand Down
8 changes: 2 additions & 6 deletions include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,10 @@

#if USEARCH_USE_SIMSIMD
#if defined(USEARCH_DEFINED_LINUX)
#define SIMSIMD_TARGET_X86_AVX2 1
#define SIMSIMD_TARGET_X86_AVX512 1
#define SIMSIMD_TARGET_ARM_NEON 1
#define SIMSIMD_TARGET_ARM_SVE 1
#include <simsimd/simsimd.h>
#elif defined(USEARCH_DEFINED_APPLE)
#define SIMSIMD_TARGET_X86_AVX2 1
#define SIMSIMD_TARGET_ARM_NEON 1
#define SIMSIMD_TARGET_X86_AVX512 0
#define SIMSIMD_TARGET_ARM_SVE 0
#include <simsimd/simsimd.h>
#endif
#endif
Expand Down
49 changes: 26 additions & 23 deletions python/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,47 +526,34 @@ static void search_typed_brute_force( //
byte_t const* dataset_data = reinterpret_cast<byte_t const*>(dataset_info.ptr);
byte_t const* queries_data = reinterpret_cast<byte_t const*>(queries_info.ptr);
for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx)
counts_py1d(query_idx) = 0;
counts_py1d(query_idx) = wanted;

if (!threads)
threads = std::thread::hardware_concurrency();

std::size_t tasks_count = static_cast<std::size_t>(dataset_count * queries_count);
bitset_t query_mutexes(static_cast<std::size_t>(queries_count));
if (!query_mutexes)
throw std::bad_alloc();

// Progress status
progress_t progress_{progress};
std::atomic<std::size_t> processed{0};

// Allocate temporary memory to store the distance matrix
// Previous version didn't need temporary memory, but the performance was much lower
struct dense_key_and_distance_t {
u32_t offset;
f32_t distance;
};
std::vector<dense_key_and_distance_t> keys_and_distances(tasks_count);

executor_default_t{threads}.dynamic(tasks_count, [&](std::size_t thread_idx, std::size_t task_idx) {
//
std::size_t dataset_idx = task_idx / queries_count;
std::size_t query_idx = task_idx % queries_count;

byte_t const* dataset = dataset_data + dataset_idx * dataset_info.strides[0];
byte_t const* query = queries_data + query_idx * queries_info.strides[0];
distance_t distance = metric(dataset, query);

{
auto lock = query_mutexes.lock(query_idx);
dense_key_t* keys = &keys_py2d(query_idx, 0);
distance_t* distances = &distances_py2d(query_idx, 0);
std::size_t& matches = reinterpret_cast<std::size_t&>(counts_py1d(query_idx));
if (matches == wanted)
if (distances[wanted - 1] <= distance)
return true;

std::size_t offset = std::lower_bound(distances, distances + matches, distance) - distances;

std::size_t count_worse = matches - offset - (wanted == matches);
std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(dense_key_t));
std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t));
keys[offset] = static_cast<dense_key_t>(dataset_idx);
distances[offset] = distance;
matches += matches != wanted;
}
keys_and_distances[task_idx] = {static_cast<u32_t>(query_idx), static_cast<f32_t>(distance)};

// We don't want to check for signals from multiple threads
++processed;
Expand All @@ -576,6 +563,20 @@ static void search_typed_brute_force( //
return true;
});

// Partial-sort every query result
executor_default_t{threads}.fixed(queries_count, [&](std::size_t, std::size_t query_idx) {
auto start = keys_and_distances.data() + query_idx * dataset_count;
std::partial_sort(start, start + wanted, start + dataset_count,
[](dense_key_and_distance_t const& a, dense_key_and_distance_t const& b) {
return a.distance < b.distance;
});

dense_key_t* keys = &keys_py2d(query_idx, 0);
distance_t* distances = &distances_py2d(query_idx, 0);
for (std::size_t i = 0; i != wanted; ++i)
keys[i] = static_cast<dense_key_t>(start[i].offset), distances[i] = start[i].distance;
});

// At the end report the latest numbers, because the reporter thread may be finished earlier
progress_(processed.load(), tasks_count);
}
Expand All @@ -602,6 +603,8 @@ static py::tuple search_many_brute_force( //
Py_ssize_t queries_dimensions = queries_info.shape[1];
if (dataset_dimensions != queries_dimensions)
throw std::invalid_argument("The number of vector dimensions doesn't match!");
if (wanted > dataset_count)
throw std::invalid_argument("You can't request more matches than in the dataset!");

scalar_kind_t dataset_kind = numpy_string_to_kind(dataset_info.format);
scalar_kind_t queries_kind = numpy_string_to_kind(queries_info.format);
Expand Down
4 changes: 2 additions & 2 deletions python/scripts/test_tooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@ def test_exact_search(rows: int, cols: int):
:param int cols: The number of columns in the matrix.
"""
original = np.random.rand(rows, cols)
matches: BatchMatches = search(original, original, 10, exact=True)
matches: BatchMatches = search(original, original, min(10, rows), exact=True)
top_matches = (
[int(m.keys[0]) for m in matches] if rows > 1 else int(matches.keys[0])
)
assert np.all(top_matches == np.arange(rows))

matches: Matches = search(original, original[0], 10, exact=True)
matches: Matches = search(original, original[0], min(10, rows), exact=True)
top_match = int(matches.keys[0])
assert top_match == 0

Expand Down
10 changes: 8 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
if sys.platform == "linux":
compile_args.append("-std=c++17")
compile_args.append("-O3") # Maximize performance
compile_args.append("-g") # Simplify debugging
compile_args.append("-Wno-unknown-pragmas")
compile_args.append("-fdiagnostics-color=always")

# Simplify debugging, but the normal `-g` may make builds much longer!
compile_args.append("-g1")

macros_args.append(("USEARCH_USE_NATIVE_F16", "0"))
macros_args.append(("USEARCH_USE_SIMSIMD", "1"))
Expand All @@ -27,9 +30,12 @@
compile_args.append("-mmacosx-version-min=10.15")
compile_args.append("-std=c++17")
compile_args.append("-O3") # Maximize performance
compile_args.append("-g") # Simplify debugging
compile_args.append("-fcolor-diagnostics")
compile_args.append("-Wno-unknown-pragmas")

# Simplify debugging, but the normal `-g` may make builds much longer!
compile_args.append("-g1")

# Linking OpenMP requires additional preparation in CIBuildWheel
# macros_args.append(("USEARCH_USE_OPENMP", "1"))
# compile_args.append("-Xpreprocessor -fopenmp")
Expand Down

0 comments on commit a3896ac

Please sign in to comment.