@@ -452,7 +452,7 @@ static py::tuple search_many_in_index( //
452
452
}
453
453
454
454
/* *
455
- * @brief Brute-force exact search implementation, compatible with
455
+ * @brief Brute-force @b exact search implementation, compatible with
456
456
* NumPy-like Tensors and other objects supporting Buffer Protocol.
457
457
*/
458
458
static py::tuple search_many_brute_force ( //
@@ -545,6 +545,81 @@ static py::tuple search_many_brute_force( //
545
545
return results;
546
546
}
547
547
548
+ /* *
549
+ * @brief Brute-force @b K-Means clustering, compatible with
550
+ * NumPy-like Tensors and other objects supporting Buffer Protocol.
551
+ */
552
+ static py::tuple cluster_many_brute_force ( //
553
+ py::buffer dataset, //
554
+ std::size_t wanted, //
555
+ std::size_t max_iterations, //
556
+ double inertia_threshold, //
557
+ double max_seconds, //
558
+ double min_shifts, //
559
+ std::uint64_t seed, //
560
+ std::size_t threads, //
561
+ scalar_kind_t scalar_kind, //
562
+ metric_kind_t metric_kind, //
563
+ progress_func_t const & progress_func) {
564
+
565
+ using distance_t = typename kmeans_clustering_t ::distance_t ;
566
+ py::buffer_info dataset_info = dataset.request ();
567
+ if (dataset_info.ndim != 2 )
568
+ throw std::invalid_argument (" Expects a matrix (rank-2 tensor) of dataset to cluster!" );
569
+
570
+ std::size_t dataset_count = static_cast <std::size_t >(dataset_info.shape [0 ]);
571
+ std::size_t dataset_dimensions = static_cast <std::size_t >(dataset_info.shape [1 ]);
572
+ std::size_t dataset_stride = static_cast <std::size_t >(dataset_info.strides [0 ]);
573
+ scalar_kind_t dataset_kind = numpy_string_to_kind (dataset_info.format );
574
+ std::size_t bytes_per_scalar = bits_per_scalar_word (dataset_kind) / CHAR_BIT;
575
+
576
+ std::vector<std::size_t > point_to_centroid_index (dataset_count, 0 );
577
+ std::vector<distance_t > point_to_centroid_distance (dataset_count, 0 );
578
+ std::vector<byte_t > centroids (wanted * dataset_dimensions * bytes_per_scalar, 0 );
579
+
580
+ if (!threads)
581
+ threads = std::thread::hardware_concurrency ();
582
+
583
+ // Dispatch brute-force search
584
+ progress_t progress{progress_func};
585
+ executor_default_t executor{threads};
586
+ kmeans_clustering_t engine;
587
+ engine.metric_kind = metric_kind;
588
+ engine.quantization_kind = scalar_kind;
589
+ engine.max_iterations = max_iterations;
590
+ engine.min_shifts = min_shifts;
591
+ engine.max_seconds = max_seconds;
592
+ engine.inertia_threshold = inertia_threshold;
593
+
594
+ kmeans_clustering_result_t result = engine ( //
595
+ reinterpret_cast <byte_t const *>(dataset_info.ptr ), dataset_count, dataset_stride, //
596
+ centroids.data (), wanted, dataset_dimensions * bytes_per_scalar, //
597
+ point_to_centroid_index.data (), point_to_centroid_distance.data (), dataset_kind, dataset_dimensions, executor,
598
+ [&](std::size_t passed, std::size_t total) { return PyErr_CheckSignals () == 0 && progress (passed, total); });
599
+
600
+ if (!result)
601
+ throw std::runtime_error (result.error .release ());
602
+
603
+ // Following constructor doesn't seem to be documented, but it's used in the source code of `pybind11`
604
+ // https://github.com/pybind/pybind11/blob/aeda49ed0b4e6e8abba7abc265ace86a6c26ba66/include/pybind11/numpy.h#L918-L919
605
+ // https://github.com/pybind/pybind11/blob/aeda49ed0b4e6e8abba7abc265ace86a6c26ba66/include/pybind11/buffer_info.h#L60-L75
606
+ py::buffer_info centroids_info;
607
+ centroids_info.ptr = reinterpret_cast <void *>(centroids.data ());
608
+ centroids_info.itemsize = dataset_info.itemsize ;
609
+ centroids_info.size = wanted * dataset_dimensions;
610
+ centroids_info.format = dataset_info.format ;
611
+ centroids_info.ndim = 2 ;
612
+ centroids_info.shape = {wanted, dataset_dimensions};
613
+ centroids_info.strides = {dataset_dimensions * bytes_per_scalar, bytes_per_scalar};
614
+
615
+ py::tuple results (3 );
616
+ results[0 ] = py::array_t <std::size_t >({dataset_count}, point_to_centroid_index.data ());
617
+ results[1 ] = py::array_t <distance_t >({dataset_count}, point_to_centroid_distance.data ());
618
+ results[2 ] = py::array (centroids_info);
619
+
620
+ return results;
621
+ }
622
+
548
623
template <typename scalar_at> struct rows_lookup_gt {
549
624
byte_t * data_;
550
625
std::size_t stride_;
@@ -936,16 +1011,33 @@ PYBIND11_MODULE(compiled, m) {
936
1011
return index_metadata (meta);
937
1012
});
938
1013
939
- m.def (" exact_search" , &search_many_brute_force, //
940
- py::arg (" dataset" ), //
941
- py::arg (" queries" ), //
942
- py::arg (" count" ) = 10 , //
943
- py::kw_only (), //
944
- py::arg (" threads" ) = 0 , //
945
- py::arg (" metric_kind" ) = metric_kind_t ::cos_k, //
946
- py::arg (" metric_signature" ) = metric_punned_signature_t ::array_array_k, //
947
- py::arg (" metric_pointer" ) = 0 , //
948
- py::arg (" progress" ) = nullptr //
1014
+ m.def ( //
1015
+ " exact_search" , &search_many_brute_force, //
1016
+ py::arg (" dataset" ), //
1017
+ py::arg (" queries" ), //
1018
+ py::arg (" count" ) = 10 , //
1019
+ py::kw_only (), //
1020
+ py::arg (" threads" ) = 0 , //
1021
+ py::arg (" metric_kind" ) = metric_kind_t ::cos_k, //
1022
+ py::arg (" metric_signature" ) = metric_punned_signature_t ::array_array_k, //
1023
+ py::arg (" metric_pointer" ) = 0 , //
1024
+ py::arg (" progress" ) = nullptr //
1025
+ );
1026
+
1027
+ m.def ( //
1028
+ " kmeans" , &cluster_many_brute_force, //
1029
+ py::arg (" dataset" ), //
1030
+ py::arg (" count" ) = 10 , //
1031
+ py::kw_only (), //
1032
+ py::arg (" max_iterations" ) = kmeans_clustering_t ::max_iterations_default_k, //
1033
+ py::arg (" inertia_threshold" ) = kmeans_clustering_t ::inertia_threshold_default_k, //
1034
+ py::arg (" max_seconds" ) = kmeans_clustering_t ::max_seconds_default_k, //
1035
+ py::arg (" min_shifts" ) = kmeans_clustering_t ::min_shifts_default_k, //
1036
+ py::arg (" seed" ) = 0 , //
1037
+ py::arg (" threads" ) = 0 , //
1038
+ py::arg (" dtype" ) = scalar_kind_t ::bf16_k, //
1039
+ py::arg (" metric_kind" ) = metric_kind_t ::l2sq_k, //
1040
+ py::arg (" progress" ) = nullptr //
949
1041
);
950
1042
951
1043
m.def (
@@ -961,18 +1053,19 @@ PYBIND11_MODULE(compiled, m) {
961
1053
962
1054
auto i = py::class_<dense_index_py_t , std::shared_ptr<dense_index_py_t >>(m, " Index" );
963
1055
964
- i.def (py::init (&make_index), //
965
- py::kw_only (), //
966
- py::arg (" ndim" ) = 0 , //
967
- py::arg (" dtype" ) = scalar_kind_t ::f32_k, //
968
- py::arg (" connectivity" ) = default_connectivity (), //
969
- py::arg (" expansion_add" ) = default_expansion_add (), //
970
- py::arg (" expansion_search" ) = default_expansion_search (), //
971
- py::arg (" metric_kind" ) = metric_kind_t ::cos_k, //
972
- py::arg (" metric_signature" ) = metric_punned_signature_t ::array_array_k, //
973
- py::arg (" metric_pointer" ) = 0 , //
974
- py::arg (" multi" ) = false , //
975
- py::arg (" enable_key_lookups" ) = true //
1056
+ i.def ( //
1057
+ py::init (&make_index), //
1058
+ py::kw_only (), //
1059
+ py::arg (" ndim" ) = 0 , //
1060
+ py::arg (" dtype" ) = scalar_kind_t ::f32_k, //
1061
+ py::arg (" connectivity" ) = default_connectivity (), //
1062
+ py::arg (" expansion_add" ) = default_expansion_add (), //
1063
+ py::arg (" expansion_search" ) = default_expansion_search (), //
1064
+ py::arg (" metric_kind" ) = metric_kind_t ::cos_k, //
1065
+ py::arg (" metric_signature" ) = metric_punned_signature_t ::array_array_k, //
1066
+ py::arg (" metric_pointer" ) = 0 , //
1067
+ py::arg (" multi" ) = false , //
1068
+ py::arg (" enable_key_lookups" ) = true //
976
1069
);
977
1070
978
1071
i.def ( //
0 commit comments