Skip to content

Commit c059ba3

Browse files
ashvardanianmbautin
andcommitted
Improve: Separate casts_punned_t
Separating vector-casting logic will make it easier to extend the type system, potentially adding `bf16` and `u8` high-level `add`, `get`, and`search` APIs down the road. Related to #469 Co-authored-by: Mikhail Bautin <[email protected]>
1 parent 0ab569d commit c059ba3

File tree

2 files changed

+110
-88
lines changed

2 files changed

+110
-88
lines changed

include/usearch/index_dense.hpp

+45-88
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,6 @@ class index_dense_gt {
406406
using tape_allocator_t = memory_mapping_allocator_gt<64>;
407407

408408
private:
409-
/// @brief Schema: input buffer, bytes in input buffer, output buffer.
410-
using cast_t = bool (*)(byte_t const*, std::size_t, byte_t*);
411409
/// @brief Punned index.
412410
using index_t = index_gt< //
413411
distance_t, vector_key_t, compressed_slot_t, //
@@ -446,19 +444,7 @@ class index_dense_gt {
446444

447445
/// @brief Temporary memory for every thread to store a casted vector.
448446
mutable cast_buffer_t cast_buffer_;
449-
struct casts_t {
450-
cast_t from_b1x8;
451-
cast_t from_i8;
452-
cast_t from_f16;
453-
cast_t from_f32;
454-
cast_t from_f64;
455-
456-
cast_t to_b1x8;
457-
cast_t to_i8;
458-
cast_t to_f16;
459-
cast_t to_f32;
460-
cast_t to_f64;
461-
} casts_;
447+
casts_punned_t casts_;
462448

463449
/// @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks.
464450
metric_t metric_;
@@ -677,7 +663,7 @@ class index_dense_gt {
677663
// In some cases the metric is not provided, and will be set later.
678664
if (metric) {
679665
scalar_kind_t scalar_kind = metric.scalar_kind();
680-
index.casts_ = make_casts_(scalar_kind);
666+
index.casts_ = casts_punned_t::make(scalar_kind);
681667
index.metric_ = metric;
682668
}
683669

@@ -767,41 +753,41 @@ class index_dense_gt {
767753
};
768754

769755
// clang-format off
770-
add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_b1x8); }
771-
add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_i8); }
772-
add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f16); }
773-
add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f32); }
774-
add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f64); }
775-
776-
search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_b1x8); }
777-
search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_i8); }
778-
search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f16); }
779-
search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f32); }
780-
search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f64); }
781-
782-
template <typename predicate_at> search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_b1x8); }
783-
template <typename predicate_at> search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_i8); }
784-
template <typename predicate_at> search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f16); }
785-
template <typename predicate_at> search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f32); }
786-
template <typename predicate_at> search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f64); }
787-
788-
std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_b1x8); }
789-
std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_i8); }
790-
std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f16); }
791-
std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f32); }
792-
std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_f64); }
793-
794-
cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_b1x8); }
795-
cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_i8); }
796-
cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f16); }
797-
cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f32); }
798-
cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from_f64); }
799-
800-
aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_b1x8); }
801-
aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_i8); }
802-
aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f16); }
803-
aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f32); }
804-
aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f64); }
756+
add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.b1x8); }
757+
add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.i8); }
758+
add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f16); }
759+
add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f32); }
760+
add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from.f64); }
761+
762+
search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.b1x8); }
763+
search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.i8); }
764+
search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f16); }
765+
search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f32); }
766+
search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f64); }
767+
768+
template <typename predicate_at> search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.b1x8); }
769+
template <typename predicate_at> search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.i8); }
770+
template <typename predicate_at> search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f16); }
771+
template <typename predicate_at> search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f32); }
772+
template <typename predicate_at> search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from.f64); }
773+
774+
std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.b1x8); }
775+
std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.i8); }
776+
std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f16); }
777+
std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f32); }
778+
std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f64); }
779+
780+
cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.b1x8); }
781+
cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.i8); }
782+
cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f16); }
783+
cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f32); }
784+
cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f64); }
785+
786+
aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.b1x8); }
787+
aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.i8); }
788+
aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f16); }
789+
aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f32); }
790+
aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f64); }
805791
// clang-format on
806792

807793
/**
@@ -1154,7 +1140,7 @@ class index_dense_gt {
11541140
cast_buffer_ = cast_buffer_t(available_threads_.size() * metric_.bytes_per_vector());
11551141
if (!cast_buffer_)
11561142
return result.failed("Failed to allocate memory for the casts");
1157-
casts_ = make_casts_(head.kind_scalar);
1143+
casts_ = casts_punned_t::make(head.kind_scalar);
11581144
}
11591145

11601146
// Pull the actual proximity graph
@@ -1266,7 +1252,7 @@ class index_dense_gt {
12661252
cast_buffer_ = cast_buffer_t(available_threads_.size() * metric_.bytes_per_vector());
12671253
if (!cast_buffer_)
12681254
return result.failed("Failed to allocate memory for the casts");
1269-
casts_ = make_casts_(head.kind_scalar);
1255+
casts_ = casts_punned_t::make(head.kind_scalar);
12701256
offset += sizeof(buffer);
12711257
}
12721258

@@ -1994,7 +1980,7 @@ class index_dense_gt {
19941980
template <typename scalar_at>
19951981
add_result_t add_( //
19961982
vector_key_t key, scalar_at const* vector, //
1997-
std::size_t thread, bool force_vector_copy, cast_t const& cast) {
1983+
std::size_t thread, bool force_vector_copy, cast_punned_t const& cast) {
19981984

19991985
if (!multi() && config().enable_key_lookups && contains(key))
20001986
return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers");
@@ -2044,7 +2030,7 @@ class index_dense_gt {
20442030

20452031
template <typename scalar_at, typename predicate_at>
20462032
search_result_t search_(scalar_at const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread,
2047-
bool exact, cast_t const& cast) const {
2033+
bool exact, cast_punned_t const& cast) const {
20482034

20492035
// Cast the vector, if needed for compatibility with `metric_`
20502036
thread_lock_t lock = thread_lock_(thread);
@@ -2080,7 +2066,7 @@ class index_dense_gt {
20802066
template <typename scalar_at>
20812067
cluster_result_t cluster_( //
20822068
scalar_at const* vector, std::size_t level, //
2083-
std::size_t thread, cast_t const& cast) const {
2069+
std::size_t thread, cast_punned_t const& cast) const {
20842070

20852071
// Cast the vector, if needed for compatibility with `metric_`
20862072
thread_lock_t lock = thread_lock_(thread);
@@ -2104,7 +2090,7 @@ class index_dense_gt {
21042090
template <typename scalar_at>
21052091
aggregated_distances_t distance_between_( //
21062092
vector_key_t key, scalar_at const* vector, //
2107-
std::size_t thread, cast_t const& cast) const {
2093+
std::size_t thread, cast_punned_t const& cast) const {
21082094

21092095
// Cast the vector, if needed for compatibility with `metric_`
21102096
thread_lock_t lock = thread_lock_(thread);
@@ -2181,7 +2167,8 @@ class index_dense_gt {
21812167
}
21822168

21832169
template <typename scalar_at>
2184-
std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, cast_t const& cast) const {
2170+
std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit,
2171+
cast_punned_t const& cast) const {
21852172

21862173
if (!multi()) {
21872174
compressed_slot_t slot;
@@ -2216,36 +2203,6 @@ class index_dense_gt {
22162203
return count_exported;
22172204
}
22182205
}
2219-
2220-
template <typename to_scalar_at> static casts_t make_casts_() {
2221-
casts_t result;
2222-
2223-
result.from_b1x8 = &cast_gt<b1x8_t, to_scalar_at>::try_;
2224-
result.from_i8 = &cast_gt<i8_t, to_scalar_at>::try_;
2225-
result.from_f16 = &cast_gt<f16_t, to_scalar_at>::try_;
2226-
result.from_f32 = &cast_gt<f32_t, to_scalar_at>::try_;
2227-
result.from_f64 = &cast_gt<f64_t, to_scalar_at>::try_;
2228-
2229-
result.to_b1x8 = &cast_gt<to_scalar_at, b1x8_t>::try_;
2230-
result.to_i8 = &cast_gt<to_scalar_at, i8_t>::try_;
2231-
result.to_f16 = &cast_gt<to_scalar_at, f16_t>::try_;
2232-
result.to_f32 = &cast_gt<to_scalar_at, f32_t>::try_;
2233-
result.to_f64 = &cast_gt<to_scalar_at, f64_t>::try_;
2234-
2235-
return result;
2236-
}
2237-
2238-
static casts_t make_casts_(scalar_kind_t scalar_kind) {
2239-
switch (scalar_kind) {
2240-
case scalar_kind_t::f64_k: return make_casts_<f64_t>();
2241-
case scalar_kind_t::f32_k: return make_casts_<f32_t>();
2242-
case scalar_kind_t::f16_k: return make_casts_<f16_t>();
2243-
case scalar_kind_t::bf16_k: return make_casts_<bf16_t>();
2244-
case scalar_kind_t::i8_k: return make_casts_<i8_t>();
2245-
case scalar_kind_t::b1x8_k: return make_casts_<b1x8_t>();
2246-
default: return {};
2247-
}
2248-
}
22492206
};
22502207

22512208
using index_dense_t = index_dense_gt<>;

0 commit comments

Comments
 (0)