Skip to content

Commit 4fae008

Browse files
Fix: Set dims in metric_punned_t::stateful
Co-authored-by: Ash Vardanian <[email protected]> Co-authored-by: Terence Liu <[email protected]> Co-authored-by: Terence Z. Liu <[email protected]>
1 parent dccdd8e commit 4fae008

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

c/lib.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_
354354
USEARCH_ASSERT(index && error && "Missing arguments");
355355
auto& index_dense = *reinterpret_cast<index_dense_t*>(index);
356356
auto metric_punned =
357-
state ? metric_punned_t::stateful(reinterpret_cast<std::uintptr_t>(metric),
357+
state ? metric_punned_t::stateful(index_dense.dimensions(), reinterpret_cast<std::uintptr_t>(metric),
358358
reinterpret_cast<std::uintptr_t>(state), metric_kind_to_cpp(kind),
359359
index_dense.scalar_kind())
360360
: metric_punned_t::stateless(index_dense.dimensions(), reinterpret_cast<std::uintptr_t>(metric),

include/usearch/index_plugins.hpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -1767,20 +1767,22 @@ class metric_punned_t {
17671767
* @brief Creates a metric using the provided function pointer for a stateful metric.
17681768
* The third argument is the state that will be passed to the metric function.
17691769
*
1770+
* @param dimensions The number of elements in the input arrays.
17701771
* @param metric_uintptr The function pointer to the metric function.
17711772
* @param metric_state The state to pass to the metric function.
17721773
* @param metric_kind The kind of metric to use.
17731774
* @param scalar_kind The kind of scalar to use.
17741775
* @return A metric object that can be used to compute distances between vectors.
17751776
*/
1776-
inline static metric_punned_t stateful(std::uintptr_t metric_uintptr, std::uintptr_t metric_state,
1777-
metric_kind_t metric_kind = metric_kind_t::unknown_k,
1778-
scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept {
1777+
inline static metric_punned_t stateful( //
1778+
std::size_t dimensions, std::uintptr_t metric_uintptr, std::uintptr_t metric_state,
1779+
metric_kind_t metric_kind = metric_kind_t::unknown_k,
1780+
scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept {
17791781
metric_punned_t metric;
17801782
metric.metric_routed_ = &metric_punned_t::invoke_array_array_third;
17811783
metric.metric_ptr_ = metric_uintptr;
17821784
metric.metric_third_arg_ = metric_state;
1783-
metric.dimensions_ = 0;
1785+
metric.dimensions_ = dimensions;
17841786
metric.metric_kind_ = metric_kind;
17851787
metric.scalar_kind_ = scalar_kind;
17861788
return metric;
@@ -2223,6 +2225,8 @@ template <typename allocator_at = std::allocator<char>> class kmeans_clustering_
22232225
scalar_kind_t original_scalar_kind, std::size_t dimensions, executor_at&& executor = executor_at{},
22242226
progress_at&& progress = progress_at{}) {
22252227

2228+
(void)progress; // TODO
2229+
22262230
// Perform sanity checks for algorithm settings.
22272231
kmeans_clustering_result_t result;
22282232
if (max_iterations < 1)
@@ -2332,7 +2336,7 @@ template <typename allocator_at = std::allocator<char>> class kmeans_clustering_
23322336

23332337
// For every point, find the closest centroid.
23342338
std::atomic<std::size_t> points_shifted{0};
2335-
executor.dynamic(points_count, [&](std::size_t thread_idx, std::size_t points_idx) {
2339+
executor.dynamic(points_count, [&](std::size_t, std::size_t points_idx) {
23362340
byte_t const* quantized_point =
23372341
points_quantized_buffer.data() + points_idx * stride_per_vector_quantized;
23382342
byte_t const* quantized_centroids = centroids_quantized_buffer.data();

rust/lib.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void NativeIndex::change_expansion_search(size_t n) const { index_->change_expan
104104

105105
void NativeIndex::change_metric(uptr_t metric, uptr_t state) const {
106106
index_->change_metric(metric_punned_t::stateful( //
107+
index_->dimensions(), //
107108
static_cast<std::uintptr_t>(metric), //
108109
static_cast<std::uintptr_t>(state), //
109110
index_->metric().metric_kind(), //

0 commit comments

Comments
 (0)