@@ -7,6 +7,7 @@ package faiss
77#include <faiss/c_api/IndexIVF_c.h>
88#include <faiss/c_api/IndexIVF_c_ex.h>
99#include <faiss/c_api/IndexBinary_c.h>
10+ #include <faiss/c_api/IndexBinaryIVF_c.h>
1011#include <faiss/c_api/index_factory_c.h>
1112#include <faiss/c_api/MetaIndexes_c.h>
1213#include <faiss/c_api/impl/AuxIndexStructures_c.h>
@@ -54,6 +55,10 @@ type BinaryIndex interface {
5455 SearchBinaryWithIDs (x []uint8 , k int64 , include []int64 , params json.RawMessage ) ([]int32 , []int64 , error )
5556 SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 ,
5657 labels []int64 , err error )
58+
59+ ObtainClusterVectorCountsFromIVFIndex (vecIDs []int64 ) (map [int64 ]int64 , error )
60+ ObtainClustersWithDistancesFromIVFIndex (x []uint8 , centroidIDs []int64 ) (
61+ []int64 , []int32 , error )
5762}
5863
5964// FloatIndex defines methods specific to float-based FAISS indexes
@@ -156,6 +161,62 @@ func (idx *BinaryIndexImpl) Close() {
156161 }
157162}
158163
164+ func (idx * BinaryIndexImpl ) ObtainClusterVectorCountsFromIVFIndex (vecIDs []int64 ) (map [int64 ]int64 , error ) {
165+ if ! idx .IsIVFIndex () {
166+ return nil , fmt .Errorf ("index is not an IVF index" )
167+ }
168+ clusterIDs := make ([]int64 , len (vecIDs ))
169+ if c := C .faiss_get_lists_for_keys_binary (
170+ idx .indexPtr ,
171+ (* C .idx_t )(unsafe .Pointer (& vecIDs [0 ])),
172+ (C .size_t )(len (vecIDs )),
173+ (* C .idx_t )(unsafe .Pointer (& clusterIDs [0 ])),
174+ ); c != 0 {
175+ return nil , getLastError ()
176+ }
177+ rv := make (map [int64 ]int64 , len (vecIDs ))
178+ for _ , v := range clusterIDs {
179+ rv [v ]++
180+ }
181+ return rv , nil
182+ }
183+
184+ func (idx * BinaryIndexImpl ) ObtainClustersWithDistancesFromIVFIndex (x []uint8 , centroidIDs []int64 ) (
185+ []int64 , []int32 , error ) {
186+ // Selector to include only the centroids whose IDs are part of 'centroidIDs'.
187+ includeSelector , err := NewIDSelectorBatch (centroidIDs )
188+ if err != nil {
189+ return nil , nil , err
190+ }
191+ defer includeSelector .Delete ()
192+
193+ params , err := NewSearchParams (idx , json.RawMessage {}, includeSelector .Get (), nil )
194+ if err != nil {
195+ return nil , nil , err
196+ }
197+ defer params .Delete ()
198+
199+ // Populate these with the centroids and their distances.
200+ centroids := make ([]int64 , len (centroidIDs ))
201+ centroidDistances := make ([]int32 , len (centroidIDs ))
202+
203+ n := len (x ) / idx .D ()
204+
205+ c := C .faiss_Search_closest_eligible_centroids_binary (
206+ idx .indexPtr ,
207+ (C .idx_t )(n ),
208+ (* C .uint8_t )(& x [0 ]),
209+ (C .idx_t )(len (centroidIDs )),
210+ (* C .int32_t )(& centroidDistances [0 ]),
211+ (* C .idx_t )(& centroids [0 ]),
212+ params .sp )
213+ if c != 0 {
214+ return nil , nil , getLastError ()
215+ }
216+
217+ return centroids , centroidDistances , nil
218+ }
219+
159220func (idx * BinaryIndexImpl ) Size () uint64 {
160221 return 0
161222}
@@ -263,7 +324,7 @@ func (idx *BinaryIndexImpl) Train(vectors []uint8) error {
263324}
264325
265326func (idx * BinaryIndexImpl ) SearchBinaryWithoutIDs (x []uint8 , k int64 , exclude []int64 , params json.RawMessage ) (distances []int32 , labels []int64 , err error ) {
266- if len (exclude ) == 0 && params == nil {
327+ if len (exclude ) == 0 && len ( params ) == 0 {
267328 return idx .SearchBinary (x , k )
268329 }
269330
0 commit comments