Skip to content

Commit

Permalink
Add ability to share IVFPQ-l2 state in JNI (opensearch-project#1529)
Browse files Browse the repository at this point in the history
Adds a set of JNI functions that allows IVFPQ-l2 dynamically precomputed
tables to be shared amongst indices at load time that use the same
model.

In addition, refactors JNIService interface to take KNNEngine as param instead of as string

In addition, add unit tests to confirm that the shared state operates as
expected.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 committed Mar 18, 2024
1 parent bc70e06 commit 8ae57d3
Show file tree
Hide file tree
Showing 21 changed files with 711 additions and 162 deletions.
15 changes: 15 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,18 @@ namespace knn_jni {
// Return a pointer to the loaded index
jlong LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ);

// Check if a loaded index requires shared state
bool IsSharedIndexStateRequired(jlong indexPointerJ);

// Initializes the shared index state from an index. Note, this will not set the state for
// the index pointed to by indexPointerJ. To set it, SetSharedIndexState needs to be called.
//
// Return a pointer to the shared index state
jlong InitSharedIndexState(jlong indexPointerJ);

// Sets the sharedIndexState for an index
void SetSharedIndexState(jlong indexPointerJ, jlong shareIndexStatePointerJ);

// Execute a query against the index located in memory at indexPointerJ.
//
// Return an array of KNNQueryResults
Expand All @@ -49,6 +61,9 @@ namespace knn_jni {
// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);

// Free shared index state in memory at shareIndexStatePointerJ
void FreeSharedIndexState(jlong shareIndexStatePointerJ);

// Perform initilization operations for the library
void InitLibrary();

Expand Down
38 changes: 35 additions & 3 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,42 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex
(JNIEnv *, jclass, jstring);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: isSharedIndexStateRequired
* Signature: (J)Z
*/
JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initSharedIndexState
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initSharedIndexState
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: setSharedIndexState
* Signature: (JJ)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexState
(JNIEnv *, jclass, jlong, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex
* Signature: (J[FI)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Signature: (J[FI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex_WithFilter
* Signature: (J[FI[J)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Method: queryIndexWithFilter
* Signature: (J[FI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray);
Expand All @@ -66,6 +90,14 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: freeSharedIndexState
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeSharedIndexState
(JNIEnv *, jclass, jlong);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initLibrary
Expand Down
101 changes: 100 additions & 1 deletion jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "faiss/MetaIndexes.h"
#include "faiss/Index.h"
#include "faiss/impl/IDSelector.h"
#include "faiss/IndexIVFPQ.h"

#include <algorithm>
#include <jni.h>
Expand Down Expand Up @@ -73,6 +74,13 @@ void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bi

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap);

// Check if a loaded index is an IVFPQ index with l2 space type
bool isIndexIVFPQL2(faiss::Index * index);

// Gets IVFPQ index from a faiss index. For faiss, we wrap the index in the type
// IndexIDMap which has member that will point to underlying index that stores the data
faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index);

void knn_jni::faiss_wrapper::CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jobjectArray vectorsJ, jstring indexPathJ, jobject parametersJ) {

Expand Down Expand Up @@ -214,10 +222,60 @@ jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNI

std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ));
// Skipping IO_FLAG_PQ_SKIP_SDC_TABLE because the index is read only and the sdc table is only used during ingestion
faiss::Index* indexReader = faiss::read_index(indexPathCpp.c_str(), faiss::IO_FLAG_READ_ONLY | faiss::IO_FLAG_PQ_SKIP_SDC_TABLE);
// Skipping IO_PRECOMPUTE_TABLE because it is only needed for IVFPQ-l2 and it leads to high memory consumption if
// done for each segment. Instead, we will set it later on with `setSharedIndexState`
faiss::Index* indexReader = faiss::read_index(indexPathCpp.c_str(), faiss::IO_FLAG_READ_ONLY | faiss::IO_FLAG_PQ_SKIP_SDC_TABLE | faiss::IO_FLAG_SKIP_PRECOMPUTE_TABLE);
return (jlong) indexReader;
}

bool knn_jni::faiss_wrapper::IsSharedIndexStateRequired(jlong indexPointerJ) {
auto * index = reinterpret_cast<faiss::Index*>(indexPointerJ);
return isIndexIVFPQL2(index);
}

jlong knn_jni::faiss_wrapper::InitSharedIndexState(jlong indexPointerJ) {
auto * index = reinterpret_cast<faiss::Index*>(indexPointerJ);
if (!isIndexIVFPQL2(index)) {
throw std::runtime_error("Unable to init shared index state from index. index is not of type IVFPQ-l2");
}

auto * indexIVFPQ = extractIVFPQIndex(index);
int use_precomputed_table = 0;
auto * sharedMemoryAddress = new faiss::AlignedTable<float>();
faiss::initialize_IVFPQ_precomputed_table(
use_precomputed_table,
indexIVFPQ->quantizer,
indexIVFPQ->pq,
*sharedMemoryAddress,
indexIVFPQ->by_residual,
indexIVFPQ->verbose);
return (jlong) sharedMemoryAddress;
}

void knn_jni::faiss_wrapper::SetSharedIndexState(jlong indexPointerJ, jlong shareIndexStatePointerJ) {
auto * index = reinterpret_cast<faiss::Index*>(indexPointerJ);
if (!isIndexIVFPQL2(index)) {
throw std::runtime_error("Unable to set shared index state from index. index is not of type IVFPQ-l2");
}
auto * indexIVFPQ = extractIVFPQIndex(index);

//TODO: Currently, the only shared state is that of the AlignedTable associated with
// IVFPQ-l2 index type (see https://github.com/opensearch-project/k-NN/issues/1507). In the future,
// this will be generalized and more information will be needed to determine the shared type. But, until then,
// this is fine.
auto *alignTable = reinterpret_cast<faiss::AlignedTable<float>*>(shareIndexStatePointerJ);
// In faiss, usePrecomputedTable can have a couple different values:
// -1 -> dont use the table
// 0 -> tell initialize_IVFPQ_precomputed_table to select the best value and change the value
// 1 -> default behavior
// 2 -> Index is of type "MultiIndexQuantizer"
// This index will be of type IndexIVFPQ always. We never create "MultiIndexQuantizer". So, the value we
// want is 1.
// (ref: https://github.com/facebookresearch/faiss/blob/v1.8.0/faiss/IndexIVFPQ.cpp#L383-L410)
int usePrecomputedTable = 1;
indexIVFPQ->set_precomputed_table(alignTable, usePrecomputedTable);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ);
Expand Down Expand Up @@ -337,6 +395,15 @@ void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
delete indexWrapper;
}

void knn_jni::faiss_wrapper::FreeSharedIndexState(jlong shareIndexStatePointerJ) {
//TODO: Currently, the only shared state is that of the AlignedTable associated with
// IVFPQ-l2 index type (see https://github.com/opensearch-project/k-NN/issues/1507). In the future,
// this will be generalized and more information will be needed to determine the shared type. But, until then,
// this is fine.
auto *alignTable = reinterpret_cast<faiss::AlignedTable<float>*>(shareIndexStatePointerJ);
delete alignTable;
}

void knn_jni::faiss_wrapper::InitLibrary() {
//set thread 1 cause ES has Search thread
//TODO make it different at search and write
Expand Down Expand Up @@ -469,3 +536,35 @@ std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInt
jniUtil->ReleaseIntArrayElements(env, parentIdsJ, parentIdsArray, JNI_ABORT);
return idGrouper;
}

bool isIndexIVFPQL2(faiss::Index * index) {
faiss::Index * candidateIndex = index;
// Unwrap the index if it is wrapped in IndexIDMap. Dynamic cast will "Safely converts pointers and references to
// classes up, down, and sideways along the inheritance hierarchy." It will return a nullptr if the
// cast fails. (ref: https://en.cppreference.com/w/cpp/language/dynamic_cast)
if (auto indexIDMap = dynamic_cast<faiss::IndexIDMap *>(index)) {
candidateIndex = indexIDMap->index;
}

// Check if the index is of type IndexIVFPQ. If so, confirm its metric type is
// l2.
if (auto indexIVFPQ = dynamic_cast<faiss::IndexIVFPQ *>(candidateIndex)) {
return faiss::METRIC_L2 == indexIVFPQ->metric_type;
}

return false;
}

faiss::IndexIVFPQ * extractIVFPQIndex(faiss::Index * index) {
faiss::Index * candidateIndex = index;
if (auto indexIDMap = dynamic_cast<faiss::IndexIDMap *>(index)) {
candidateIndex = indexIDMap->index;
}

faiss::IndexIVFPQ * indexIVFPQ;
if ((indexIVFPQ = dynamic_cast<faiss::IndexIVFPQ *>(candidateIndex))) {
return indexIVFPQ;
}

throw std::runtime_error("Unable to extract IVFPQ index. IVFPQ index not present.");
}
42 changes: 42 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEn
return NULL;
}

JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired
(JNIEnv * env, jclass cls, jlong indexPointerJ)
{
try {
return knn_jni::faiss_wrapper::IsSharedIndexStateRequired(indexPointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return NULL;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initSharedIndexState
(JNIEnv * env, jclass cls, jlong indexPointerJ)
{
try {
return knn_jni::faiss_wrapper::InitSharedIndexState(indexPointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return NULL;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexState
(JNIEnv * env, jclass cls, jlong indexPointerJ, jlong shareIndexStatePointerJ)
{
try {
knn_jni::faiss_wrapper::SetSharedIndexState(indexPointerJ, shareIndexStatePointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ)
Expand Down Expand Up @@ -109,6 +141,16 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_free(JNIEnv * en
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_freeSharedIndexState
(JNIEnv * env, jclass cls, jlong shareIndexStatePointerJ)
{
try {
knn_jni::faiss_wrapper::FreeSharedIndexState(shareIndexStatePointerJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary(JNIEnv * env, jclass cls)
{
try {
Expand Down
Loading

0 comments on commit 8ae57d3

Please sign in to comment.