Skip to content

Commit

Permalink
Merge pull request #476 from unum-cloud/main-dev
Browse files Browse the repository at this point in the history
Mixed Precision Kernels, Windows Builds, and Docs
  • Loading branch information
ashvardanian authored Aug 28, 2024
2 parents 242be10 + 50a6608 commit f336a06
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 15 deletions.
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@
"HNSW",
"hnswlib",
"ibin",
"ivdep",
"jaccard",
"Jemalloc",
"Kullback",
Expand All @@ -175,6 +176,7 @@
"Println",
"pytest",
"Quickstart",
"relock",
"repr",
"rtype",
"SIMD",
Expand Down
8 changes: 7 additions & 1 deletion build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,13 @@ fn main() {
.flag_if_supported("/std:c++17")
.flag_if_supported("/O2")
.flag_if_supported("/fp:fast")
.flag_if_supported("/W1"); // Reduce warnings verbosity
.flag_if_supported("/W1") // Reduce warnings verbosity
.flag_if_supported("/EHsc")
.flag_if_supported("/MD")
.flag_if_supported("/permissive-")
.flag_if_supported("/sdl-")
.define("_ALLOW_RUNTIME_LIBRARY_MISMATCH", None)
.define("_ALLOW_POINTER_TO_CONST_MISMATCH", None);
}

let mut result = build.try_compile("usearch");
Expand Down
71 changes: 63 additions & 8 deletions include/usearch/index_plugins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ enum class scalar_kind_t : std::uint8_t {
i8_k = 23,
};

/**
* @brief Maps a scalar type to its corresponding scalar_kind_t enumeration value.
*/
template <typename scalar_at> scalar_kind_t scalar_kind() noexcept {
if (std::is_same<scalar_at, b1x8_t>())
return scalar_kind_t::b1x8_k;
Expand Down Expand Up @@ -175,22 +178,43 @@ template <typename scalar_at> scalar_kind_t scalar_kind() noexcept {
return scalar_kind_t::unknown_k;
}

/**
* @brief Converts an angle from degrees to radians.
*/
template <typename at> at angle_to_radians(at angle) noexcept { return angle * at(3.14159265358979323846) / at(180); }

/**
* @brief Readability helper to compute the square of a given value.
*/
template <typename at> at square(at value) noexcept { return value * value; }

/**
* @brief Clamps a value between a lower and upper bound using a custom comparator. Similar to `std::clamp`.
* https://en.cppreference.com/w/cpp/algorithm/clamp
*/
template <typename at, typename compare_at> inline at clamp(at v, at lo, at hi, compare_at comp) noexcept {
return comp(v, lo) ? lo : comp(hi, v) ? hi : v;
}

/**
* @brief Clamps a value between a lower and upper bound. Similar to `std::clamp`.
* https://en.cppreference.com/w/cpp/algorithm/clamp
*/
template <typename at> inline at clamp(at v, at lo, at hi) noexcept {
return usearch::clamp(v, lo, hi, std::less<at>{});
}

inline bool str_equals(char const* begin, std::size_t len, char const* other_begin) noexcept {
std::size_t other_len = std::strlen(other_begin);
return len == other_len && std::strncmp(begin, other_begin, len) == 0;
/**
* @brief Compares two strings for equality, given a length for the first string.
*/
inline bool str_equals(char const* first_begin, std::size_t first_len, char const* second_begin) noexcept {
std::size_t second_len = std::strlen(second_begin);
return first_len == second_len && std::strncmp(first_begin, second_begin, first_len) == 0;
}

/**
* @brief Returns the number of bits required to represent a scalar type.
*/
inline std::size_t bits_per_scalar(scalar_kind_t scalar_kind) noexcept {
switch (scalar_kind) {
case scalar_kind_t::uuid_k: return 128;
Expand All @@ -213,6 +237,10 @@ inline std::size_t bits_per_scalar(scalar_kind_t scalar_kind) noexcept {
}
}

/**
* @brief Returns the number of bits in a scalar word for a given scalar type.
* Equivalent to `bits_per_scalar` for types that are not bit-packed.
*/
inline std::size_t bits_per_scalar_word(scalar_kind_t scalar_kind) noexcept {
switch (scalar_kind) {
case scalar_kind_t::uuid_k: return 128;
Expand All @@ -235,6 +263,9 @@ inline std::size_t bits_per_scalar_word(scalar_kind_t scalar_kind) noexcept {
}
}

/**
* @brief Returns the string name of a given scalar type.
*/
inline char const* scalar_kind_name(scalar_kind_t scalar_kind) noexcept {
switch (scalar_kind) {
case scalar_kind_t::uuid_k: return "uuid";
Expand All @@ -257,6 +288,9 @@ inline char const* scalar_kind_name(scalar_kind_t scalar_kind) noexcept {
}
}

/**
* @brief Returns the string name of a given distance metric.
*/
inline char const* metric_kind_name(metric_kind_t metric) noexcept {
switch (metric) {
case metric_kind_t::unknown_k: return "unknown";
Expand All @@ -273,6 +307,10 @@ inline char const* metric_kind_name(metric_kind_t metric) noexcept {
default: return "";
}
}

/**
* @brief Parses a string to identify the corresponding `scalar_kind_t` enumeration value.
*/
inline expected_gt<scalar_kind_t> scalar_kind_from_name(char const* name, std::size_t len) {
expected_gt<scalar_kind_t> parsed;
if (str_equals(name, len, "f32"))
Expand All @@ -292,10 +330,16 @@ inline expected_gt<scalar_kind_t> scalar_kind_from_name(char const* name, std::s
return parsed;
}

/**
* @brief Parses a string to identify the corresponding `scalar_kind_t` enumeration value.
*/
inline expected_gt<scalar_kind_t> scalar_kind_from_name(char const* name) {
return scalar_kind_from_name(name, std::strlen(name));
}

/**
* @brief Parses a string to identify the corresponding `metric_kind_t` enumeration value.
*/
inline expected_gt<metric_kind_t> metric_from_name(char const* name, std::size_t len) {
expected_gt<metric_kind_t> parsed;
if (str_equals(name, len, "l2sq") || str_equals(name, len, "euclidean_sq")) {
Expand All @@ -321,6 +365,10 @@ inline expected_gt<metric_kind_t> metric_from_name(char const* name, std::size_t
"tanimoto, sorensen");
return parsed;
}

/**
* @brief Parses a string to identify the corresponding `metric_kind_t` enumeration value.
*/
inline expected_gt<metric_kind_t> metric_from_name(char const* name) {
return metric_from_name(name, std::strlen(name));
}
Expand Down Expand Up @@ -417,7 +465,7 @@ class f16_bits_t {
inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) {}
inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(static_cast<float>(v))) {}

inline bool operator<(const f16_bits_t& other) const noexcept { return float(*this) < float(other); }
inline bool operator<(f16_bits_t const& other) const noexcept { return float(*this) < float(other); }

inline f16_bits_t operator+(f16_bits_t other) const noexcept { return {float(*this) + float(other)}; }
inline f16_bits_t operator-(f16_bits_t other) const noexcept { return {float(*this) - float(other)}; }
Expand Down Expand Up @@ -1419,7 +1467,10 @@ template <typename scalar_at = float, typename result_at = float> struct metric_
}
};

struct cos_i8_t {
/**
* @brief Cosine (Angular) distance for signed 8-bit integers using 16-bit intermediates.
*/
struct metric_cos_i8_t {
using scalar_t = i8_t;
using result_t = f32_t;

Expand All @@ -1445,7 +1496,11 @@ struct cos_i8_t {
}
};

struct l2sq_i8_t {
/**
* @brief Squared Euclidean (L2) distance for signed 8-bit integers using 16-bit intermediates.
* Square root is avoided at the end, as it won't affect the ordering.
*/
struct metric_l2sq_i8_t {
using scalar_t = i8_t;
using result_t = f32_t;

Expand Down Expand Up @@ -1775,7 +1830,7 @@ class metric_punned_t {
case metric_kind_t::cos_k: {
switch (scalar_kind_) {
case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<bf16_t, f32_t>>; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<i8_t, f32_t>>; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_i8_t>; break;
case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<f16_t, f32_t>>; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_cos_gt<f64_t>>; break;
Expand All @@ -1786,7 +1841,7 @@ class metric_punned_t {
case metric_kind_t::l2sq_k: {
switch (scalar_kind_) {
case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<bf16_t, f32_t>>; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<i8_t, f32_t>>; break;
case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_i8_t>; break;
case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<f16_t, f32_t>>; break;
case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<f32_t>>; break;
case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_<metric_l2sq_gt<f64_t>>; break;
Expand Down
3 changes: 2 additions & 1 deletion javascript/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ wasmer install unum/usearch
Create an index, add vectors, and perform searches with ease:

```js
const assert = require('node:assert');
const usearch = require('usearch');
const index = new usearch.Index({ metric: 'cos', connectivity: 16, dimensions: 3 });
const index = new usearch.Index({ metric: 'l2sq', connectivity: 16, dimensions: 3 });
index.add(42n, new Float32Array([0.2, 0.6, 0.4]));
const results = index.search(new Float32Array([0.2, 0.6, 0.4]), 10);

Expand Down
4 changes: 2 additions & 2 deletions javascript/usearch.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ test('Batch operations', () => {
});

test("Expected results", () => {
var index = new usearch.Index({
const index = new usearch.Index({
metric: "l2sq",
connectivity: 16,
dimensions: 3,
});
index.add(42n, new Float32Array([0.2, 0.6, 0.4]));
var results = index.search(new Float32Array([0.2, 0.6, 0.4]), 10);
const results = index.search(new Float32Array([0.2, 0.6, 0.4]), 10);

assert.equal(index.size(), 1);
assert.deepEqual(results.keys, new BigUint64Array([42n]));
Expand Down
6 changes: 3 additions & 3 deletions rust/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ NativeIndex::NativeIndex(std::unique_ptr<index_t> index) : index_(std::move(inde
auto make_predicate(uptr_t metric, uptr_t metric_state) {
return [=](vector_key_t key) {
auto func = reinterpret_cast<bool (*)(uptr_t, vector_key_t)>(metric);
auto state = reinterpret_cast<uptr_t>(metric_state);
auto state = static_cast<uptr_t>(metric_state);
return func(key, state);
};
}
Expand Down Expand Up @@ -104,8 +104,8 @@ void NativeIndex::change_expansion_search(size_t n) const { index_->change_expan

void NativeIndex::change_metric(uptr_t metric, uptr_t state) const {
index_->change_metric(metric_punned_t::stateful( //
reinterpret_cast<std::uintptr_t>(metric), //
reinterpret_cast<std::uintptr_t>(state), //
static_cast<std::uintptr_t>(metric), //
static_cast<std::uintptr_t>(state), //
index_->metric().metric_kind(), //
index_->scalar_kind()));
}
Expand Down

0 comments on commit f336a06

Please sign in to comment.