Skip to content

Commit

Permalink
Add test case for dev IVF binary index
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jun 19, 2024
1 parent 6e30a3b commit 2757338
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
#include "faiss_wrapper.h"

#include <vector>
#include <faiss/IndexBinaryFlat.h>
#include <faiss/IndexBinaryIVF.h>
#include <faiss/index_io.h>

#include <gtest/gtest.h>
#include <random>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
Expand Down Expand Up @@ -479,6 +485,68 @@ TEST(FaissTrainIndexTest, BasicAssertions) {
ASSERT_TRUE(trainedIndex->is_trained);
}

// Utility function to generate random binary data
std::vector<uint8_t> generateRandomBinaryData(int dim, int numVectors) {
std::vector<uint8_t> data(dim / 8 * numVectors);
std::default_random_engine engine;
std::uniform_int_distribution<uint8_t> distribution(0, 255); // Generate bytes
for (auto& byte : data) {
byte = distribution(engine);
}
return data;
}

TEST(FaissBinaryIVFIndexTest, BasicIVFSearch) {
// Dimension of the vectors, should be a multiple of 8.
int d = 256;

// Number of database vectors, training vectors, and query vectors
int nb = 1000; // Database vectors
int nt = 500; // Training vectors
int nq = 10; // Query vectors

// Generate binary data for db, training, and queries
std::vector<uint8_t> db = generateRandomBinaryData(d, nb);
std::vector<uint8_t> training = generateRandomBinaryData(d, nt);
std::vector<uint8_t> queries = generateRandomBinaryData(d, nq);

// Initializing the quantizer
faiss::IndexBinaryFlat quantizer(d);

// Number of clusters
int nlist = 100;

// Initializing index
faiss::IndexBinaryIVF index(&quantizer, d, nlist);
index.nprobe = 4; // Number of nearest clusters to be searched per query.

// Training the quantizer
index.train(nt, training.data());

// Adding the database vectors
index.add(nb, db.data());

// Number of nearest neighbors to retrieve per query vector
int k = 4;

// Output variables for the queries
std::vector<int32_t> distances(nq * k);
std::vector<faiss::idx_t> labels(nq * k);

// Querying the index
index.search(nq, queries.data(), k, distances.data(), labels.data());

// Check the results
for (int i = 0; i < nq; ++i) {
for (int j = 0; j < k; ++j) {
int idx = i * k + j;
// Ensure that distances are non-negative and labels are within valid range
ASSERT_GE(distances[idx], 0);
ASSERT_LT(labels[idx], nb);
}
}
}

TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
Expand Down

0 comments on commit 2757338

Please sign in to comment.