Skip to content

Commit

Permalink
Support filter and nested field in faiss engine radial search
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Apr 24, 2024
1 parent c35f6ad commit ae39f5a
Show file tree
Hide file tree
Showing 15 changed files with 596 additions and 132 deletions.
3 changes: 2 additions & 1 deletion jni/cmake/init-faiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ if (NOT EXISTS ${FAISS_REPO_DIR})
endif ()

# Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu.
find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH)
find_path(PATCH_FILE NAMES 0001-Custom-patch-to-support-multi-vector.patch 0002-Enable-precomp-table-to-be-shared-ivfpq.patch 0003-Custom-patch-to-support-range-search-params.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss NO_DEFAULT_PATH)

# If it exists, apply patches
if (EXISTS ${PATCH_FILE})
message(STATUS "Applying custom patches.")
execute_process(COMMAND git am --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0001-Custom-patch-to-support-multi-vector.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
execute_process(COMMAND git am --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0002-Enable-precomp-table-to-be-shared-ivfpq.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
execute_process(COMMAND git am --3way ${CMAKE_CURRENT_SOURCE_DIR}/patches/faiss/0003-Custom-patch-to-support-range-search-params.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE)
if(RESULT_CODE)
message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}")
endif()
Expand Down
20 changes: 19 additions & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,35 @@ namespace knn_jni {
jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension,
jlong trainVectorsPointerJ);

/*
* Perform a range search with filter against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param maxResultsWindowJ - the maximum number of results to return
* @param filterIdsJ - the filter ids
* @param filterIdsTypeJ - the filter ids type
* @param parentIdsJ - the parent ids
*
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

/*
* Perform a range search against the index located in memory at indexPointerJ.
*
* @param indexPointerJ - pointer to the index
* @param queryVectorJ - the query vector
* @param radiusJ - the radius for the range search
* @param maxResultsWindowJ - the maximum number of results to return
* @param parentIdsJ - the parent ids
*
* @return an array of RangeQueryResults
*/
jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultsWindowJ);
jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ);
}
}

Expand Down
14 changes: 11 additions & 3 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,19 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[F[F)J
* Method: rangeSearchIndexWithFilter
* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jlongArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: rangeSearchIndex
* Signature: (J[FJ[I)[Lorg/opensearch/knn/index/query/RangeQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint);
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jint, jintArray);

#ifdef __cplusplus
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
From af6770b505a32b2c4eab2036d2509dec4b137f28 Mon Sep 17 00:00:00 2001
From: Junqiu Lei <[email protected]>
Date: Tue, 23 Apr 2024 17:18:56 -0700
Subject: [PATCH] Custom patch to support range search params

Signed-off-by: Junqiu Lei <[email protected]>
---
faiss/IndexIDMap.cpp | 28 ++++++++++++++++++++++++----
1 file changed, 24 insertions(+), 4 deletions(-)

diff --git a/faiss/IndexIDMap.cpp b/faiss/IndexIDMap.cpp
index 3f375e7b..11f3a847 100644
--- a/faiss/IndexIDMap.cpp
+++ b/faiss/IndexIDMap.cpp
@@ -176,11 +176,31 @@ void IndexIDMapTemplate<IndexT>::range_search(
RangeSearchResult* result,
const SearchParameters* params) const {
if (params) {
- SearchParameters internal_search_parameters;
- IDSelectorTranslated id_selector_translated(id_map, params->sel);
- internal_search_parameters.sel = &id_selector_translated;
+ IDSelectorTranslated this_idtrans(this->id_map, nullptr);
+ ScopedSelChange sel_change;
+ IDGrouperTranslated this_idgrptrans(this->id_map, nullptr);
+ ScopedGrpChange grp_change;
+
+ if (params->sel) {
+ auto idtrans = dynamic_cast<const IDSelectorTranslated*>(params->sel);
+
+ if (!idtrans) {
+ auto params_non_const = const_cast<SearchParameters*>(params);
+ this_idtrans.sel = params->sel;
+ sel_change.set(params_non_const, &this_idtrans);
+ }
+ }
+
+ if (params->grp) {
+ auto idtrans = dynamic_cast<const IDGrouperTranslated*>(params->grp);

- index->range_search(n, x, radius, result, &internal_search_parameters);
+ if (!idtrans) {
+ auto params_non_const = const_cast<SearchParameters*>(params);
+ this_idgrptrans.grp = params->grp;
+ grp_change.set(params_non_const, &this_idgrptrans);
+ }
+ }
+ index->range_search(n, x, radius, result, params);
} else {
index->range_search(n, x, radius, result);
}
--
2.39.0

72 changes: 70 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <jni.h>
#include <string>
#include <vector>
#include <iostream>

// Defines type of IDSelector
enum FilterIdsSelectorType{
Expand Down Expand Up @@ -589,7 +590,12 @@ faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {
}

jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ) {
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::RangeSearchWithFilter(jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, nullptr, 0, parentIdsJ);
}

jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ,
jfloatArray queryVectorJ, jfloat radiusJ, jint maxResultWindowJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
}
Expand All @@ -605,7 +611,69 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearch(knn_jni::JNIUtilInterface *jniU
// The res will be freed by ~RangeSearchResult() in FAISS
// The second parameter is always true, as lims is allocated by FAISS
faiss::RangeSearchResult res(1, true);
indexReader->range_search(1, rawQueryVector, radiusJ, &res);

if(filterIdsJ != nullptr) {
jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr);
int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ);
std::unique_ptr<faiss::IDSelector> idSelector;
if(filterIdsTypeJ == BITMAP) {
idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray));
} else {
faiss::idx_t* batchIndices = reinterpret_cast<faiss::idx_t*>(filteredIdsArray);
idSelector.reset(new faiss::IDSelectorBatch(filterIdsLength, batchIndices));
}
faiss::SearchParameters *searchParameters;
faiss::SearchParametersHNSW hnswParams;
faiss::SearchParametersIVF ivfParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
hnswParams.sel = idSelector.get();
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
searchParameters = &hnswParams;
} else {
auto ivfReader = dynamic_cast<const faiss::IndexIVF*>(indexReader->index);
auto ivfFlatReader = dynamic_cast<const faiss::IndexIVFFlat*>(indexReader->index);
if(ivfReader || ivfFlatReader) {
ivfParams.sel = idSelector.get();
searchParameters = &ivfParams;
}
}
try {
indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT);
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr && parentIdsJ != nullptr) {
// Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default
// value of ef_search = 16 which will then be used.
hnswParams.efSearch = hnswReader->hnsw.efSearch;
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
searchParameters = &hnswParams;
}
try {
indexReader->range_search(1, rawQueryVector, radiusJ, &res, searchParameters);
} catch (...) {
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT);
throw;
}
}

// lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries),
// lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN,
Expand Down
19 changes: 17 additions & 2 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,26 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ)
jfloat radiusJ, jint maxResultWindowJ,
jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ);
return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, maxResultWindowJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return nullptr;
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ,
jfloat radiusJ, jint maxResultWindowJ,
jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ,
maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
Loading

0 comments on commit ae39f5a

Please sign in to comment.