@@ -406,8 +406,6 @@ class index_dense_gt {
406
406
using tape_allocator_t = memory_mapping_allocator_gt<64 >;
407
407
408
408
private:
409
- // / @brief Schema: input buffer, bytes in input buffer, output buffer.
410
- using cast_t = bool (*)(byte_t const *, std::size_t , byte_t *);
411
409
// / @brief Punned index.
412
410
using index_t = index_gt< //
413
411
distance_t , vector_key_t , compressed_slot_t , //
@@ -446,19 +444,7 @@ class index_dense_gt {
446
444
447
445
// / @brief Temporary memory for every thread to store a casted vector.
448
446
mutable cast_buffer_t cast_buffer_;
449
- struct casts_t {
450
- cast_t from_b1x8;
451
- cast_t from_i8;
452
- cast_t from_f16;
453
- cast_t from_f32;
454
- cast_t from_f64;
455
-
456
- cast_t to_b1x8;
457
- cast_t to_i8;
458
- cast_t to_f16;
459
- cast_t to_f32;
460
- cast_t to_f64;
461
- } casts_;
447
+ casts_punned_t casts_;
462
448
463
449
// / @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks.
464
450
metric_t metric_;
@@ -677,7 +663,7 @@ class index_dense_gt {
677
663
// In some cases the metric is not provided, and will be set later.
678
664
if (metric) {
679
665
scalar_kind_t scalar_kind = metric.scalar_kind ();
680
- index .casts_ = make_casts_ (scalar_kind);
666
+ index .casts_ = casts_punned_t::make (scalar_kind);
681
667
index .metric_ = metric;
682
668
}
683
669
@@ -767,41 +753,41 @@ class index_dense_gt {
767
753
};
768
754
769
755
// clang-format off
770
- add_result_t add (vector_key_t key, b1x8_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from_b1x8 ); }
771
- add_result_t add (vector_key_t key, i8_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from_i8 ); }
772
- add_result_t add (vector_key_t key, f16_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from_f16 ); }
773
- add_result_t add (vector_key_t key, f32_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from_f32 ); }
774
- add_result_t add (vector_key_t key, f64_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from_f64 ); }
775
-
776
- search_result_t search (b1x8_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_b1x8 ); }
777
- search_result_t search (i8_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_i8 ); }
778
- search_result_t search (f16_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f16 ); }
779
- search_result_t search (f32_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f32 ); }
780
- search_result_t search (f64_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f64 ); }
781
-
782
- template <typename predicate_at> search_result_t filtered_search (b1x8_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_b1x8 ); }
783
- template <typename predicate_at> search_result_t filtered_search (i8_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_i8 ); }
784
- template <typename predicate_at> search_result_t filtered_search (f16_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f16 ); }
785
- template <typename predicate_at> search_result_t filtered_search (f32_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f32 ); }
786
- template <typename predicate_at> search_result_t filtered_search (f64_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from_f64 ); }
787
-
788
- std::size_t get (vector_key_t key, b1x8_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to_b1x8 ); }
789
- std::size_t get (vector_key_t key, i8_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to_i8 ); }
790
- std::size_t get (vector_key_t key, f16_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to_f16 ); }
791
- std::size_t get (vector_key_t key, f32_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to_f32 ); }
792
- std::size_t get (vector_key_t key, f64_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to_f64 ); }
793
-
794
- cluster_result_t cluster (b1x8_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from_b1x8 ); }
795
- cluster_result_t cluster (i8_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from_i8 ); }
796
- cluster_result_t cluster (f16_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from_f16 ); }
797
- cluster_result_t cluster (f32_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from_f32 ); }
798
- cluster_result_t cluster (f64_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from_f64 ); }
799
-
800
- aggregated_distances_t distance_between (vector_key_t key, b1x8_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to_b1x8 ); }
801
- aggregated_distances_t distance_between (vector_key_t key, i8_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to_i8 ); }
802
- aggregated_distances_t distance_between (vector_key_t key, f16_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to_f16 ); }
803
- aggregated_distances_t distance_between (vector_key_t key, f32_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to_f32 ); }
804
- aggregated_distances_t distance_between (vector_key_t key, f64_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to_f64 ); }
756
+ add_result_t add (vector_key_t key, b1x8_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from . b1x8 ); }
757
+ add_result_t add (vector_key_t key, i8_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from . i8 ); }
758
+ add_result_t add (vector_key_t key, f16_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from . f16 ); }
759
+ add_result_t add (vector_key_t key, f32_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from . f32 ); }
760
+ add_result_t add (vector_key_t key, f64_t const * vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_ (key, vector, thread, force_vector_copy, casts_.from . f64 ); }
761
+
762
+ search_result_t search (b1x8_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from . b1x8 ); }
763
+ search_result_t search (i8_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from . i8 ); }
764
+ search_result_t search (f16_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from . f16 ); }
765
+ search_result_t search (f32_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from . f32 ); }
766
+ search_result_t search (f64_t const * vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from . f64 ); }
767
+
768
+ template <typename predicate_at> search_result_t filtered_search (b1x8_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from . b1x8 ); }
769
+ template <typename predicate_at> search_result_t filtered_search (i8_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from . i8 ); }
770
+ template <typename predicate_at> search_result_t filtered_search (f16_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from . f16 ); }
771
+ template <typename predicate_at> search_result_t filtered_search (f32_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from . f32 ); }
772
+ template <typename predicate_at> search_result_t filtered_search (f64_t const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_ (vector, wanted, std::forward<predicate_at>(predicate), thread, exact, casts_.from . f64 ); }
773
+
774
+ std::size_t get (vector_key_t key, b1x8_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to . b1x8 ); }
775
+ std::size_t get (vector_key_t key, i8_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to . i8 ); }
776
+ std::size_t get (vector_key_t key, f16_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to . f16 ); }
777
+ std::size_t get (vector_key_t key, f32_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to . f32 ); }
778
+ std::size_t get (vector_key_t key, f64_t * vector, std::size_t vectors_count = 1 ) const { return get_ (key, vector, vectors_count, casts_.to . f64 ); }
779
+
780
+ cluster_result_t cluster (b1x8_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from . b1x8 ); }
781
+ cluster_result_t cluster (i8_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from . i8 ); }
782
+ cluster_result_t cluster (f16_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from . f16 ); }
783
+ cluster_result_t cluster (f32_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from . f32 ); }
784
+ cluster_result_t cluster (f64_t const * vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_ (vector, level, thread, casts_.from . f64 ); }
785
+
786
+ aggregated_distances_t distance_between (vector_key_t key, b1x8_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to . b1x8 ); }
787
+ aggregated_distances_t distance_between (vector_key_t key, i8_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to . i8 ); }
788
+ aggregated_distances_t distance_between (vector_key_t key, f16_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to . f16 ); }
789
+ aggregated_distances_t distance_between (vector_key_t key, f32_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to . f32 ); }
790
+ aggregated_distances_t distance_between (vector_key_t key, f64_t const * vector, std::size_t thread = any_thread()) const { return distance_between_ (key, vector, thread, casts_.to . f64 ); }
805
791
// clang-format on
806
792
807
793
/* *
@@ -1154,7 +1140,7 @@ class index_dense_gt {
1154
1140
cast_buffer_ = cast_buffer_t (available_threads_.size () * metric_.bytes_per_vector ());
1155
1141
if (!cast_buffer_)
1156
1142
return result.failed (" Failed to allocate memory for the casts" );
1157
- casts_ = make_casts_ (head.kind_scalar );
1143
+ casts_ = casts_punned_t::make (head.kind_scalar );
1158
1144
}
1159
1145
1160
1146
// Pull the actual proximity graph
@@ -1266,7 +1252,7 @@ class index_dense_gt {
1266
1252
cast_buffer_ = cast_buffer_t (available_threads_.size () * metric_.bytes_per_vector ());
1267
1253
if (!cast_buffer_)
1268
1254
return result.failed (" Failed to allocate memory for the casts" );
1269
- casts_ = make_casts_ (head.kind_scalar );
1255
+ casts_ = casts_punned_t::make (head.kind_scalar );
1270
1256
offset += sizeof (buffer);
1271
1257
}
1272
1258
@@ -1994,7 +1980,7 @@ class index_dense_gt {
1994
1980
template <typename scalar_at>
1995
1981
add_result_t add_ ( //
1996
1982
vector_key_t key, scalar_at const * vector, //
1997
- std::size_t thread, bool force_vector_copy, cast_t const & cast) {
1983
+ std::size_t thread, bool force_vector_copy, cast_punned_t const & cast) {
1998
1984
1999
1985
if (!multi () && config ().enable_key_lookups && contains (key))
2000
1986
return add_result_t {}.failed (" Duplicate keys not allowed in high-level wrappers" );
@@ -2044,7 +2030,7 @@ class index_dense_gt {
2044
2030
2045
2031
template <typename scalar_at, typename predicate_at>
2046
2032
search_result_t search_ (scalar_at const * vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread,
2047
- bool exact, cast_t const & cast) const {
2033
+ bool exact, cast_punned_t const & cast) const {
2048
2034
2049
2035
// Cast the vector, if needed for compatibility with `metric_`
2050
2036
thread_lock_t lock = thread_lock_ (thread);
@@ -2080,7 +2066,7 @@ class index_dense_gt {
2080
2066
template <typename scalar_at>
2081
2067
cluster_result_t cluster_ ( //
2082
2068
scalar_at const * vector, std::size_t level, //
2083
- std::size_t thread, cast_t const & cast) const {
2069
+ std::size_t thread, cast_punned_t const & cast) const {
2084
2070
2085
2071
// Cast the vector, if needed for compatibility with `metric_`
2086
2072
thread_lock_t lock = thread_lock_ (thread);
@@ -2104,7 +2090,7 @@ class index_dense_gt {
2104
2090
template <typename scalar_at>
2105
2091
aggregated_distances_t distance_between_ ( //
2106
2092
vector_key_t key, scalar_at const * vector, //
2107
- std::size_t thread, cast_t const & cast) const {
2093
+ std::size_t thread, cast_punned_t const & cast) const {
2108
2094
2109
2095
// Cast the vector, if needed for compatibility with `metric_`
2110
2096
thread_lock_t lock = thread_lock_ (thread);
@@ -2181,7 +2167,8 @@ class index_dense_gt {
2181
2167
}
2182
2168
2183
2169
template <typename scalar_at>
2184
- std::size_t get_ (vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, cast_t const & cast) const {
2170
+ std::size_t get_ (vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit,
2171
+ cast_punned_t const & cast) const {
2185
2172
2186
2173
if (!multi ()) {
2187
2174
compressed_slot_t slot;
@@ -2216,36 +2203,6 @@ class index_dense_gt {
2216
2203
return count_exported;
2217
2204
}
2218
2205
}
2219
-
2220
- template <typename to_scalar_at> static casts_t make_casts_ () {
2221
- casts_t result;
2222
-
2223
- result.from_b1x8 = &cast_gt<b1x8_t , to_scalar_at>::try_;
2224
- result.from_i8 = &cast_gt<i8_t , to_scalar_at>::try_;
2225
- result.from_f16 = &cast_gt<f16_t , to_scalar_at>::try_;
2226
- result.from_f32 = &cast_gt<f32_t , to_scalar_at>::try_;
2227
- result.from_f64 = &cast_gt<f64_t , to_scalar_at>::try_;
2228
-
2229
- result.to_b1x8 = &cast_gt<to_scalar_at, b1x8_t >::try_;
2230
- result.to_i8 = &cast_gt<to_scalar_at, i8_t >::try_;
2231
- result.to_f16 = &cast_gt<to_scalar_at, f16_t >::try_;
2232
- result.to_f32 = &cast_gt<to_scalar_at, f32_t >::try_;
2233
- result.to_f64 = &cast_gt<to_scalar_at, f64_t >::try_;
2234
-
2235
- return result;
2236
- }
2237
-
2238
- static casts_t make_casts_ (scalar_kind_t scalar_kind) {
2239
- switch (scalar_kind) {
2240
- case scalar_kind_t ::f64_k: return make_casts_<f64_t >();
2241
- case scalar_kind_t ::f32_k: return make_casts_<f32_t >();
2242
- case scalar_kind_t ::f16_k: return make_casts_<f16_t >();
2243
- case scalar_kind_t ::bf16_k: return make_casts_<bf16_t >();
2244
- case scalar_kind_t ::i8_k: return make_casts_<i8_t >();
2245
- case scalar_kind_t ::b1x8_k: return make_casts_<b1x8_t >();
2246
- default : return {};
2247
- }
2248
- }
2249
2206
};
2250
2207
2251
2208
using index_dense_t = index_dense_gt<>;
0 commit comments