Skip to content

Commit

Permalink
Support distance type radius search for Faiss engine
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Mar 19, 2024
1 parent 0470f21 commit 0b00d43
Show file tree
Hide file tree
Showing 14 changed files with 434 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x)
### Features
* Support distance type radius search for Lucene engine [#1498](https://github.com/opensearch-project/k-NN/pull/1498)
* Support distance type radius search for Faiss engine [#1546](https://github.com/opensearch-project/k-NN/pull/1546)
### Enhancements
### Bug Fixes
### Infrastructure
Expand Down
4 changes: 4 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ namespace knn_jni {
// Return the serialized representation
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

// Execute a range search against the index located in memory at indexPointerJ.
jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ);
}
}

Expand Down
8 changes: 8 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[F[F)J
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat);

#ifdef __cplusplus
}
#endif
Expand Down
34 changes: 34 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,37 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {

throw std::runtime_error("Unable to extract IVFPQ index. IVFPQ index not present.");
}

jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ) {
if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}

auto *indexReader = reinterpret_cast<faiss::IndexIDMap *>(indexPointerJ);

if (indexReader == nullptr) {
throw std::runtime_error("Invalid pointer to indexReader");
}

float *rawQueryVector = jniUtil->GetFloatArrayElements(env, queryVectorJ, nullptr);

faiss::RangeSearchResult res(1, true);
indexReader->range_search(1, rawQueryVector, radiusJ, &res);

// Process the results, lims[1] contains the total number of results found for single query
int resultSize = res.lims[1];

jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult");
jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", "<init>");

jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr);

jobject result;
for(int i = 0; i < resultSize; ++i) {
result = jniUtil->NewObject(env, resultClass, allArgs, res.labels[i], res.distances[i]);
jniUtil->SetObjectArrayElement(env, results, i, result);
}

return results;
}
13 changes: 13 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeVectors(JNIE
delete vect;
}
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ);

} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}
60 changes: 60 additions & 0 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,63 @@ TEST(FaissInitAndSetSharedIndexState, BasicAssertions) {
ASSERT_EQ(1, ivfpqIndex->use_precomputed_table);
knn_jni::faiss_wrapper::FreeSharedIndexState(sharedModelAddress);
}

TEST(FaissRangeSearchQueryIndexTest, BasicAssertions) {
// Define the index data
faiss::idx_t numIds = 100;
std::vector<faiss::idx_t> ids;
std::vector<float> vectors;
int dim = 1;
for (int64_t i = 0; i < numIds; i++) {
ids.push_back(i);
for (int j = 0; j < dim; j++) {
vectors.push_back(test_util::RandomFloat(-10.0, 10.0));
}
}

faiss::MetricType metricType = faiss::METRIC_L2;
std::string method = "Flat";

// Define query data
float radius = 10.0;
int numQueries = 2;
std::vector<std::vector<float>> queries;

for (int i = 0; i < numQueries; i++) {
std::vector<float> query;
query.reserve(dim);
for (int j = 0; j < dim; j++) {
query.push_back(test_util::RandomFloat(-10.0, 10.0));
}
queries.push_back(query);
}

// Create the index
std::unique_ptr<faiss::Index> createdIndex(
test_util::FaissCreateIndex(2, method, metricType));
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(

knn_jni::faiss_wrapper::RangeSearch(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), radius)));

// assert result size is not 0
ASSERT_NE(0, results->size());


// Need to free up each result
for (auto it : *results) {
delete it;
}
}
}
42 changes: 41 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@ public class KNNQuery extends Query {

private final String field;
private final float[] queryVector;
private final int k;
private int k;
private final String indexName;

@Getter
@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
@Getter
@Setter
// Radius for radius query, set to -1.0 for KNN query to avoid exception in OpenSearch Query null checks
private Float radius = -1.0f;

public KNNQuery(
final String field,
Expand Down Expand Up @@ -69,6 +73,42 @@ public KNNQuery(
this.parentsFilter = parentsFilter;
}

/**
* Constructor for KNNQuery with query vector, index name and parent filter
*
* @param field field name
* @param queryVector query vector
* @param indexName index name
*/
public KNNQuery(String field, float[] queryVector, String indexName, BitSetProducer parentsFilter) {
this.field = field;
this.queryVector = queryVector;
this.indexName = indexName;
this.parentsFilter = parentsFilter;
}

/**
* Constructor for KNNQuery with radius
*
* @param radius engine radius
* @return KNNQuery
*/
public KNNQuery radius(Float radius) {
this.radius = radius;
return this;
}

/**
* Constructor for KNNQuery with filter query
*
* @param filterQuery filter query
* @return KNNQuery
*/
public KNNQuery filterQuery(Query filterQuery) {
this.filterQuery = filterQuery;
return this;
}

public String getField() {
return this.field;
}
Expand Down
22 changes: 11 additions & 11 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,12 @@

import java.util.List;
import java.util.Objects;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
Expand All @@ -34,6 +25,16 @@
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;

import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue;
import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion;
Expand Down Expand Up @@ -457,7 +458,6 @@ protected Query doToQuery(QueryShardContext context) {
.radius(radius)
.filter(this.filter)
.context(context)
.radius(radius)
.build();
return RNNQueryFactory.create(createQueryRequest);
}
Expand Down
31 changes: 21 additions & 10 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,24 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
throw new RuntimeException("Index has already been closed");
}
int[] parentIds = getParentIdsArray(context);
results = JNIService.queryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnEngine,
filterIds,
filterType.getValue(),
parentIds
);

if (knnQuery.getK() > 0) {
results = JNIService.queryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnEngine,
filterIds,
filterType.getValue(),
parentIds
);
} else {
results = JNIService.radiusQueryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getRadius(),
knnEngine.getName()
);
}
} catch (Exception e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
Expand Down Expand Up @@ -406,6 +414,9 @@ private boolean canDoExactSearch(final int filterIdsCount) {
filterIdsCount,
KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName())
);
if (knnQuery.getRadius() >= 0) {
return false;
}
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic
if (filterIdsCount <= knnQuery.getK()) {
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.apache.lucene.search.ByteVectorSimilarityQuery;
import org.apache.lucene.search.FloatVectorSimilarityQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

Expand Down Expand Up @@ -66,6 +68,24 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest
final VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
final Query filterQuery = getFilterQuery(createQueryRequest);

BitSetProducer parentFilter = null;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
KNNQuery rnnQuery = new KNNQuery(fieldName, vector, indexName, parentFilter).radius(radius);
if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) {
log.debug("Creating custom radius search with filters for index: {}, field: {} , r: {}", indexName, fieldName, radius);
rnnQuery.filterQuery(filterQuery);
}
log.debug(
String.format("Creating custom radius search for index: %s \"\", field: %s \"\", r: %f", indexName, fieldName, radius)
);
return rnnQuery;
}

log.debug(String.format("Creating Lucene r-NN query for index: %s \"\", field: %s \"\", k: %f", indexName, fieldName, radius));
switch (vectorDataType) {
case BYTE:
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/util/KNNEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public enum KNNEngine implements KNNLibrary {

private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
11 changes: 11 additions & 0 deletions src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,15 @@ public static native KNNQueryResult[] queryIndexWithFilter(
* @param vectorsPointer to be freed
*/
public static native void freeVectors(long vectorsPointer);

/**
* Range search index
*
* @param indexPointer pointer to index in memory
* @param queryVector vector to be used for query
* @param radius search within radius threshold
* @param engineName name of the engine to use
* @return KNNQueryResult array of neighbors within radius
*/
public static native KNNQueryResult[] rangeSearchIndex(long indexPointer, float[] queryVector, float radius, String engineName);
}
17 changes: 17 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -253,4 +253,21 @@ public static long transferVectors(long vectorsPointer, float[][] trainingData)
public static void freeVectors(long vectorsPointer) {
FaissService.freeVectors(vectorsPointer);
}

/**
* Range search index for a given query vector
*
* @param indexPointer pointer to index in memory
* @param queryVector vector to be used for query
* @param radius search within radius threshold
* @param engineName name of the engine to use
* @return KNNQueryResult array of neighbors within radius
*/
public static KNNQueryResult[] radiusQueryIndex(long indexPointer, float[] queryVector, float radius, String engineName) {
if (KNNEngine.FAISS.getName().equals(engineName)) {
return FaissService.rangeSearchIndex(indexPointer, queryVector, radius, engineName);
}
throw new IllegalArgumentException("RadiusQueryIndex not supported for provided engine");
}

}
Loading

0 comments on commit 0b00d43

Please sign in to comment.