Skip to content

Commit

Permalink
Adds method parameters and validates against engine specific parameters
Browse files Browse the repository at this point in the history
Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Jun 5, 2024
1 parent d7bc3ad commit d663640
Show file tree
Hide file tree
Showing 51 changed files with 756 additions and 531 deletions.
8 changes: 4 additions & 4 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,23 @@ namespace knn_jni {
* Execute a query against the index located in memory at indexPointerJ
*
* Parameters:
* algoParams: introduces an object to have additional algo params. for instance hnsw will have efSearch
* methodParamsJ: introduces a map to have additional method parameters
*
* Return an array of KNNQueryResults
*/
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jobject algoParams, jintArray parentIdsJ);
jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ);

/**
* Execute a query against the index located in memory at indexPointerJ along with Filters
*
* Parameters:
* algoParams: introduces an object to have additional algo params. for instance hnsw will have efSearch
* methodParamsJ: introduces a map to have additional method parameters
*
* Return an array of KNNQueryResults
*/
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jobject algoParams, jlongArray filterIdsJ,
jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ,
jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
Expand Down
12 changes: 0 additions & 12 deletions jni/include/jni_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ namespace knn_jni {
// Find a java class given a particular name
virtual jclass FindClass(JNIEnv * env, const std::string& className) = 0;

virtual jboolean IsInstanceOf(JNIEnv * env, jobject object, const std::string& className) = 0;

virtual jboolean OptionalIsPresent(JNIEnv * env, jobject optional_jobject) = 0;

virtual jobject OptionalGetObject(JNIEnv * env, jobject optional_jobject) = 0;

virtual jobject CallObjectMethod(JNIEnv * env, jobject _jobject, const std::string& className, const std::string& methodName) = 0;

// Find a java method given a particular class, name and signature
virtual jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName) = 0;

Expand Down Expand Up @@ -146,11 +138,7 @@ namespace knn_jni {
void HasExceptionInStack(JNIEnv* env, const std::string& message);
void CatchCppExceptionAndThrowJava(JNIEnv* env);
jclass FindClass(JNIEnv * env, const std::string& className);
jboolean IsInstanceOf(JNIEnv * env, jobject object, const std::string& className);
jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName);
jboolean OptionalIsPresent(JNIEnv * env, jobject optional_jobject);
jobject OptionalGetObject(JNIEnv * env, jobject optional_jobject);
jobject CallObjectMethod(JNIEnv * env, jobject _jobject, const std::string& className, const std::string& methodName);
std::string ConvertJavaStringToCppString(JNIEnv * env, jstring javaString);
std::unordered_map<std::string, jobject> ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ);
std::string ConvertJavaObjectToCppString(JNIEnv *env, jobject objectJ);
Expand Down
4 changes: 2 additions & 2 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndex
* Signature: (J[FI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Signature: (J[FI[Ljava/util/MapI)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex
(JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: queryIndexWithFilter
* Signature: (J[FI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
* Signature: (J[FI[JLjava/util/MapI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult;
*/
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jlongArray, jint, jintArray);
Expand Down
35 changes: 18 additions & 17 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, fa
void buildFilterIdsBitMap(const int* filterIds, int filterIdsLength, uint8_t* bitsetVector);

// Gets efSearch from algo parameters
int getQueryEfSearch(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, jobject algoParams, int defaultEfSearch);
int getQueryEfSearch(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map<std::string, jobject> methodParams, int defaultEfSearch);

std::unique_ptr<faiss::IDGrouperBitmap> buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector<uint64_t>* bitmap);

Expand Down Expand Up @@ -299,12 +299,12 @@ void knn_jni::faiss_wrapper::SetSharedIndexState(jlong indexPointerJ, jlong shar
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jobject algoParams, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, algoParams, nullptr, 0, parentIdsJ);
jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ) {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, nullptr, 0, parentIdsJ);
}

jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jobject algoParams, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

if (queryVectorJ == nullptr) {
throw std::runtime_error("Query Vector cannot be null");
Expand All @@ -316,6 +316,11 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
throw std::runtime_error("Invalid pointer to index");
}

std::unordered_map<std::string, jobject> methodParams;
if (methodParamsJ != nullptr) {
methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ);
}

// The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from
// the query point
std::vector<float> dis(kJ);
Expand Down Expand Up @@ -344,7 +349,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader) {
// Query param efsearch supersedes ef_search provided during index setting.
hnswParams.efSearch = getQueryEfSearch(env, jniUtil, algoParams, hnswReader->hnsw.efSearch);
hnswParams.efSearch = getQueryEfSearch(env, jniUtil, methodParams, hnswReader->hnsw.efSearch);
hnswParams.sel = idSelector.get();
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
Expand Down Expand Up @@ -375,7 +380,7 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
auto hnswReader = dynamic_cast<const faiss::IndexHNSW*>(indexReader->index);
if(hnswReader!= nullptr) {
// Query param efseatch supersedes ef_search provided during index setting.
hnswParams.efSearch = getQueryEfSearch(env, jniUtil, algoParams, hnswReader->hnsw.efSearch);
hnswParams.efSearch = getQueryEfSearch(env, jniUtil, methodParams, hnswReader->hnsw.efSearch);
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
Expand Down Expand Up @@ -412,20 +417,16 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter
return results;
}

int getQueryEfSearch(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, jobject algoParams, int defaultEfSearch) {
if (algoParams == nullptr) {
int getQueryEfSearch(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map<std::string, jobject> methodParams, int defaultEfSearch) {
if (methodParams.empty()) {
return defaultEfSearch;
}
std::string hnswParameterJavaFilePath = "org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters";
if (jniUtil->IsInstanceOf(env, algoParams, hnswParameterJavaFilePath)) {
jobject optionalEfSearch = jniUtil->CallObjectMethod(env, algoParams, hnswParameterJavaFilePath, "getEfSearch");
jobject efSearch = jniUtil->OptionalGetObject(env, optionalEfSearch);
if (efSearch == nullptr) {
return defaultEfSearch;
}
return jniUtil->ConvertJavaObjectToCppInteger(env, efSearch);
auto efSearchIt = methodParams.find(knn_jni::EF_SEARCH);
if (efSearchIt != methodParams.end()) {
return jniUtil->ConvertJavaObjectToCppInteger(env, methodParams[knn_jni::EF_SEARCH]);
}
throw std::runtime_error("The algorithm parameters is not of type HNSWAlgoQueryParameters");

return defaultEfSearch;
}

void knn_jni::faiss_wrapper::Free(jlong indexPointer) {
Expand Down
40 changes: 0 additions & 40 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,6 @@ void knn_jni::JNIUtil::Initialize(JNIEnv *env) {
this->cachedMethods["java/lang/Integer:intValue"] = env->GetMethodID(tempLocalClassRef, "intValue", "()I");
env->DeleteLocalRef(tempLocalClassRef);

tempLocalClassRef = env->FindClass("java/util/Optional");
this->cachedClasses["java/util/Optional"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["java/util/Optional:isPresent"] = env->GetMethodID(tempLocalClassRef, "isPresent", "()Z");
this->cachedMethods["java/util/Optional:get"] = env->GetMethodID(tempLocalClassRef, "get", "()Ljava/lang/Object;");
env->DeleteLocalRef(tempLocalClassRef);

tempLocalClassRef = env->FindClass("org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters");
this->cachedClasses["org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters:getEfSearch"] = env->GetMethodID(tempLocalClassRef, "getEfSearch", "()Ljava/util/Optional;");
env->DeleteLocalRef(tempLocalClassRef);

tempLocalClassRef = env->FindClass("org/opensearch/knn/index/query/KNNQueryResult");
this->cachedClasses["org/opensearch/knn/index/query/KNNQueryResult"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["org/opensearch/knn/index/query/KNNQueryResult:<init>"] = env->GetMethodID(tempLocalClassRef, "<init>", "(IF)V");
Expand Down Expand Up @@ -141,35 +130,6 @@ jmethodID knn_jni::JNIUtil::FindMethod(JNIEnv * env, const std::string& classNam
return this->cachedMethods[key];
}

jboolean knn_jni::JNIUtil::IsInstanceOf(JNIEnv * env, jobject object, const std::string& className) {
jclass clazz = FindClass(env, className);
return env->IsInstanceOf(object, clazz);
}

jboolean knn_jni::JNIUtil::OptionalIsPresent(JNIEnv * env, jobject optional_jobject) {
std::string key = "java/util/Optional:isPresent";
if (this->cachedMethods.find(key) == this->cachedMethods.end()) {
throw std::runtime_error("Unable to find java/util/Optional:isPresent method");
}

jmethodID is_present_jmethod_id = this->cachedMethods[key];
return env->CallBooleanMethod(optional_jobject, is_present_jmethod_id);
}

jobject knn_jni::JNIUtil::OptionalGetObject(JNIEnv * env, jobject optional_jobject) {
if (this->OptionalIsPresent(env, optional_jobject)) {
return this->CallObjectMethod(env, optional_jobject, "java/util/Optional", "get");
}
return nullptr;
}

jobject knn_jni::JNIUtil::CallObjectMethod(JNIEnv * env, jobject _jobject, const std::string& className, const std::string& methodName) {
jmethodID methodId = this->FindMethod(env, className, methodName);
jobject returnValue = env->CallObjectMethod(_jobject, methodId);
this->HasExceptionInStack(env, "Could not call method " + className + ":" + methodName);
return returnValue;
}

std::unordered_map<std::string, jobject> knn_jni::JNIUtil::ConvertJavaMapToCppMap(JNIEnv *env, jobject parametersJ) {
// Here, we parse parametersJ, which is a java Map<String, Object>. In order to implement this, I referred to
// https://stackoverflow.com/questions/4844022/jni-create-hashmap. All java references are local, so they will be
Expand Down
8 changes: 4 additions & 4 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex(JNIEnv * env, jclass cls,
jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jobject algoParamsJ, jintArray parentIdsJ)
jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ)
{
try {
return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, algoParamsJ, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, parentIdsJ);

} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
Expand All @@ -121,10 +121,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd
}

JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject algoParamsJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {
(JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) {

try {
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, algoParamsJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ);
return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
Expand Down
45 changes: 11 additions & 34 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ TEST(FaissLoadIndexTest, IVFPQDisablePrecomputeTable) {
}

TEST(FaissQueryIndexTest, BasicAssertions) {
std::cout << "Test start";
// Define the index data
faiss::idx_t numIds = 100;
int dim = 16;
Expand All @@ -243,7 +242,10 @@ TEST(FaissQueryIndexTest, BasicAssertions) {

// Define query data
int k = 10;
auto* algoParams = new test_util::HNSWAlgoQueryParam({20});
int efSearch = 20;
std::unordered_map<std::string, jobject> methodParams;
methodParams[knn_jni::EF_SEARCH] = reinterpret_cast<jobject>(&efSearch);

int numQueries = 100;
std::vector<std::vector<float>> queries;

Expand All @@ -265,26 +267,15 @@ TEST(FaissQueryIndexTest, BasicAssertions) {
// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
auto algoParamsJ = reinterpret_cast<jobject>(algoParams);
EXPECT_CALL(mockJNIUtil,
IsInstanceOf(jniEnv, algoParamsJ, "org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters"))
.WillRepeatedly(Return(true));
EXPECT_CALL(mockJNIUtil,
CallObjectMethod(
jniEnv, algoParamsJ, "org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters", "getEfSearch"))
.WillRepeatedly(Return(algoParamsJ));
EXPECT_CALL(mockJNIUtil, OptionalGetObject(jniEnv, algoParamsJ))
.WillRepeatedly(Return(algoParamsJ));
EXPECT_CALL(mockJNIUtil, ConvertJavaObjectToCppInteger(jniEnv, algoParamsJ))
.WillRepeatedly(Return(algoParams->efSearch));
auto methodParamsJ = reinterpret_cast<jobject>(&methodParams);

for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(
knn_jni::faiss_wrapper::QueryIndex(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k, algoParamsJ, nullptr)));
reinterpret_cast<jfloatArray>(&query), k, methodParamsJ, nullptr)));

ASSERT_EQ(k, results->size());

Expand Down Expand Up @@ -391,7 +382,6 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) {

// Define query data
int k = 20;
auto* algoParams = new test_util::HNSWAlgoQueryParam({20});
int numQueries = 100;
std::vector<std::vector<float>> queries;

Expand All @@ -410,33 +400,20 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) {
auto createdIndexWithData =
test_util::FaissAddData(createdIndex.get(), ids, vectors);

int efSearch = 100;
std::unordered_map<std::string, jobject> methodParams;
methodParams[knn_jni::EF_SEARCH] = reinterpret_cast<jobject>(&efSearch);

// Setup jni
JNIEnv *jniEnv = nullptr;
NiceMock<test_util::MockJNIUtil> mockJNIUtil;
auto algoParamsJ = reinterpret_cast<jobject>(algoParams);
EXPECT_CALL(mockJNIUtil, GetJavaIntArrayLength(jniEnv, reinterpret_cast<jintArray>(&parentIds)))
.WillRepeatedly(Return(parentIds.size()));
EXPECT_CALL(mockJNIUtil, IsInstanceOf(jniEnv, algoParamsJ, "org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters"))
.WillRepeatedly(Return(true));
EXPECT_CALL(mockJNIUtil,
CallObjectMethod(
jniEnv, algoParamsJ, "org/opensearch/knn/index/query/model/HNSWAlgoQueryParameters", "getEfSearch"))
.WillRepeatedly(Return(algoParamsJ));;
EXPECT_CALL(mockJNIUtil,
OptionalGetObject(
jniEnv, algoParamsJ))
.WillRepeatedly(Return(algoParamsJ));
EXPECT_CALL(mockJNIUtil,
ConvertJavaObjectToCppInteger(
jniEnv, algoParamsJ))
.WillRepeatedly(Return(algoParams->efSearch));
for (auto query : queries) {
std::unique_ptr<std::vector<std::pair<int, float> *>> results(
reinterpret_cast<std::vector<std::pair<int, float> *> *>(
knn_jni::faiss_wrapper::QueryIndex(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jfloatArray>(&query), k, algoParamsJ,
reinterpret_cast<jfloatArray>(&query), k, reinterpret_cast<jobject>(&methodParams),
reinterpret_cast<jintArray>(&parentIds))));

// Even with k 20, result should have only 10 which is total number of groups
Expand Down
Loading

0 comments on commit d663640

Please sign in to comment.