Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds support for ef_search query parameter in FAISS engine #1707

Merged
merged 18 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,27 @@ namespace knn_jni {
// Sets the sharedIndexState for an index
void SetSharedIndexState(jlong indexPointerJ, jlong shareIndexStatePointerJ);

// Execute a query against the index located in memory at indexPointerJ.
//
// Return an array of KNNQueryResults
/**
* Execute a query against the index located in memory at indexPointerJ
*
* Parameters:
* queryEfSearch: -1 indicates to use efsearch value used during index setting
*
* Return an array of KNNQueryResults
*/
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
shatejas marked this conversation as resolved.
Show resolved Hide resolved
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jfloatArray queryVectorJ, jint kJ, jint queryEfSearchJ, jintArray parentIdsJ);

/**
* Execute a query against the index located in memory at indexPointerJ along with Filters
*
* Parameters:
* queryEfSearch: -1 indicates to use efsearch value used during index setting
*
* Return an array of KNNQueryResults
*/
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
shatejas marked this conversation as resolved.
Show resolved Hide resolved
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ,
jfloatArray queryVectorJ, jint kJ, jint queryEfSearchJ, jlongArray filterIdsJ,
jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
Expand Down
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,15 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt
* Signature: (J[FI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jint, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndexWithFilter
* Signature: (J[FI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray);
(JNIEnv *, jclass, jlong, jfloatArray, jint, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
Expand Down
24 changes: 12 additions & 12 deletions jni/src/faiss_wrapper.cpp
shatejas marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,12 @@ void knn_jni::faiss_wrapper::SetSharedIndexState(jlong indexPointerJ, jlong shar
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ);
jfloatArray queryVectorJ, jint kJ, jint queryEfSearchJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, queryEfSearchJ, nullptr, 0, parentIdsJ);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
jfloatArray queryVectorJ, jint kJ, jint queryEfSearchJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
Expand Down Expand Up @@ -340,9 +340,8 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
// Query param efseatch supersedes ef_search provided during index setting.
hnswParams.efSearch = queryEfSearchJ == -1 ? hnswReader->hnsw.efSearch : queryEfSearchJ;
hnswParams.sel = idSelector.get();
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
Expand Down Expand Up @@ -371,12 +370,13 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr && parentIdsJ != nullptr) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
if(hnswReader!= nullptr) {
shatejas marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we test with cases what happnes when hnswReader is not null and parentIdsJ is Null and vice versa , I feel earlier condition was fine to check both not null. And Can we add test case for all Corner cased like when one is null and other is not , just to make sure it works , Since we are removing one condition from Top If Statment.

Copy link
Contributor Author

@shatejas shatejas May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the earlier condition efSearch was not getting updated and it was using the default value of 16 with the combination of filter and parent filter being null. I will let others evaluate this

As mentioned, still working on the unit test

// Query param efseatch supersedes ef_search provided during index setting.
hnswParams.efSearch = queryEfSearchJ == -1 ? hnswReader->hnsw.efSearch : queryEfSearchJ;
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
}
try {
Expand Down
8 changes: 4 additions & 4 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ)
jfloatArray queryVectorJ, jint kJ, jint queryEfSearchJ, jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, queryEfSearchJ, parentIdsJ);

} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
Expand All @@ -121,10 +121,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jint queryEfSearchJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, queryEfSearchJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
8 changes: 5 additions & 3 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) {

// Define query data
int k = 10;
int efSearch = 100;
int numQueries = 100;
std::vector<std::vector<float>> queries;

Expand Down Expand Up @@ -270,7 +271,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) {
knn_jni::faiss_wrapper::QueryIndex(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k, nullptr)));
reinterpret_cast<jfloatArray>(&query), k, efSearch, nullptr)));

ASSERT_EQ(k, results->size());

Expand Down Expand Up @@ -336,7 +337,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) {
knn_jni::faiss_wrapper::QueryIndex_WithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k,
reinterpret_cast<jfloatArray>(&query), k, -1,
reinterpret_cast<jlongArray>(&bitmap), 0, nullptr)));

ASSERT_TRUE(results->size() <= filterIds.size());
Expand Down Expand Up @@ -376,6 +377,7 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) {

// Define query data
int k = 20;
int efSearch = 100;
int numQueries = 100;
std::vector<std::vector<float>> queries;

Expand Down Expand Up @@ -407,7 +409,7 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) {
knn_jni::faiss_wrapper::QueryIndex(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k,
reinterpret_cast<jfloatArray>(&query), k, efSearch,
reinterpret_cast<jintArray>(&parentIds))));

// Even with k 20, result should have only 10 which is total number of groups
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public static class CreateQueryRequest {
private byte[] byteVector;
private VectorDataType vectorDataType;
private Integer k;
private Integer efSearch;
private Float radius;
private QueryBuilder filter;
private QueryShardContext context;
Expand Down
39 changes: 14 additions & 25 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@

package org.opensearch.knn.index.query;

import java.util.Arrays;
import java.util.Objects;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.search.BooleanClause;
Expand All @@ -23,26 +21,28 @@
import org.opensearch.knn.index.KNNSettings;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
* Custom KNN query. Query is used for KNNEngine's that create their own custom segment files. These files need to be
* loaded and queried in a custom manner throughout the query path.
*/
@Getter
@Builder
@AllArgsConstructor
public class KNNQuery extends Query {

private final String field;
private final float[] queryVector;
private int k;
private Integer efSearch;
shatejas marked this conversation as resolved.
Show resolved Hide resolved
private final String indexName;

@Getter
@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
@Getter
private Float radius = null;
@Getter
private Float radius;
private Context context;

public KNNQuery(
Expand Down Expand Up @@ -123,22 +123,6 @@ public KNNQuery filterQuery(Query filterQuery) {
return this;
}

public String getField() {
return this.field;
}

public float[] getQueryVector() {
return this.queryVector;
}

public int getK() {
return this.k;
}

public String getIndexName() {
return this.indexName;
}

/**
* Constructs Weight implementation for this query
*
Expand Down Expand Up @@ -183,7 +167,7 @@ public String toString(String field) {

@Override
public int hashCode() {
return Objects.hash(field, Arrays.hashCode(queryVector), k, indexName, filterQuery);
return Objects.hash(field, Arrays.hashCode(queryVector), k, indexName, filterQuery, context, parentsFilter, radius, efSearch);
}

@Override
Expand All @@ -192,10 +176,15 @@ public boolean equals(Object other) {
}

private boolean equalsTo(KNNQuery other) {
if (other == this) return true;
return Objects.equals(field, other.field)
&& Arrays.equals(queryVector, other.queryVector)
&& Objects.equals(k, other.k)
&& Objects.equals(efSearch, other.efSearch)
&& Objects.equals(radius, other.radius)
&& Objects.equals(context, other.context)
&& Objects.equals(indexName, other.indexName)
&& Objects.equals(parentsFilter, other.parentsFilter)
&& Objects.equals(filterQuery, other.filterQuery);
}

Expand Down
Loading
Loading