Skip to content

Commit

Permalink
Support distance type radius search for Faiss engine (#1546)
Browse files Browse the repository at this point in the history
* Support distance type radius search for Faiss engine

Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 1, 2024
1 parent 31f3b9d commit c369ec7
Show file tree
Hide file tree
Showing 17 changed files with 685 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x)
### Features
* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498)
* Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546)
### Enhancements
* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549)
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
Expand Down
12 changes: 12 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ namespace knn_jni {
// Return the serialized representation
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param maxResultsWindowJ - the maximum number of results to return
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultsWindowJ);
}
}

Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[F[F)J
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint);

#ifdef __cplusplus
}
#endif
Expand Down
44 changes: 44 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,47 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {

throw std::runtime_error("Unable to extract IVFPQ index. IVFPQ index not present.");
}

jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) {
if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}

auto *indexReader = reinterpret_cast<faiss::IndexIDMap *>(indexPointerJ);

if (indexReader == nullptr) {
throw std::runtime_error("Invalid pointer to indexReader");
}

float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

// The res will be freed by ~RangeSearchResult() in FAISS
// The second parameter is always true, as lims is allocated by FAISS
faiss::RangeSearchResult res(1, true);
indexReader->range_search(1, rawQueryVector, radiusJ, &res);

// lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries),
// lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN,
// res.lims[0] is always 0, and res.lims[1] gives the total number of matching entries found.
int resultSize = res.lims[1];

// Limit the result size to maxResultWindowJ so that we don't return more than the max result window
// TODO: In the future, we should prevent this via FAISS's ResultHandler.
if (resultSize > maxResultWindowJ) {
resultSize = maxResultWindowJ;
}

jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult");
jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", "<init>");

jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr);

jobject result;
for(int i = 0; i < resultSize; ++i) {
result = jniUtil->NewObject(env, resultClass, allArgs, res.labels[i], res.distances[i]);
jniUtil->SetObjectArrayElement(env, results, i, result);
}

return results;
}
14 changes: 14 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,17 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIE
delete vect;
}
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ);

} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}
113 changes: 113 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,116 @@ TEST(FaissInitAndSetSharedIndexState, BasicAssertions) {
ASSERT_EQ(1, ivfpqIndex->use_precomputed_table);
knn_jni::faiss_wrapper::FreeSharedIndexState(sharedModelAddress);
}

TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
// Define the index data
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 2;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
}
queries.push_back(query);
}

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

int maxResultWindow = 20000;

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(

knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow)));

// assert result size is not 0
ASSERT_NE(0, results->size());


// Need to free up each result
for (auto it : *results) {
delete it;
}
}
}

TEST(FaissRangeSearchQueryIndexTest_WhenHitMaxWindowResult, BasicAssertions){
// Define the index data
faiss::idx_t numIds = 200;
int dim = 2;
std::vector<faiss::idx_t> ids = test_util::Range(numIds);
std::vector<float> vectors = test_util::RandomVectors(dim, numIds, randomDataMin, randomDataMax);

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "HNSW32,Flat";

// Define query data
float radius = 100000.0;
int numQueries = 2;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(randomDataMin, randomDataMax));
}
queries.push_back(query);
}

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(dim, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

int maxResultWindow = 10;

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(

knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius, maxResultWindow)));

// assert result size is not 0
ASSERT_NE(0, results->size());
// assert result size is equal to maxResultWindow
ASSERT_EQ(maxResultWindow, results->size());

// Need to free up each result
for (auto it : *results) {
delete it;
}
}
}
66 changes: 65 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

import java.util.Arrays;
import java.util.Objects;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.search.BooleanClause;
Expand All @@ -30,14 +32,18 @@ public class KNNQuery extends Query {

private final String field;
private final float[] queryVector;
private final int k;
private int k;
private final String indexName;

@Getter
@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
@Getter
private Float radius = null;
@Getter
private Context context;

public KNNQuery(
final String field,
Expand Down Expand Up @@ -69,6 +75,54 @@ public KNNQuery(
this.parentsFilter = parentsFilter;
}

/**
* Constructor for KNNQuery with query vector, index name and parent filter
*
* @param field field name
* @param queryVector query vector
* @param indexName index name
* @param parentsFilter parent filter
*/
public KNNQuery(String field, float[] queryVector, String indexName, BitSetProducer parentsFilter) {
this.field = field;
this.queryVector = queryVector;
this.indexName = indexName;
this.parentsFilter = parentsFilter;
}

/**
* Constructor for KNNQuery with radius
*
* @param radius engine radius
* @return KNNQuery
*/
public KNNQuery radius(Float radius) {
this.radius = radius;
return this;
}

/**
* Constructor for KNNQuery with Context
*
* @param context Context for KNNQuery
* @return KNNQuery
*/
public KNNQuery kNNQueryContext(Context context) {
this.context = context;
return this;
}

/**
* Constructor for KNNQuery with filter query
*
* @param filterQuery filter query
* @return KNNQuery
*/
public KNNQuery filterQuery(Query filterQuery) {
this.filterQuery = filterQuery;
return this;
}

public String getField() {
return this.field;
}
Expand Down Expand Up @@ -144,4 +198,14 @@ private boolean equalsTo(KNNQuery other) {
&& Objects.equals(indexName, other.indexName)
&& Objects.equals(filterQuery, other.filterQuery);
}

/**
* Context for KNNQuery
*/
@Setter
@Getter
@AllArgsConstructor
public static class Context {
int maxResultWindow;
}
}
29 changes: 16 additions & 13 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,12 @@

import java.util.List;
import java.util.Objects;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -34,6 +25,16 @@
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;

import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
Expand Down Expand Up @@ -108,8 +109,8 @@ public KNNQueryBuilder k(int k) {
* @param distance the distance threshold for the nearest neighbours
*/
public KNNQueryBuilder distance(Float distance) {
if (distance == null || distance < 0) {
throw new IllegalArgumentException("[" + NAME + "] requires distance >= 0");
if (distance == null) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be set");
}
if (k != 0) {
throw new IllegalArgumentException("[" + NAME + "] requires either k or distance must be set");
Expand Down Expand Up @@ -400,6 +401,9 @@ protected Query doToQuery(QueryShardContext context) {
// We need transform distance radius to right type of engine required radius.
Float radius = null;
if (this.distance != null) {
if (this.distance < 0 && SpaceType.INNER_PRODUCT.equals(spaceType) == false) {
throw new IllegalArgumentException("[" + NAME + "] requires distance to be non-negative for space type: " + spaceType);
}
radius = knnEngine.distanceToRadialThreshold(this.distance, spaceType);
}

Expand Down Expand Up @@ -457,7 +461,6 @@ protected Query doToQuery(QueryShardContext context) {
.radius(radius)
.filter(this.filter)
.context(context)
.radius(radius)
.build();
return RNNQueryFactory.create(createQueryRequest);
}
Expand Down
Loading

0 comments on commit c369ec7

Please sign in to comment.