diff --git a/CHANGELOG.md b/CHANGELOG.md index 62006a6f3..d38818a68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.15...2.x) ### Features +* Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/jni/CMakeLists.txt b/jni/CMakeLists.txt index 595fa6fea..e2003e0f7 100644 --- a/jni/CMakeLists.txt +++ b/jni/CMakeLists.txt @@ -147,8 +147,10 @@ if ("${WIN32}" STREQUAL "") add_executable( jni_test tests/faiss_wrapper_test.cpp + tests/faiss_wrapper_unit_test.cpp tests/faiss_util_test.cpp tests/nmslib_wrapper_test.cpp + tests/nmslib_wrapper_unit_test.cpp tests/test_util.cpp tests/commons_test.cpp ) diff --git a/jni/cmake/init-nmslib.cmake b/jni/cmake/init-nmslib.cmake index a735bcbd8..387dce6bc 100644 --- a/jni/cmake/init-nmslib.cmake +++ b/jni/cmake/init-nmslib.cmake @@ -13,12 +13,13 @@ if (NOT EXISTS ${NMS_REPO_DIR}) endif () # Check if patch exist, this is to skip git apply during CI build. See CI.yml with ubuntu. -find_path(PATCH_FILE NAMES 0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib NO_DEFAULT_PATH) +find_path(PATCH_FILE NAMES 0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch 0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch PATHS ${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib NO_DEFAULT_PATH) # If it exists, apply patches if (EXISTS ${PATCH_FILE}) message(STATUS "Applying custom patches.") execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) + execute_process(COMMAND git ${GIT_PATCH_COMMAND} --3way --ignore-space-change --ignore-whitespace ${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib ERROR_VARIABLE ERROR_MSG RESULT_VARIABLE RESULT_CODE) if(RESULT_CODE) message(FATAL_ERROR "Failed to apply patch:\n${ERROR_MSG}") diff --git a/jni/include/commons.h b/jni/include/commons.h index 05367a693..67a141c8b 100644 --- a/jni/include/commons.h +++ b/jni/include/commons.h @@ -33,5 +33,10 @@ namespace knn_jni { * @param memoryAddress address to be freed. */ void freeVectorData(jlong); + + /** + * Extracts query time efSearch from method parameters + **/ + int getIntegerMethodParameter(JNIEnv *, knn_jni::JNIUtilInterface *, std::unordered_map, std::string, int); } } diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 958eca8ac..aa747862a 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -45,17 +45,27 @@ namespace knn_jni { // Sets the sharedIndexState for an index void SetSharedIndexState(jlong indexPointerJ, jlong shareIndexStatePointerJ); - // Execute a query against the index located in memory at indexPointerJ. - // - // Return an array of KNNQueryResults + /** + * Execute a query against the index located in memory at indexPointerJ + * + * Parameters: + * methodParamsJ: introduces a map to have additional method parameters + * + * Return an array of KNNQueryResults + */ jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ); - - // Execute a query against the index located in memory at indexPointerJ along with Filters - // - // Return an array of KNNQueryResults + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ); + + /** + * Execute a query against the index located in memory at indexPointerJ along with Filters + * + * Parameters: + * methodParamsJ: introduces a map to have additional method parameters + * + * Return an array of KNNQueryResults + */ jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); // Free the index located in memory at indexPointerJ diff --git a/jni/include/nmslib_wrapper.h b/jni/include/nmslib_wrapper.h index 08494644f..27a013c10 100644 --- a/jni/include/nmslib_wrapper.h +++ b/jni/include/nmslib_wrapper.h @@ -37,7 +37,7 @@ namespace knn_jni { // // Return an array of KNNQueryResults jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ); + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ); // Free the index located in memory at indexPointerJ void Free(jlong indexPointer); diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index e16677db7..0453864e4 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -69,18 +69,18 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt /* * Class: org_opensearch_knn_jni_FaissService * Method: queryIndex - * Signature: (J[FI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + * Signature: (J[FI[Ljava/util/MapI)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex - (JNIEnv *, jclass, jlong, jfloatArray, jint, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jintArray); /* * Class: org_opensearch_knn_jni_FaissService * Method: queryIndexWithFilter - * Signature: (J[FI[JI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; + * Signature: (J[FI[JLjava/util/MapI[I)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv *, jclass, jlong, jfloatArray, jint, jlongArray, jint, jintArray); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject, jlongArray, jint, jintArray); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/include/org_opensearch_knn_jni_NmslibService.h b/jni/include/org_opensearch_knn_jni_NmslibService.h index 31422955f..a9d5238b7 100644 --- a/jni/include/org_opensearch_knn_jni_NmslibService.h +++ b/jni/include/org_opensearch_knn_jni_NmslibService.h @@ -40,7 +40,7 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex * Signature: (J[FI)[Lorg/opensearch/knn/index/query/KNNQueryResult; */ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_NmslibService_queryIndex - (JNIEnv *, jclass, jlong, jfloatArray, jint); + (JNIEnv *, jclass, jlong, jfloatArray, jint, jobject); /* * Class: org_opensearch_knn_jni_NmslibService diff --git a/jni/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch b/jni/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch new file mode 100644 index 000000000..e55e52cb3 --- /dev/null +++ b/jni/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch @@ -0,0 +1,303 @@ +From d700d93d5efda7349b90f6f0b2373580ced8097d Mon Sep 17 00:00:00 2001 +From: Tejas Shah +Date: Mon, 27 May 2024 22:02:12 -0700 +Subject: [PATCH] Adds ability to pass ef parameter in the query for hnsw + +It defaults to index ef_ value if its not type HNSWQuery +--- + similarity_search/include/hnswquery.h | 37 +++++++++++++++++++ + similarity_search/include/method/hnsw.h | 1 + + similarity_search/include/space.h | 4 ++ + similarity_search/src/hnswquery.cc | 32 ++++++++++++++++ + similarity_search/src/method/hnsw.cc | 28 +++++++++----- + .../src/method/hnsw_distfunc_opt.cc | 12 +++--- + 6 files changed, 100 insertions(+), 14 deletions(-) + create mode 100644 similarity_search/include/hnswquery.h + create mode 100644 similarity_search/src/hnswquery.cc + +diff --git a/similarity_search/include/hnswquery.h b/similarity_search/include/hnswquery.h +new file mode 100644 +index 0000000..a4f65ac +--- /dev/null ++++ b/similarity_search/include/hnswquery.h +@@ -0,0 +1,37 @@ ++/** ++ * Non-metric Space Library ++ * ++ * Main developers: Bilegsaikhan Naidan, Leonid Boytsov, Yury Malkov, Ben Frederickson, David Novak ++ * ++ * For the complete list of contributors and further details see: ++ * https://github.com/nmslib/nmslib ++ * ++ * Copyright (c) 2013-2018 ++ * ++ * This code is released under the ++ * Apache License Version 2.0 http://www.apache.org/licenses/. ++ * ++ */ ++ ++#ifndef HNSWQUERY_H ++#define HNSWQUERY_H ++#include "global.h" ++#include "knnquery.h" ++ ++namespace similarity { ++ ++template ++class HNSWQuery : public KNNQuery { ++public: ++ ~HNSWQuery(); ++ HNSWQuery(const Space& space, const Object *query_object, unsigned K, unsigned ef = 100, float eps = 0); ++ ++ unsigned getEf() { return ef_; } ++ ++protected: ++ unsigned ef_; ++}; ++ ++} ++ ++#endif //HNSWQUERY_H +diff --git a/similarity_search/include/method/hnsw.h b/similarity_search/include/method/hnsw.h +index 57d99d0..e6dcea7 100644 +--- a/similarity_search/include/method/hnsw.h ++++ b/similarity_search/include/method/hnsw.h +@@ -474,6 +474,7 @@ namespace similarity { + void baseSearchAlgorithmV1Merge(KNNQuery *query); + void SearchOld(KNNQuery *query, bool normalize); + void SearchV1Merge(KNNQuery *query, bool normalize); ++ size_t extractEf(KNNQuery *query, size_t defaultEf) const; + + int getRandomLevel(double revSize) + { +diff --git a/similarity_search/include/space.h b/similarity_search/include/space.h +index fedad46..a0e9ea9 100644 +--- a/similarity_search/include/space.h ++++ b/similarity_search/include/space.h +@@ -63,6 +63,9 @@ class Query; + template + class KNNQuery; + ++ template ++class HNSWQuery; ++ + template + class RangeQuery; + +@@ -263,6 +266,7 @@ class Space { + friend class Query; + friend class RangeQuery; + friend class KNNQuery; ++ friend class HNSWQuery; + friend class Experiments; + /* + * This function is private, but it will be accessible by the friend class Query +diff --git a/similarity_search/src/hnswquery.cc b/similarity_search/src/hnswquery.cc +new file mode 100644 +index 0000000..4ee7b38 +--- /dev/null ++++ b/similarity_search/src/hnswquery.cc +@@ -0,0 +1,32 @@ ++/** ++* Non-metric Space Library ++ * ++ * Main developers: Bilegsaikhan Naidan, Leonid Boytsov, Yury Malkov, Ben Frederickson, David Novak ++ * ++ * For the complete list of contributors and further details see: ++ * https://github.com/nmslib/nmslib ++ * ++ * Copyright (c) 2013-2018 ++ * ++ * This code is released under the ++ * Apache License Version 2.0 http://www.apache.org/licenses/. ++ * ++ */ ++ ++#include "hnswquery.h" ++ ++namespace similarity { ++ ++template ++HNSWQuery::HNSWQuery(const Space &space, const Object* query_object, const unsigned K, unsigned ef, float eps) ++ : KNNQuery(space, query_object, K, eps), ++ ef_(ef) { ++} ++ ++template ++HNSWQuery::~HNSWQuery() = default; ++ ++template class HNSWQuery; ++template class HNSWQuery; ++template class HNSWQuery; ++} +diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc +index 35b372c..69ee9e4 100644 +--- a/similarity_search/src/method/hnsw.cc ++++ b/similarity_search/src/method/hnsw.cc +@@ -46,6 +46,7 @@ + #include + #include + ++#include "hnswquery.h" + #include "sort_arr_bi.h" + #define MERGE_BUFFER_ALGO_SWITCH_THRESHOLD 100 + +@@ -101,9 +102,16 @@ namespace similarity { + return nullptr; + } + ++ template ++ size_t Hnsw::extractEf(KNNQuery* searchQuery, size_t defaultEf) const { ++ auto* hnswQueryPtr = dynamic_cast*>(searchQuery); ++ if (hnswQueryPtr) { ++ return hnswQueryPtr->getEf(); ++ } ++ return defaultEf; ++ } + +- +-// This is the counter to keep the size of neighborhood information (for one node) ++ // This is the counter to keep the size of neighborhood information (for one node) + // TODO Can this one overflow? I really doubt + typedef uint32_t SIZEMASS_TYPE; + +@@ -718,10 +726,11 @@ namespace similarity { + void + Hnsw::Search(KNNQuery *query, IdType) const + { ++ size_t ef = this->extractEf(query, ef_); + if (this->data_.empty() && this->data_rearranged_.empty()) { + return; + } +- bool useOld = searchAlgoType_ == kOld || (searchAlgoType_ == kHybrid && ef_ >= 1000); ++ bool useOld = searchAlgoType_ == kOld || (searchAlgoType_ == kHybrid && ef >= 1000); + // cout << "Ef = " << ef_ << " use old = " << useOld << endl; + switch (searchMethod_) { + case 0: +@@ -1148,6 +1157,7 @@ namespace similarity { + PREFETCH((char *)(massVisited + (*iter)->getId()), _MM_HINT_T0); + } + // calculate distance to each neighbor ++ size_t ef = this->extractEf(query, ef_); + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + curId = (*iter)->getId(); + +@@ -1155,12 +1165,12 @@ namespace similarity { + massVisited[curId] = currentV; + currObj = (*iter)->getData(); + d = query->DistanceObjLeft(currObj); +- if (closestDistQueue1.top().getDistance() > d || closestDistQueue1.size() < ef_) { ++ if (closestDistQueue1.top().getDistance() > d || closestDistQueue1.size() < ef) { + { + query->CheckAndAddToResult(d, currObj); + candidateQueue.emplace(d, *iter); + closestDistQueue1.emplace(d, *iter); +- if (closestDistQueue1.size() > ef_) { ++ if (closestDistQueue1.size() > ef) { + closestDistQueue1.pop(); + } + } +@@ -1185,6 +1195,7 @@ namespace similarity { + + const Object *currObj = provider->getData(); + ++ size_t ef = this->extractEf(query, ef_); + dist_t d = query->DistanceObjLeft(currObj); + dist_t curdist = d; + HnswNode *curNode = provider; +@@ -1209,7 +1220,7 @@ namespace similarity { + } + } + +- SortArrBI sortedArr(max(ef_, query->GetK())); ++ SortArrBI sortedArr(max(ef, query->GetK())); + sortedArr.push_unsorted_grow(curdist, curNode); + + int_fast32_t currElem = 0; +@@ -1225,8 +1236,7 @@ namespace similarity { + // PHASE TWO OF THE SEARCH + // Extraction of the neighborhood to find k nearest neighbors. + //////////////////////////////////////////////////////////////////////////////// +- +- while (currElem < min(sortedArr.size(), ef_)) { ++ while (currElem < min(sortedArr.size(), ef)) { + auto &e = queueData[currElem]; + CHECK(!e.used); + e.used = true; +@@ -1255,7 +1265,7 @@ namespace similarity { + currObj = (*iter)->getData(); + d = query->DistanceObjLeft(currObj); + +- if (d < topKey || sortedArr.size() < ef_) { ++ if (d < topKey || sortedArr.size() < ef) { + CHECK_MSG(itemBuff.size() > itemQty, + "Perhaps a bug: buffer size is not enough " + + ConvertToString(itemQty) + " >= " + ConvertToString(itemBuff.size())); +diff --git a/similarity_search/src/method/hnsw_distfunc_opt.cc b/similarity_search/src/method/hnsw_distfunc_opt.cc +index 5c219cd..1913936 100644 +--- a/similarity_search/src/method/hnsw_distfunc_opt.cc ++++ b/similarity_search/src/method/hnsw_distfunc_opt.cc +@@ -120,6 +120,7 @@ namespace similarity { + PREFETCH(data_level0_memory_ + (*(data + 1)) * memoryPerObject_ + offsetData_, _MM_HINT_T0); + PREFETCH((char *)(data + 2), _MM_HINT_T0); + ++ size_t ef = this->extractEf(query, ef_); + for (int j = 1; j <= size; j++) { + int tnum = *(data + j); + PREFETCH((char *)(massVisited + *(data + j + 1)), _MM_HINT_T0); +@@ -131,7 +132,7 @@ namespace similarity { + massVisited[tnum] = currentV; + char *currObj1 = (data_level0_memory_ + tnum * memoryPerObject_ + offsetData_); + dist_t d = (fstdistfunc_(pVectq, (float *)(currObj1 + 16), qty, TmpRes)); +- if (closestDistQueuei.top().getDistance() > d || closestDistQueuei.size() < ef_) { ++ if (closestDistQueuei.top().getDistance() > d || closestDistQueuei.size() < ef) { + candidateQueuei.emplace(-d, tnum); + PREFETCH(data_level0_memory_ + candidateQueuei.top().element * memoryPerObject_ + offsetLevel0_, + _MM_HINT_T0); +@@ -139,7 +140,7 @@ namespace similarity { + query->CheckAndAddToResult(d, data_rearranged_[tnum]); + closestDistQueuei.emplace(d, tnum); + +- if (closestDistQueuei.size() > ef_) { ++ if (closestDistQueuei.size() > ef) { + closestDistQueuei.pop(); + } + } +@@ -153,6 +154,7 @@ namespace similarity { + void + Hnsw::SearchV1Merge(KNNQuery *query, bool normalize) + { ++ size_t ef = this->extractEf(query, ef_); + float *pVectq = (float *)((char *)query->QueryObject()->data()); + TMP_RES_ARRAY(TmpRes); + size_t qty = query->QueryObject()->datalength() >> 2; +@@ -197,7 +199,7 @@ namespace similarity { + } + } + +- SortArrBI sortedArr(max(ef_, query->GetK())); ++ SortArrBI sortedArr(max(ef, query->GetK())); + sortedArr.push_unsorted_grow(curdist, curNodeNum); + + int_fast32_t currElem = 0; +@@ -208,7 +210,7 @@ namespace similarity { + + massVisited[curNodeNum] = currentV; + +- while (currElem < min(sortedArr.size(), ef_)) { ++ while (currElem < min(sortedArr.size(), ef)) { + auto &e = queueData[currElem]; + CHECK(!e.used); + e.used = true; +@@ -237,7 +239,7 @@ namespace similarity { + char *currObj1 = (data_level0_memory_ + tnum * memoryPerObject_ + offsetData_); + dist_t d = (fstdistfunc_(pVectq, (float *)(currObj1 + 16), qty, TmpRes)); + +- if (d < topKey || sortedArr.size() < ef_) { ++ if (d < topKey || sortedArr.size() < ef) { + CHECK_MSG(itemBuff.size() > itemQty, + "Perhaps a bug: buffer size is not enough " + + ConvertToString(itemQty) + " >= " + ConvertToString(itemBuff.size())); +-- +2.44.0 + diff --git a/jni/src/commons.cpp b/jni/src/commons.cpp index 3c03ac49d..c2b2354cc 100644 --- a/jni/src/commons.cpp +++ b/jni/src/commons.cpp @@ -38,4 +38,16 @@ void knn_jni::commons::freeVectorData(jlong memoryAddressJ) { delete vect; } } -#endif //OPENSEARCH_KNN_COMMONS_H \ No newline at end of file + +int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilInterface * jniUtil, std::unordered_map methodParams, std::string methodParam, int defaultValue) { + if (methodParams.empty()) { + return defaultValue; + } + auto efSearchIt = methodParams.find(methodParam); + if (efSearchIt != methodParams.end()) { + return jniUtil->ConvertJavaObjectToCppInteger(env, methodParams[methodParam]); + } + + return defaultValue; +} +#endif //OPENSEARCH_KNN_COMMONS_H diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 5a0910d9a..198f733be 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -22,6 +22,7 @@ #include "faiss/Index.h" #include "faiss/impl/IDSelector.h" #include "faiss/IndexIVFPQ.h" +#include "commons.h" #include #include @@ -296,12 +297,12 @@ void knn_jni::faiss_wrapper::SetSharedIndexState(jlong indexPointerJ, jlong shar } jobjectArray knn_jni::faiss_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, nullptr, 0, parentIdsJ); + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ) { + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, nullptr, 0, parentIdsJ); } jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); @@ -313,6 +314,11 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter throw std::runtime_error("Invalid pointer to index"); } + std::unordered_map methodParams; + if (methodParamsJ != nullptr) { + methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + } + // The ids vector will hold the top k ids from the search and the dis vector will hold the top k distances from // the query point std::vector dis(kJ); @@ -340,9 +346,8 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); if(hnswReader) { - // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default - // value of ef_search = 16 which will then be used. - hnswParams.efSearch = hnswReader->hnsw.efSearch; + // Query param efsearch supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); hnswParams.sel = idSelector.get(); if (parentIdsJ != nullptr) { idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); @@ -371,12 +376,13 @@ jobjectArray knn_jni::faiss_wrapper::QueryIndex_WithFilter(knn_jni::JNIUtilInter std::unique_ptr idGrouper; std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); - if(hnswReader!= nullptr && parentIdsJ != nullptr) { - // Setting the ef_search value equal to what was provided during index creation. SearchParametersHNSW has a default - // value of ef_search = 16 which will then be used. - hnswParams.efSearch = hnswReader->hnsw.efSearch; - idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); - hnswParams.grp = idGrouper.get(); + if(hnswReader!= nullptr) { + // Query param efsearch supersedes ef_search provided during index setting. + hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); + if (parentIdsJ != nullptr) { + idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap); + hnswParams.grp = idGrouper.get(); + } searchParameters = &hnswParams; } try { diff --git a/jni/src/nmslib_wrapper.cpp b/jni/src/nmslib_wrapper.cpp index 6ea80d727..21b34eb83 100644 --- a/jni/src/nmslib_wrapper.cpp +++ b/jni/src/nmslib_wrapper.cpp @@ -12,6 +12,8 @@ #include "jni_util.h" #include "nmslib_wrapper.h" +#include "commons.h" + #include "init.h" #include "index.h" #include "params.h" @@ -24,6 +26,8 @@ #include #include +#include "hnswquery.h" + std::string TranslateSpaceType(const std::string& spaceType); @@ -220,7 +224,7 @@ jlong knn_jni::nmslib_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JN } jobjectArray knn_jni::nmslib_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ) { + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ) { if (queryVectorJ == nullptr) { throw std::runtime_error("Query Vector cannot be null"); @@ -243,14 +247,34 @@ jobjectArray knn_jni::nmslib_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jni jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); throw; } + jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryvector, JNI_ABORT); + std::unordered_map methodParams; + if (methodParamsJ != nullptr) { + methodParams = jniUtil->ConvertJavaMapToCppMap(env, methodParamsJ); + } - similarity::KNNQuery knnQuery(*(indexWrapper->space), queryObject.get(), kJ); - indexWrapper->index->Search(&knnQuery); + int queryEfSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, -1); + similarity::KNNQuery* query; // TODO: Replace with smart pointers https://github.com/opensearch-project/k-NN/issues/1785 + std::unique_ptr> neighbors; + try { + if (queryEfSearch == -1) { + query = new similarity::KNNQuery(*(indexWrapper->space), queryObject.get(), kJ); + } else { + query = new similarity::HNSWQuery(*(indexWrapper->space), queryObject.get(), kJ, queryEfSearch); + } - std::unique_ptr> neighbors(knnQuery.Result()->Clone()); + indexWrapper->index->Search(query); + neighbors.reset(query->Result()->Clone()); + } catch (...) { + if (query != nullptr) { + delete query; + } + throw; + } + delete query; + int resultSize = neighbors->Size(); - jclass resultClass = jniUtil->FindClass(env,"org/opensearch/knn/index/query/KNNQueryResult"); jmethodID allArgs = jniUtil->FindMethod(env, "org/opensearch/knn/index/query/KNNQueryResult", ""); @@ -265,6 +289,7 @@ jobjectArray knn_jni::nmslib_wrapper::QueryIndex(knn_jni::JNIUtilInterface * jni result = jniUtil->NewObject(env, resultClass, allArgs, id, distance); jniUtil->SetObjectArrayElement(env, results, i, result); } + return results; } diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 0aa51987d..57353f9e1 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -109,10 +109,10 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setSharedIndexSt JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndex(JNIEnv * env, jclass cls, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ) + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); @@ -121,10 +121,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryInd } JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_queryIndexWithFilter - (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { + (JNIEnv * env, jclass cls, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filteredIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { try { - return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); + return knn_jni::faiss_wrapper::QueryIndex_WithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ, filteredIdsJ, filterIdsTypeJ, parentIdsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/src/org_opensearch_knn_jni_NmslibService.cpp b/jni/src/org_opensearch_knn_jni_NmslibService.cpp index d037d3337..e265827cd 100644 --- a/jni/src/org_opensearch_knn_jni_NmslibService.cpp +++ b/jni/src/org_opensearch_knn_jni_NmslibService.cpp @@ -61,10 +61,10 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_NmslibService_loadIndex(JNIE JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_NmslibService_queryIndex(JNIEnv * env, jclass cls, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ) + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ) { try { - return knn_jni::nmslib_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ); + return knn_jni::nmslib_wrapper::QueryIndex(&jniUtil, env, indexPointerJ, queryVectorJ, kJ, methodParamsJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/commons_test.cpp b/jni/tests/commons_test.cpp index 09323f0fb..630358919 100644 --- a/jni/tests/commons_test.cpp +++ b/jni/tests/commons_test.cpp @@ -71,3 +71,22 @@ TEST(CommonsTests, BasicAssertions) { } } } + +TEST(CommonTests, GetIntegerMethodParam) { + JNIEnv *jniEnv = nullptr; + testing::NiceMock mockJNIUtil; + + std::unordered_map methodParams1; + int efSearch = 10; + methodParams1[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + + int actualValue1 = knn_jni::commons::getIntegerMethodParameter(jniEnv, &mockJNIUtil, methodParams1, knn_jni::EF_SEARCH, 1); + EXPECT_EQ(efSearch, actualValue1); + + int actualValue2 = knn_jni::commons::getIntegerMethodParameter(jniEnv, &mockJNIUtil, methodParams1, "param", 1); + EXPECT_EQ(1, actualValue2); + + std::unordered_map methodParams2; + int actualValue3 = knn_jni::commons::getIntegerMethodParameter(jniEnv, &mockJNIUtil, methodParams2, knn_jni::EF_SEARCH, 1); + EXPECT_EQ(1, actualValue3); +} diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 4cd3b319e..1db3df42c 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -245,6 +245,10 @@ TEST(FaissQueryIndexTest, BasicAssertions) { // Define query data int k = 10; + int efSearch = 20; + std::unordered_map methodParams; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + int numQueries = 100; std::vector> queries; @@ -266,6 +270,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; + auto methodParamsJ = reinterpret_cast(&methodParams); for (auto query : queries) { std::unique_ptr *>> results( @@ -273,7 +278,7 @@ TEST(FaissQueryIndexTest, BasicAssertions) { knn_jni::faiss_wrapper::QueryIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k, nullptr))); + reinterpret_cast(&query), k, methodParamsJ, nullptr))); ASSERT_EQ(k, results->size()); @@ -339,7 +344,7 @@ TEST(FaissQueryIndexWithFilterTest1435, BasicAssertions) { knn_jni::faiss_wrapper::QueryIndex_WithFilter( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k, + reinterpret_cast(&query), k, nullptr, reinterpret_cast(&bitmap), 0, nullptr))); ASSERT_TRUE(results->size() <= filterIds.size()); @@ -397,20 +402,20 @@ TEST(FaissQueryIndexWithParentFilterTest, BasicAssertions) { auto createdIndexWithData = test_util::FaissAddData(createdIndex.get(), ids, vectors); + int efSearch = 100; + std::unordered_map methodParams; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + // Setup jni JNIEnv *jniEnv = nullptr; NiceMock mockJNIUtil; - EXPECT_CALL(mockJNIUtil, - GetJavaIntArrayLength( - jniEnv, reinterpret_cast(&parentIds))) - .WillRepeatedly(Return(parentIds.size())); for (auto query : queries) { std::unique_ptr *>> results( reinterpret_cast *> *>( knn_jni::faiss_wrapper::QueryIndex( &mockJNIUtil, jniEnv, reinterpret_cast(&createdIndexWithData), - reinterpret_cast(&query), k, + reinterpret_cast(&query), k, reinterpret_cast(&methodParams), reinterpret_cast(&parentIds)))); // Even with k 20, result should have only 10 which is total number of groups diff --git a/jni/tests/faiss_wrapper_unit_test.cpp b/jni/tests/faiss_wrapper_unit_test.cpp new file mode 100644 index 000000000..ea9131dd7 --- /dev/null +++ b/jni/tests/faiss_wrapper_unit_test.cpp @@ -0,0 +1,251 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +#include "faiss_wrapper.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jni_util.h" +#include "jni.h" +#include "test_util.h" +#include "faiss/IndexHNSW.h" +#include "faiss/IndexIDMap.h" + +using ::testing::NiceMock; + +using idx_t = faiss::idx_t; + +struct MockIndex : faiss::IndexHNSW { + explicit MockIndex(idx_t d) : faiss::IndexHNSW(d, 32) { + } +}; + + +struct MockIdMap : faiss::IndexIDMap { + mutable idx_t nCalled; + mutable const float *xCalled; + mutable idx_t kCalled; + mutable float *distancesCalled; + mutable idx_t *labelsCalled; + mutable const faiss::SearchParametersHNSW *paramsCalled; + + explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate(index) { + } + + void search( + idx_t n, + const float *x, + idx_t k, + float *distances, + idx_t *labels, + const faiss::SearchParameters *params) const override { + nCalled = n; + xCalled = x; + kCalled = k; + distancesCalled = distances; + labelsCalled = labels; + paramsCalled = dynamic_cast(params); + } + + void resetMock() const { + nCalled = 0; + xCalled = nullptr; + kCalled = 0; + distancesCalled = nullptr; + labelsCalled = nullptr; + paramsCalled = nullptr; + } +}; + +struct QueryIndexHNSWTestInput { + string description; + int k; + int efSearch; + int filterIdType; + bool filterIdsPresent; + bool parentIdsPresent; +}; + + + +class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam { +public: + FaissWrappeterParametrizedTestFixture() : index_(3), id_map_(&index_) { + index_.hnsw.efSearch = 100; // assigning 100 to make sure default of 16 is not used anywhere + }; + +protected: + MockIndex index_; + MockIdMap id_map_; +}; + +namespace query_index_test { + + std::unordered_map methodParams; + + + TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexHNSWTests) { + //Given + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + + QueryIndexHNSWTestInput const &input = GetParam(); + float query[] = {1.2, 2.3, 3.4}; + + int efSearch = input.efSearch; + int expectedEfSearch = 100; //default set in mock + std::unordered_map methodParams; + if (efSearch != -1) { + expectedEfSearch = input.efSearch; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + } + + std::vector *parentIdPtr = nullptr; + if (input.parentIdsPresent) { + std::vector parentId; + parentId.reserve(2); + parentId.push_back(1); + parentId.push_back(2); + parentIdPtr = &parentId; + + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(parentIdPtr))) + .WillOnce(testing::Return(parentId.size())); + + EXPECT_CALL(mockJNIUtil, + GetIntArrayElements( + jniEnv, reinterpret_cast(parentIdPtr), nullptr)) + .WillOnce(testing::Return(new int[2]{1, 2})); + } + + // When + knn_jni::faiss_wrapper::QueryIndex( + &mockJNIUtil, jniEnv, + reinterpret_cast(&id_map_), + reinterpret_cast(&query), input.k, reinterpret_cast(&methodParams), + reinterpret_cast(parentIdPtr)); + + //Then + int actualEfSearch = id_map_.paramsCalled->efSearch; + // Asserting the captured argument + EXPECT_EQ(input.k, id_map_.kCalled); + EXPECT_EQ(expectedEfSearch, actualEfSearch); + if (input.parentIdsPresent) { + faiss::IDGrouper *grouper = id_map_.paramsCalled->grp; + EXPECT_TRUE(grouper != nullptr); + } + + id_map_.resetMock(); + } + + INSTANTIATE_TEST_CASE_P( + QueryIndexHNSWTests, + FaissWrappeterParametrizedTestFixture, + ::testing::Values( + QueryIndexHNSWTestInput{"algoParams present, parent absent", 10, 200, 0, false, false}, + QueryIndexHNSWTestInput{"algoParams absent, parent absent", 10, -1, 0, false, false}, + QueryIndexHNSWTestInput{"algoParams present, parent present", 10, 200, 0, false, true}, + QueryIndexHNSWTestInput{"algoParams absent, parent present", 10, -1, 0, false, true} + ) + ); +} + +namespace query_index_with_filter_test { + + TEST_P(FaissWrappeterParametrizedTestFixture, QueryIndexWithFilterHNSWTests) { + //Given + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + + QueryIndexHNSWTestInput const &input = GetParam(); + float query[] = {1.2, 2.3, 3.4}; + + std::vector *parentIdPtr = nullptr; + if (input.parentIdsPresent) { + std::vector parentId; + parentId.reserve(2); + parentId.push_back(1); + parentId.push_back(2); + parentIdPtr = &parentId; + + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength( + jniEnv, reinterpret_cast(parentIdPtr))) + .WillOnce(testing::Return(parentId.size())); + + EXPECT_CALL(mockJNIUtil, + GetIntArrayElements( + jniEnv, reinterpret_cast(parentIdPtr), nullptr)) + .WillOnce(testing::Return(new int[2]{1, 2})); + } + + std::vector *filterptr = nullptr; + if (input.filterIdsPresent) { + std::vector filter; + filter.reserve(2); + filter.push_back(1); + filter.push_back(2); + filterptr = &filter; + } + + int efSearch = input.efSearch; + int expectedEfSearch = 100; //default set in mock + std::unordered_map methodParams; + if (efSearch != -1) { + expectedEfSearch = input.efSearch; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + } + + // When + knn_jni::faiss_wrapper::QueryIndex_WithFilter( + &mockJNIUtil, jniEnv, + reinterpret_cast(&id_map_), + reinterpret_cast(&query), input.k, reinterpret_cast(&methodParams), + reinterpret_cast(filterptr), + input.filterIdType, + reinterpret_cast(parentIdPtr)); + + //Then + int actualEfSearch = id_map_.paramsCalled->efSearch; + // Asserting the captured argument + EXPECT_EQ(input.k, id_map_.kCalled); + EXPECT_EQ(expectedEfSearch, actualEfSearch); + if (input.parentIdsPresent) { + faiss::IDGrouper *grouper = id_map_.paramsCalled->grp; + EXPECT_TRUE(grouper != nullptr); + } + if (input.filterIdsPresent) { + faiss::IDSelector *sel = id_map_.paramsCalled->sel; + EXPECT_TRUE(sel != nullptr); + } + id_map_.resetMock(); + } + + INSTANTIATE_TEST_CASE_P( + QueryIndexWithFilterHNSWTests, + FaissWrappeterParametrizedTestFixture, + ::testing::Values( + QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent", 10, 200, 0, false, false}, + QueryIndexHNSWTestInput{"algoParams present, parent absent, filter absent, filter type 1", 10, 200, 1, false, false}, + QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present", 10, -1, 0, true, false}, + QueryIndexHNSWTestInput{"algoParams absent, parent absent, filter present, filter type 1", 10, -1, 1, true, false}, + QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent", 10, 200, 0, false, true}, + QueryIndexHNSWTestInput{"algoParams present, parent present, filter absent, filter type 1", 10, 150, 1, false, true}, + QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present", 10, -1, 0, true, true}, + QueryIndexHNSWTestInput{"algoParams absent, parent present, filter present, filter type 1",10, -1, 1, true, true} + ) + ); +} diff --git a/jni/tests/nmslib_wrapper_test.cpp b/jni/tests/nmslib_wrapper_test.cpp index 1fd9471b0..4e0c57044 100644 --- a/jni/tests/nmslib_wrapper_test.cpp +++ b/jni/tests/nmslib_wrapper_test.cpp @@ -182,8 +182,11 @@ TEST(NmslibQueryIndexTest, BasicAssertions) { // Define query data int k = 10; + int efSearch = 20; int numQueries = 100; std::vector> queries; + std::unordered_map methodParams; + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); for (int i = 0; i < numQueries; i++) { std::vector query; @@ -205,7 +208,7 @@ TEST(NmslibQueryIndexTest, BasicAssertions) { knn_jni::nmslib_wrapper::QueryIndex( &mockJNIUtil, jniEnv, reinterpret_cast(indexWrapper.get()), - reinterpret_cast(&query), k))); + reinterpret_cast(&query), k, nullptr))); ASSERT_EQ(k, results->size()); diff --git a/jni/tests/nmslib_wrapper_unit_test.cpp b/jni/tests/nmslib_wrapper_unit_test.cpp new file mode 100644 index 000000000..ff6b94fc7 --- /dev/null +++ b/jni/tests/nmslib_wrapper_unit_test.cpp @@ -0,0 +1,135 @@ +/* +* SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ +#include "hnswquery.h" +#include "knnquery.h" +#include "nmslib_wrapper.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "jni_util.h" +#include "jni.h" +#include "test_util.h" +#include "method/hnsw.h" +#include "space/space_dummy.h" + +namespace nmslib_query_index_test { + + using ::testing::NiceMock; + + struct QueryIndexHNSWTestInput { + string description; + int k; + int efSearch; + bool expectedHNSWQuery; + }; + + struct MockNMSIndex : similarity::Hnsw { + mutable int kCalled; + mutable int efCalled = -1; + + explicit MockNMSIndex(const similarity::Space &space, const similarity::ObjectVector &data): Hnsw(false, + space, data) { + std::vector input; + input.emplace_back("ef=10"); + similarity::AnyParams ef(input); + this->Hnsw::SetQueryTimeParams(ef); + } + + void Search(similarity::KNNQuery *query, similarity::IdType id) const override { + auto hnsw = dynamic_cast *>(query); + if (hnsw != nullptr) { + kCalled = hnsw->GetK(); + efCalled = hnsw->getEf(); + } else { + kCalled = query->GetK(); + } + similarity::Object object(5, 0, 3*sizeof(float), new float[] { 2.2f, 2.5f, 2.6f }); + similarity::Object* objectPtr = &object; + bool added = query->CheckAndAddToResult(0.0f, objectPtr); + }; + + void resetMocks() const { + kCalled = -1; + efCalled = -1; + } + }; + + class NmslibWrapperParametrizedTestFixture : public testing::TestWithParam { + public: + NmslibWrapperParametrizedTestFixture() : space_(nullptr), index_(nullptr) { + similarity::initLibrary(); + std::string spaceType = knn_jni::L2; + space_ = similarity::SpaceFactoryRegistry::Instance().CreateSpace( + spaceType, similarity::AnyParams()); + index_ = new MockNMSIndex(*space_, similarity::ObjectVector()); + }; + + protected: + MockNMSIndex* index_; + similarity::Space* space_; // Moved from local to member variable + }; + + + TEST_P(NmslibWrapperParametrizedTestFixture, QueryIndexHNSWTests) { + //Given + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + + + QueryIndexHNSWTestInput const &input = GetParam(); + float query[] = { 1.2f, 2.3f, 3.4f }; + + std::string spaceType = knn_jni::L2; + std::unique_ptr indexWrapper( + new knn_jni::nmslib_wrapper::IndexWrapper(spaceType)); + indexWrapper->index.reset(index_); + + int efSearch = input.efSearch; + std::unordered_map methodParams; + if (efSearch != -1) { + methodParams[knn_jni::EF_SEARCH] = reinterpret_cast(&efSearch); + } + EXPECT_CALL(mockJNIUtil, + GetJavaFloatArrayLength( + jniEnv, reinterpret_cast(query))) + .WillOnce(testing::Return(3)); + + EXPECT_CALL(mockJNIUtil, + ReleaseFloatArrayElements( + jniEnv, reinterpret_cast(query), query, JNI_ABORT)); + EXPECT_CALL(mockJNIUtil, + GetFloatArrayElements( + jniEnv, reinterpret_cast(query), nullptr)) + .WillOnce(testing::Return(query)); + + knn_jni::nmslib_wrapper::QueryIndex( + &mockJNIUtil, jniEnv, + reinterpret_cast(indexWrapper.get()), + reinterpret_cast(&query), input.k, reinterpret_cast(&methodParams)); + + if (input.expectedHNSWQuery) { + EXPECT_EQ(input.efSearch, index_->efCalled); + EXPECT_EQ(input.k, index_->kCalled); + } else { + EXPECT_EQ(input.k, index_->kCalled); + } + index_->resetMocks(); + } + + INSTANTIATE_TEST_CASE_P( + QueryIndexHNSWTests, + NmslibWrapperParametrizedTestFixture, + ::testing::Values( + QueryIndexHNSWTestInput{"methodParams present", 10, 200, true}, + QueryIndexHNSWTestInput{"methodParams absent", 5, -1, false } + ) + ); +} diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index fed2d39da..7df651a3c 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -177,7 +177,7 @@ public void searchKNNModel(String testModelID) throws IOException { } // Confirm that the model gets created using Get Model API - public void validateModelCreated(String modelId) throws IOException, InterruptedException { + public void validateModelCreated(String modelId) throws Exception { Response getResponse = getModel(modelId, null); String responseBody = EntityUtils.toString(getResponse.getEntity()); assertNotNull(responseBody); diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java new file mode 100644 index 000000000..566f40383 --- /dev/null +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java @@ -0,0 +1,37 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.bwc; + +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; + +public class QueryANNIT extends AbstractRestartUpgradeTestCase { + + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 5; + private static final int K = 5; + private static final Integer EF_SEARCH = 10; + private static final int NUM_DOCS = 10; + + public void testQueryANN() throws Exception { + if (isRunningAgainstOldCluster()) { + createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS)); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, 0, NUM_DOCS); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + } else { + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K, Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH)); + deleteKNNIndex(testIndex); + } + } +} diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java index 73adb6db8..10df1a79b 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/IndexingIT.java @@ -5,8 +5,6 @@ package org.opensearch.knn.bwc; -import java.io.IOException; - import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; public class IndexingIT extends AbstractRollingUpgradeTestCase { @@ -129,7 +127,7 @@ public void testKNNIndexCreation_withMethodMapper() throws Exception { } // validation steps for indexing after upgrading each node from old version to new version - public void validateKNNIndexingOnUpgrade(int totalDocsCount, int docId) throws IOException { + public void validateKNNIndexingOnUpgrade(int totalDocsCount, int docId) throws Exception { validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, totalDocsCount, K); addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docId, NUM_DOCS); diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java new file mode 100644 index 000000000..080e63241 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/QueryANNIT.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.bwc; + +import java.util.Map; + +import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; + +public class QueryANNIT extends AbstractRollingUpgradeTestCase { + private static final String TEST_FIELD = "test-field"; + private static final int DIMENSIONS = 5; + private static final int K = 5; + private static final Integer EF_SEARCH = 10; + private static final int NUM_DOCS = 10; + private static final String ALGORITHM = "hnsw"; + + public void testQueryANNIT() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + switch (getClusterType()) { + case OLD: + createKnnIndex( + testIndex, + getKNNDefaultIndexSettings(), + createKnnIndexMapping(TEST_FIELD, DIMENSIONS, ALGORITHM, FAISS_NAME) + ); + addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, 0, NUM_DOCS); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + break; + case MIXED: + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + break; + case UPGRADED: + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K); + validateKNNSearch(testIndex, TEST_FIELD, DIMENSIONS, NUM_DOCS, K, Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH)); + deleteKNNIndex(testIndex); + } + } +} diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java index a34f4b3cf..9cbb99d87 100644 --- a/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/WarmupIT.java @@ -5,7 +5,6 @@ package org.opensearch.knn.bwc; -import java.io.IOException; import java.util.Collections; import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; @@ -47,7 +46,7 @@ public void testKNNWarmup() throws Exception { } // validation steps for KNN Warmup after upgrading each node from old version to new version - public void validateKNNWarmupOnUpgrade(int totalDocsCount, int docId) throws IOException { + public void validateKNNWarmupOnUpgrade(int totalDocsCount, int docId) throws Exception { int graphCount = getTotalGraphsInCache(); knnWarmup(Collections.singletonList(testIndex)); assertTrue(getTotalGraphsInCache() > graphCount); diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index 0b4538ec8..e73212afe 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -26,6 +26,7 @@ public class KNNConstants { public static final String K = "k"; public static final String TYPE_KNN_VECTOR = "knn_vector"; public static final String PROPERTIES = "properties"; + public static final String METHOD_PARAMETER = "method_parameters"; public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search"; public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; public static final String METHOD_PARAMETER_M = "m"; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index e4523bb5e..5cd4aaf81 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -20,6 +20,7 @@ import org.opensearch.common.ValidationException; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.request.MethodParameter; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -46,15 +47,29 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT = Version.V_2_12_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; - public static final Map minimalRequiredVersionMap = new HashMap() { - { - put("filter", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); - put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); - put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); - put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); - put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0; + // public so neural search can access it + public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); + + private static Map initializeMinimalRequiredVersionMap() { + final Map versionMap = new HashMap<>() { + { + put("filter", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER); + put("ignore_unmapped", MINIMAL_SUPPORTED_VERSION_FOR_IGNORE_UNMAPPED); + put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); + put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); + put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); + put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS); + } + }; + + for (final MethodParameter methodParameter : MethodParameter.values()) { + if (methodParameter.getVersion() != null) { + versionMap.put(methodParameter.getName(), methodParameter.getVersion()); + } } - }; + return Collections.unmodifiableMap(versionMap); + } /** * Determines the size of a file on disk in kilobytes diff --git a/src/main/java/org/opensearch/knn/index/MethodComponent.java b/src/main/java/org/opensearch/knn/index/MethodComponent.java index 256d55ee5..b344772f7 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponent.java @@ -19,11 +19,13 @@ import org.opensearch.knn.index.util.IndexHyperParametersUtil; import org.opensearch.knn.training.VectorSpaceInfo; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.function.BiFunction; +import java.util.List; +import java.util.ArrayList; + +import static org.opensearch.knn.validation.ParameterValidator.validateParameters; /** * MethodComponent defines the structure of an individual component that can make up an index @@ -75,32 +77,7 @@ public Map getAsMap(MethodComponentContext methodComponentContex */ public ValidationException validate(MethodComponentContext methodComponentContext) { Map providedParameters = methodComponentContext.getParameters(); - List errorMessages = new ArrayList<>(); - - if (providedParameters == null) { - return null; - } - - ValidationException parameterValidation; - for (Map.Entry parameter : providedParameters.entrySet()) { - if (!parameters.containsKey(parameter.getKey())) { - errorMessages.add(String.format("Invalid parameter for method \"%s\".", getName())); - continue; - } - - parameterValidation = parameters.get(parameter.getKey()).validate(parameter.getValue()); - if (parameterValidation != null) { - errorMessages.addAll(parameterValidation.validationErrors()); - } - } - - if (errorMessages.isEmpty()) { - return null; - } - - ValidationException validationException = new ValidationException(); - validationException.addValidationErrors(errorMessages); - return validationException; + return validateParameters(parameters, providedParameters); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java index 3146cd33e..a02c090b1 100644 --- a/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/BaseQueryFactory.java @@ -20,6 +20,7 @@ import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; +import java.util.Map; import java.util.Optional; /** @@ -42,6 +43,7 @@ public static class CreateQueryRequest { private float[] vector; private byte[] byteVector; private VectorDataType vectorDataType; + private Map methodParameters; private Integer k; private Float radius; private QueryBuilder filter; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 0862b2d93..d123cc149 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -5,10 +5,8 @@ package org.opensearch.knn.index.query; -import java.util.Arrays; -import java.util.Objects; - import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Getter; import lombok.Setter; import org.apache.lucene.search.BooleanClause; @@ -23,26 +21,29 @@ import org.opensearch.knn.index.KNNSettings; import java.io.IOException; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; /** * Custom KNN query. Query is used for KNNEngine's that create their own custom segment files. These files need to be * loaded and queried in a custom manner throughout the query path. */ +@Getter +@Builder +@AllArgsConstructor public class KNNQuery extends Query { private final String field; private final float[] queryVector; private int k; + private Map methodParameters; private final String indexName; - @Getter @Setter private Query filterQuery; - @Getter private BitSetProducer parentsFilter; - @Getter - private Float radius = null; - @Getter + private Float radius; private Context context; public KNNQuery( @@ -123,22 +124,6 @@ public KNNQuery filterQuery(Query filterQuery) { return this; } - public String getField() { - return this.field; - } - - public float[] getQueryVector() { - return this.queryVector; - } - - public int getK() { - return this.k; - } - - public String getIndexName() { - return this.indexName; - } - /** * Constructs Weight implementation for this query * @@ -183,7 +168,17 @@ public String toString(String field) { @Override public int hashCode() { - return Objects.hash(field, Arrays.hashCode(queryVector), k, indexName, filterQuery); + return Objects.hash( + field, + Arrays.hashCode(queryVector), + k, + indexName, + filterQuery, + context, + parentsFilter, + radius, + methodParameters + ); } @Override @@ -192,10 +187,15 @@ public boolean equals(Object other) { } private boolean equalsTo(KNNQuery other) { + if (other == this) return true; return Objects.equals(field, other.field) && Arrays.equals(queryVector, other.queryVector) && Objects.equals(k, other.k) + && Objects.equals(methodParameters, other.methodParameters) + && Objects.equals(radius, other.radius) + && Objects.equals(context, other.context) && Objects.equals(indexName, other.indexName) + && Objects.equals(parentsFilter, other.parentsFilter) && Objects.equals(filterQuery, other.filterQuery); } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 88bcc84bc..1dec98c90 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -5,10 +5,14 @@ package org.opensearch.knn.index.query; +import lombok.AccessLevel; +import lombok.AllArgsConstructor; +import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.opensearch.common.ValidationException; import org.opensearch.core.ParseField; import org.opensearch.core.common.ParsingException; import org.opensearch.core.common.Strings; @@ -22,11 +26,15 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.util.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.VectorQueryType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.index.query.parser.MethodParametersParser; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; @@ -35,18 +43,25 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; import static org.opensearch.knn.common.KNNValidationUtil.validateByteVectorValue; import static org.opensearch.knn.index.IndexUtil.isClusterOnOrAfterMinRequiredVersion; import static org.opensearch.knn.index.IndexUtil.minimalRequiredVersionMap; +import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; +import static org.opensearch.knn.validation.ParameterValidator.validateParameters; /** * Helper class to build the KNN query */ +// The builder validates the member variables so access to the constructor is prohibited to not accidentally bypass validations +@AllArgsConstructor(access = AccessLevel.PRIVATE) @Log4j2 public class KNNQueryBuilder extends AbstractQueryBuilder { private static ModelDao modelDao; @@ -57,6 +72,8 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); public static final ParseField MAX_DISTANCE_FIELD = new ParseField(MAX_DISTANCE); public static final ParseField MIN_SCORE_FIELD = new ParseField(MIN_SCORE); + public static final ParseField EF_SEARCH_FIELD = new ParseField(METHOD_PARAMETER_EF_SEARCH); + public static final ParseField METHOD_PARAMS_FIELD = new ParseField(METHOD_PARAMETER); public static final int K_MAX = 10000; /** * The name for the knn query @@ -67,18 +84,27 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { */ private final String fieldName; private final float[] vector; - private int k = 0; - private Float maxDistance = null; - private Float minScore = null; + @Getter + private int k; + @Getter + private Float maxDistance; + @Getter + private Float minScore; + @Getter + private Map methodParameters; + @Getter private QueryBuilder filter; - private boolean ignoreUnmapped = false; + @Getter + private boolean ignoreUnmapped; /** * Constructs a new query with the given field name and vector * * @param fieldName Name of the field * @param vector Array of floating points + * @deprecated Use {@code {@link KNNQueryBuilder.Builder}} instead */ + @Deprecated public KNNQueryBuilder(String fieldName, float[] vector) { if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); @@ -94,61 +120,136 @@ public KNNQueryBuilder(String fieldName, float[] vector) { } /** - * Builder method for k - * - * @param k K nearest neighbours for the given vector + * lombok SuperBuilder annotation requires a builder annotation on parent class to work well + * {@link AbstractQueryBuilder#boost()} and {@link AbstractQueryBuilder#queryName()} both need to be called + * A custom builder helps with the calls to the parent class, simultaneously addressing the problem of telescoping + * constructors in this class. */ - public KNNQueryBuilder k(Integer k) { - if (k == null) { - throw new IllegalArgumentException(String.format("[%s] requires k to be set", NAME)); + public static class Builder { + private String fieldName; + private float[] vector; + private Integer k; + private Map methodParameters; + private Float maxDistance; + private Float minScore; + private QueryBuilder filter; + private boolean ignoreUnmapped; + private String queryName; + private float boost = DEFAULT_BOOST; + + private Builder() {} + + public Builder fieldName(String fieldName) { + this.fieldName = fieldName; + return this; } - validateSingleQueryType(k, maxDistance, minScore); - if (k <= 0 || k > K_MAX) { - throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, K_MAX)); + + public Builder vector(float[] vector) { + this.vector = vector; + return this; } - this.k = k; - return this; - } - /** - * Builder method for maxDistance - * - * @param maxDistance the maxDistance threshold for the nearest neighbours - */ - public KNNQueryBuilder maxDistance(Float maxDistance) { - if (maxDistance == null) { - throw new IllegalArgumentException(String.format("[%s] requires maxDistance to be set", NAME)); + public Builder k(Integer k) { + this.k = k; + return this; } - validateSingleQueryType(k, maxDistance, minScore); - this.maxDistance = maxDistance; - return this; - } - /** - * Builder method for minScore - * - * @param minScore the minScore threshold for the nearest neighbours - */ - public KNNQueryBuilder minScore(Float minScore) { - if (minScore == null) { - throw new IllegalArgumentException(String.format("[%s] requires minScore to be set", NAME)); + public Builder methodParameters(Map methodParameters) { + this.methodParameters = methodParameters; + return this; + } + + public Builder maxDistance(Float maxDistance) { + this.maxDistance = maxDistance; + return this; + } + + public Builder minScore(Float minScore) { + this.minScore = minScore; + return this; + } + + public Builder ignoreUnmapped(boolean ignoreUnmapped) { + this.ignoreUnmapped = ignoreUnmapped; + return this; + } + + public Builder filter(QueryBuilder filter) { + this.filter = filter; + return this; + } + + public Builder queryName(String queryName) { + this.queryName = queryName; + return this; + } + + public Builder boost(float boost) { + this.boost = boost; + return this; + } + + public KNNQueryBuilder build() { + validate(); + int k = this.k == null ? 0 : this.k; + return new KNNQueryBuilder(fieldName, vector, k, maxDistance, minScore, methodParameters, filter, ignoreUnmapped).boost(boost) + .queryName(queryName); } - validateSingleQueryType(k, maxDistance, minScore); - if (minScore <= 0) { - throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME)); + + private void validate() { + if (Strings.isNullOrEmpty(fieldName)) { + throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); + } + + if (vector == null) { + throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME)); + } else if (vector.length == 0) { + throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME)); + } + + if (k == null && minScore == null && maxDistance == null) { + throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + } + + if ((k != null && maxDistance != null) || (maxDistance != null && minScore != null) || (k != null && minScore != null)) { + throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); + } + + VectorQueryType vectorQueryType = VectorQueryType.MAX_DISTANCE; + if (k != null) { + vectorQueryType = VectorQueryType.K; + if (k <= 0 || k > K_MAX) { + final String errorMessage = "[" + NAME + "] requires k to be in the range (0, " + K_MAX + "]"; + throw new IllegalArgumentException(errorMessage); + } + } + + if (minScore != null) { + vectorQueryType = VectorQueryType.MIN_SCORE; + if (minScore <= 0) { + throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME)); + } + } + + if (methodParameters != null) { + ValidationException validationException = validateMethodParameters(methodParameters); + if (validationException != null) { + throw new IllegalArgumentException( + String.format("[%s] errors in method parameter [%s]", NAME, validationException.getMessage()) + ); + } + } + + // Update stats + vectorQueryType.getQueryStatCounter().increment(); + if (filter != null) { + vectorQueryType.getQueryWithFilterStatCounter().increment(); + } } - this.minScore = minScore; - return this; } - /** - * Builder method for filter - * - * @param filter QueryBuilder - */ - public KNNQueryBuilder filter(QueryBuilder filter) { - this.filter = filter; - return this; + public static KNNQueryBuilder.Builder builder() { + return new KNNQueryBuilder.Builder(); } /** @@ -158,12 +259,14 @@ public KNNQueryBuilder filter(QueryBuilder filter) { * @param vector Array of floating points * @param k K nearest neighbours for the given vector */ + @Deprecated public KNNQueryBuilder(String fieldName, float[] vector, int k) { this(fieldName, vector, k, null); } + @Deprecated public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) { - if (StringUtils.isBlank(fieldName)) { + if (Strings.isNullOrEmpty(fieldName)) { throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME)); } if (vector == null) { @@ -230,6 +333,10 @@ public KNNQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { minScore = in.readOptionalFloat(); } + if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) { + methodParameters = MethodParametersParser.streamInput(in, IndexUtil::isClusterOnOrAfterMinRequiredVersion); + } + } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -243,9 +350,10 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep Float maxDistance = null; Float minScore = null; QueryBuilder filter = null; - boolean ignoreUnmapped = false; String queryName = null; String currentFieldName = null; + boolean ignoreUnmapped = false; + Map methodParameters = null; XContentParser.Token token; while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -263,16 +371,16 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep boost = parser.floatValue(); } else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { k = (Integer) NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false); + } else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) { + if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { + ignoreUnmapped = parser.booleanValue(); + } } else if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { queryName = parser.text(); } else if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { maxDistance = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); } else if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { minScore = (Float) NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false); - } else if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) { - if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { - ignoreUnmapped = parser.booleanValue(); - } } else { throw new ParsingException( parser.getTokenLocation(), @@ -304,6 +412,8 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep ) ); } + } else if (METHOD_PARAMS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + methodParameters = MethodParametersParser.fromXContent(parser); } else { throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]"); } @@ -321,29 +431,18 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep } } - VectorQueryType vectorQueryType = validateSingleQueryType(k, maxDistance, minScore); - vectorQueryType.getQueryStatCounter().increment(); - if (filter != null) { - vectorQueryType.getQueryWithFilterStatCounter().increment(); - } - - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, ObjectsToFloats(vector)).filter(filter) + return KNNQueryBuilder.builder() + .queryName(queryName) .boost(boost) - .queryName(queryName); - - if (isClusterOnOrAfterMinRequiredVersion("ignoreUnmapped")) { - knnQueryBuilder.ignoreUnmapped(ignoreUnmapped); - } - - if (k != null) { - knnQueryBuilder.k(k); - } else if (maxDistance != null) { - knnQueryBuilder.maxDistance(maxDistance); - } else if (minScore != null) { - knnQueryBuilder.minScore(minScore); - } - - return knnQueryBuilder; + .fieldName(fieldName) + .vector(ObjectsToFloats(vector)) + .k(k) + .maxDistance(maxDistance) + .minScore(minScore) + .methodParameters(methodParameters) + .ignoreUnmapped(ignoreUnmapped) + .filter(filter) + .build(); } @Override @@ -365,6 +464,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { out.writeOptionalFloat(minScore); } + if (isClusterOnOrAfterMinRequiredVersion(METHOD_PARAMETER)) { + MethodParametersParser.streamOutput(out, methodParameters, IndexUtil::isClusterOnOrAfterMinRequiredVersion); + } } /** @@ -381,36 +483,6 @@ public Object vector() { return this.vector; } - public int getK() { - return this.k; - } - - public float getMaxDistance() { - return this.maxDistance; - } - - public float getMinScore() { - return this.minScore; - } - - public QueryBuilder getFilter() { - return this.filter; - } - - /** - * Sets whether the query builder should ignore unmapped paths (and run a - * {@link MatchNoDocsQuery} in place of this query) or throw an exception if - * the path is unmapped. - */ - public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) { - this.ignoreUnmapped = ignoreUnmapped; - return this; - } - - public boolean getIgnoreUnmapped() { - return this.ignoreUnmapped; - } - @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); @@ -430,6 +502,9 @@ public void doXContent(XContentBuilder builder, Params params) throws IOExceptio if (minScore != null) { builder.field(MIN_SCORE_FIELD.getPreferredName(), minScore); } + if (methodParameters != null) { + MethodParametersParser.doXContent(builder, methodParameters); + } printBoostAndQueryName(builder); builder.endObject(); builder.endObject(); @@ -450,6 +525,7 @@ protected Query doToQuery(QueryShardContext context) { KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType) mappedFieldType; int fieldDimension = knnVectorFieldType.getDimension(); KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext(); + MethodComponentContext methodComponentContext = null; KNNEngine knnEngine = KNNEngine.DEFAULT; VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType(); SpaceType spaceType = knnVectorFieldType.getSpaceType(); @@ -463,10 +539,32 @@ protected Query doToQuery(QueryShardContext context) { fieldDimension = modelMetadata.getDimension(); knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); + methodComponentContext = modelMetadata.getMethodComponentContext(); + } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping knnEngine = knnMethodContext.getKnnEngine(); spaceType = knnMethodContext.getSpaceType(); + methodComponentContext = knnMethodContext.getMethodComponentContext(); + } + + final String method = methodComponentContext != null ? methodComponentContext.getName() : null; + if (StringUtils.isNotBlank(method)) { + final EngineSpecificMethodContext engineSpecificMethodContext = knnEngine.getMethodContext(method); + ValidationException validationException = validateParameters( + engineSpecificMethodContext.supportedMethodParameters(), + (Map) methodParameters + ); + if (validationException != null) { + throw new IllegalArgumentException( + String.format( + "Parameters not valid for [%s]:[%s] combination: [%s]", + knnEngine, + method, + validationException.getMessage() + ) + ); + } } // Currently, k-NN supports distance and score types radial search @@ -525,6 +623,7 @@ protected Query doToQuery(QueryShardContext context) { .byteVector(VectorDataType.BYTE == vectorDataType ? byteVector : null) .vectorDataType(vectorDataType) .k(this.k) + .methodParameters(this.methodParameters) .filter(this.filter) .context(context) .build(); @@ -569,41 +668,20 @@ protected boolean doEquals(KNNQueryBuilder other) { return Objects.equals(fieldName, other.fieldName) && Arrays.equals(vector, other.vector) && Objects.equals(k, other.k) + && Objects.equals(minScore, other.minScore) + && Objects.equals(maxDistance, other.maxDistance) + && Objects.equals(methodParameters, other.methodParameters) && Objects.equals(filter, other.filter) && Objects.equals(ignoreUnmapped, other.ignoreUnmapped); } @Override protected int doHashCode() { - return Objects.hash(fieldName, Arrays.hashCode(vector), k, filter, ignoreUnmapped); + return Objects.hash(fieldName, Arrays.hashCode(vector), k, methodParameters, filter, ignoreUnmapped, maxDistance, minScore); } @Override public String getWriteableName() { return NAME; } - - private static VectorQueryType validateSingleQueryType(Integer k, Float distance, Float score) { - int countSetFields = 0; - VectorQueryType vectorQueryType = null; - - if (k != null && k != 0) { - countSetFields++; - vectorQueryType = VectorQueryType.K; - } - if (distance != null) { - countSetFields++; - vectorQueryType = VectorQueryType.MAX_DISTANCE; - } - if (score != null) { - countSetFields++; - vectorQueryType = VectorQueryType.MIN_SCORE; - } - - if (countSetFields != 1) { - throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME)); - } - - return vectorQueryType; - } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java index ec1f53d13..36987c750 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java @@ -17,7 +17,9 @@ import org.opensearch.knn.index.util.KNNEngine; import java.util.Locale; +import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES; @@ -71,6 +73,7 @@ public static Query create(CreateQueryRequest createQueryRequest) { final byte[] byteVector = createQueryRequest.getByteVector(); final VectorDataType vectorDataType = createQueryRequest.getVectorDataType(); final Query filterQuery = getFilterQuery(createQueryRequest); + final Map methodParameters = createQueryRequest.getMethodParameters(); BitSetProducer parentFilter = null; if (createQueryRequest.getContext().isPresent()) { @@ -79,20 +82,37 @@ public static Query create(CreateQueryRequest createQueryRequest) { } if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) { - if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) { - log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", indexName, fieldName, k); - return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter); - } - log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); - return new KNNQuery(fieldName, vector, k, indexName, parentFilter); + final Query validatedFilterQuery = validateFilterQuerySupport(filterQuery, createQueryRequest.getKnnEngine()); + log.debug( + "Creating custom k-NN query for index:{}, field:{}, k:{}, filterQuery:{}, efSearch:{}", + indexName, + fieldName, + k, + validatedFilterQuery, + methodParameters + ); + return KNNQuery.builder() + .field(fieldName) + .queryVector(vector) + .indexName(indexName) + .parentsFilter(parentFilter) + .k(k) + .methodParameters(methodParameters) + .filterQuery(validatedFilterQuery) + .build(); } + Integer requestEfSearch = null; + if (methodParameters != null && methodParameters.containsKey(METHOD_PARAMETER_EF_SEARCH)) { + requestEfSearch = (Integer) methodParameters.get(METHOD_PARAMETER_EF_SEARCH); + } + int luceneK = requestEfSearch == null ? k : Math.max(k, requestEfSearch); log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k)); switch (vectorDataType) { case BYTE: - return getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter); + return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter); case FLOAT: - return getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter); + return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter); default: throw new IllegalArgumentException( String.format( @@ -106,6 +126,14 @@ public static Query create(CreateQueryRequest createQueryRequest) { } } + private static Query validateFilterQuerySupport(final Query filterQuery, final KNNEngine knnEngine) { + log.debug("filter query {}, knnEngine {}", filterQuery, knnEngine); + if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) { + return filterQuery; + } + return null; + } + /** * If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery} * which will dedupe search result per parent so that we can get k parent results at the end. diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index bac8c03d4..794c9af1c 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -283,6 +283,7 @@ private Map doANNSearch(final LeafReaderContext context, final B indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), + knnQuery.getMethodParameters(), knnEngine, filterIds, filterType.getValue(), diff --git a/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java b/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java new file mode 100644 index 000000000..e2ba8f26e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/parser/MethodParametersParser.java @@ -0,0 +1,137 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query.parser; + +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import org.opensearch.common.ValidationException; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.index.query.request.MethodParameter; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +import static org.opensearch.knn.index.query.KNNQueryBuilder.METHOD_PARAMS_FIELD; +import static org.opensearch.knn.index.query.KNNQueryBuilder.NAME; + +@EqualsAndHashCode +@Getter +@AllArgsConstructor +public class MethodParametersParser { + + // Validation on rest layer + public static ValidationException validateMethodParameters(final Map methodParameters) { + final List errors = new ArrayList<>(); + for (final Map.Entry methodParameter : methodParameters.entrySet()) { + final MethodParameter parameter = MethodParameter.enumOf(methodParameter.getKey()); + if (parameter != null) { + final ValidationException validationException = parameter.validate(methodParameter.getValue()); + if (validationException != null) { + errors.add(validationException.getMessage()); + } + } else { // Should never happen if used in the right sequence + errors.add(methodParameter.getKey() + " is not a valid method parameter"); + } + } + + if (!errors.isEmpty()) { + ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errors); + return validationException; + } + return null; + } + + // deserialize for node to node communication + public static Map streamInput(StreamInput in, Function minClusterVersionCheck) throws IOException { + if (!in.readBoolean()) { + return null; + } + + final Map methodParameters = new HashMap<>(); + for (final MethodParameter methodParameter : MethodParameter.values()) { + if (minClusterVersionCheck.apply(methodParameter.getName())) { + String name = in.readString(); + Object value = in.readGenericValue(); + if (value != null) { + methodParameters.put(name, methodParameter.parse(value)); + } + } + } + + return !methodParameters.isEmpty() ? methodParameters : null; + } + + // serialize for node to node communication + public static void streamOutput(StreamOutput out, Map methodParameters, Function minClusterVersionCheck) + throws IOException { + if (methodParameters == null || methodParameters.isEmpty()) { + out.writeBoolean(false); + } else { + out.writeBoolean(true); + // All values are written to deserialize without ambiguity + for (final MethodParameter methodParameter : MethodParameter.values()) { + if (minClusterVersionCheck.apply(methodParameter.getName())) { + out.writeString(methodParameter.getName()); + out.writeGenericValue(methodParameters.get(methodParameter.getName())); + } + } + } + } + + public static void doXContent(final XContentBuilder builder, final Map methodParameters) throws IOException { + if (methodParameters == null || methodParameters.isEmpty()) { + return; + } + builder.startObject(METHOD_PARAMS_FIELD.getPreferredName()); + for (final Map.Entry entry : methodParameters.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + builder.field(entry.getKey(), entry.getValue()); + } + } + builder.endObject(); + } + + public static Map fromXContent(final XContentParser parser) throws IOException { + final Map methodParametersJson = parser.map(); + if (methodParametersJson.isEmpty()) { + throw new ParsingException(parser.getTokenLocation(), METHOD_PARAMS_FIELD.getPreferredName() + " cannot be empty"); + } + + final Map methodParameters = new HashMap<>(); + for (Map.Entry requestParameter : methodParametersJson.entrySet()) { + final String name = requestParameter.getKey(); + final Object value = requestParameter.getValue(); + final MethodParameter parameter = MethodParameter.enumOf(name); + if (parameter == null) { + throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown method parameter found [" + name + "]"); + } + + try { + // This makes sure that we throw parsing exception on rest layer. + methodParameters.put(name, parameter.parse(value)); + } catch (final Exception exception) { + throw new ParsingException(parser.getTokenLocation(), exception.getMessage()); + } + } + return methodParameters.isEmpty() ? null : methodParameters; + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java b/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java new file mode 100644 index 000000000..0b76cad4a --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/request/MethodParameter.java @@ -0,0 +1,77 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query.request; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.Version; +import org.opensearch.common.ValidationException; +import org.opensearch.core.ParseField; + +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; + +/** + * MethodParameters are engine and algorithm related parameters that clients can pass in knn query + * This enum holds metadata which helps parse and have basic validation related to MethodParameter + */ +@Getter +@RequiredArgsConstructor +public enum MethodParameter { + + EF_SEARCH(METHOD_PARAMETER_EF_SEARCH, Version.V_2_16_0, EF_SEARCH_FIELD) { + @Override + public Integer parse(Object value) { + try { + return Integer.parseInt(String.valueOf(value)); + } catch (final NumberFormatException e) { + throw new IllegalArgumentException(METHOD_PARAMETER_EF_SEARCH + " value must be an integer"); + } + } + + @Override + public ValidationException validate(Object value) { + final Integer ef = parse(value); + if (ef != null && ef > 0) { + return null; + } + ; + ValidationException validationException = new ValidationException(); + validationException.addValidationError(METHOD_PARAMETER_EF_SEARCH + " should be greater than 0"); + return validationException; + } + }; + + private final String name; + private final Version version; + private final ParseField parseField; + + private static Map PARAMETERS_DIR; + + public abstract T parse(Object value); + + // These are preliminary validations on rest layer + public abstract ValidationException validate(Object value); + + public static MethodParameter enumOf(final String name) { + if (PARAMETERS_DIR == null) { + PARAMETERS_DIR = new HashMap<>(); + for (final MethodParameter methodParameter : MethodParameter.values()) { + PARAMETERS_DIR.put(methodParameter.name, methodParameter); + } + } + return PARAMETERS_DIR.get(name); + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java index 0fe311094..0e0c56128 100644 --- a/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/AbstractKNNLibrary.java @@ -22,6 +22,7 @@ public abstract class AbstractKNNLibrary implements KNNLibrary { protected final Map methods; + protected final Map engineMethods; @Getter protected final String version; @@ -34,6 +35,15 @@ public KNNMethod getMethod(String methodName) { return method; } + @Override + public EngineSpecificMethodContext getMethodContext(String methodName) { + EngineSpecificMethodContext method = engineMethods.get(methodName); + if (method == null) { + throw new IllegalArgumentException(String.format("Invalid method name: %s", methodName)); + } + return method; + } + @Override public ValidationException validateMethod(KNNMethodContext knnMethodContext) { String methodName = knnMethodContext.getMethodComponentContext().getName(); diff --git a/src/main/java/org/opensearch/knn/index/util/DefaultHnswContext.java b/src/main/java/org/opensearch/knn/index/util/DefaultHnswContext.java new file mode 100644 index 000000000..c16f1b05e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/util/DefaultHnswContext.java @@ -0,0 +1,33 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.util; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.knn.index.Parameter; +import org.opensearch.knn.index.query.request.MethodParameter; + +import java.util.Map; + +/** + * Default HNSW context for all engines. Have a different implementation if engine context differs. + */ +public final class DefaultHnswContext implements EngineSpecificMethodContext { + + private final Map> supportedMethodParameters = ImmutableMap.>builder() + .put(MethodParameter.EF_SEARCH.getName(), new Parameter.IntegerParameter(MethodParameter.EF_SEARCH.getName(), null, value -> true)) + .build(); + + @Override + public Map> supportedMethodParameters() { + return supportedMethodParameters; + } +} diff --git a/src/main/java/org/opensearch/knn/index/util/EngineSpecificMethodContext.java b/src/main/java/org/opensearch/knn/index/util/EngineSpecificMethodContext.java new file mode 100644 index 000000000..f669704ad --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/util/EngineSpecificMethodContext.java @@ -0,0 +1,31 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.util; + +import org.opensearch.knn.index.Parameter; + +import java.util.Collections; +import java.util.Map; + +/** + * Holds context related to a method for a particular engine + * Each engine can have a specific set of parameters that it supports during index and build time. This context holds + * the information for each engine method combination. + * + * TODO: Move KnnMethod in here + */ +public interface EngineSpecificMethodContext { + + Map> supportedMethodParameters(); + + EngineSpecificMethodContext EMPTY = Collections::emptyMap; +} diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index bbb58bf1e..7cf31ba3c 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -330,7 +330,13 @@ private Faiss( String extension, Map> scoreTransform ) { - super(methods, scoreTranslation, currentVersion, extension); + super( + methods, + Map.of(METHOD_HNSW, new DefaultHnswContext(), METHOD_IVF, EngineSpecificMethodContext.EMPTY), + scoreTranslation, + currentVersion, + extension + ); this.scoreTransform = scoreTransform; } diff --git a/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java b/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java index e1d48cb0a..850679dc2 100644 --- a/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/JVMLibrary.java @@ -23,8 +23,8 @@ public abstract class JVMLibrary extends AbstractKNNLibrary { * @param methods Map of k-NN methods that the library supports * @param version String representing version of library */ - JVMLibrary(Map methods, String version) { - super(methods, version); + JVMLibrary(Map methods, Map engineMethodMetadataMap, String version) { + super(methods, engineMethodMetadataMap, version); } @Override diff --git a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java index 556785783..ee8be9c5c 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNEngine.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNEngine.java @@ -149,6 +149,11 @@ public KNNMethod getMethod(String methodName) { return knnLibrary.getMethod(methodName); } + @Override + public EngineSpecificMethodContext getMethodContext(String methodName) { + return knnLibrary.getMethodContext(methodName); + } + @Override public float score(float rawScore, SpaceType spaceType) { return knnLibrary.score(rawScore, spaceType); diff --git a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java index cac5af2bb..f9d8429d3 100644 --- a/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/KNNLibrary.java @@ -58,6 +58,13 @@ public interface KNNLibrary { */ KNNMethod getMethod(String methodName); + /** + * Gets metadata related to methods supported by the library + * @param methodName + * @return + */ + EngineSpecificMethodContext getMethodContext(String methodName); + /** * Generate the Lucene score from the rawScore returned by the library. With k-NN, often times the library * will return a score where the lower the score, the better the result. This is the opposite of how Lucene scores diff --git a/src/main/java/org/opensearch/knn/index/util/Lucene.java b/src/main/java/org/opensearch/knn/index/util/Lucene.java index 630d7a2c2..ae6ea3a70 100644 --- a/src/main/java/org/opensearch/knn/index/util/Lucene.java +++ b/src/main/java/org/opensearch/knn/index/util/Lucene.java @@ -67,7 +67,7 @@ public class Lucene extends JVMLibrary { * @param distanceTransform Map of space type to distance transformation function */ Lucene(Map methods, String version, Map> distanceTransform) { - super(methods, version); + super(methods, Map.of(METHOD_HNSW, new DefaultHnswContext()), version); this.distanceTransform = distanceTransform; } diff --git a/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java b/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java index 5e264ed12..99d4aeeb9 100644 --- a/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java +++ b/src/main/java/org/opensearch/knn/index/util/NativeLibrary.java @@ -36,11 +36,12 @@ abstract class NativeLibrary extends AbstractKNNLibrary { */ NativeLibrary( Map methods, + Map engineMethods, Map> scoreTranslation, String version, String extension ) { - super(methods, version); + super(methods, engineMethods, version); this.scoreTranslation = scoreTranslation; this.extension = extension; this.initialized = new AtomicBoolean(false); diff --git a/src/main/java/org/opensearch/knn/index/util/Nmslib.java b/src/main/java/org/opensearch/knn/index/util/Nmslib.java index 64af43520..7b18ed11d 100644 --- a/src/main/java/org/opensearch/knn/index/util/Nmslib.java +++ b/src/main/java/org/opensearch/knn/index/util/Nmslib.java @@ -66,7 +66,7 @@ private Nmslib( String currentVersion, String extension ) { - super(methods, scoreTranslation, currentVersion, extension); + super(methods, Map.of(METHOD_HNSW, new DefaultHnswContext()), scoreTranslation, currentVersion, extension); } @Override diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 53980bbb7..77b786421 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -129,7 +129,13 @@ public static native void createIndexFromTemplate( * @param parentIds list of parent doc ids when the knn field is a nested field * @return KNNQueryResult array of k neighbors */ - public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, int[] parentIds); + public static native KNNQueryResult[] queryIndex( + long indexPointer, + float[] queryVector, + int k, + Map methodParameters, + int[] parentIds + ); /** * Query an index with filter @@ -145,6 +151,7 @@ public static native KNNQueryResult[] queryIndexWithFilter( long indexPointer, float[] queryVector, int k, + Map methodParameters, long[] filterIds, int filterIdsType, int[] parentIds diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 20c418819..6563da296 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -12,6 +12,7 @@ package org.opensearch.knn.jni; import org.apache.commons.lang.ArrayUtils; +import org.opensearch.common.Nullable; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -173,13 +174,14 @@ public static KNNQueryResult[] queryIndex( long indexPointer, float[] queryVector, int k, + @Nullable Map methodParameters, KNNEngine knnEngine, long[] filteredIds, int filterIdsType, int[] parentIds ) { if (KNNEngine.NMSLIB == knnEngine) { - return NmslibService.queryIndex(indexPointer, queryVector, k); + return NmslibService.queryIndex(indexPointer, queryVector, k, methodParameters); } if (KNNEngine.FAISS == knnEngine) { @@ -188,9 +190,17 @@ public static KNNQueryResult[] queryIndex( // filterIds. FilterIds is coming as empty then its the case where we need to do search with Faiss engine // normally. if (ArrayUtils.isNotEmpty(filteredIds)) { - return FaissService.queryIndexWithFilter(indexPointer, queryVector, k, filteredIds, filterIdsType, parentIds); + return FaissService.queryIndexWithFilter( + indexPointer, + queryVector, + k, + methodParameters, + filteredIds, + filterIdsType, + parentIds + ); } - return FaissService.queryIndex(indexPointer, queryVector, k, parentIds); + return FaissService.queryIndex(indexPointer, queryVector, k, methodParameters, parentIds); } throw new IllegalArgumentException(String.format("QueryIndex not supported for provided engine : %s", knnEngine.getName())); } diff --git a/src/main/java/org/opensearch/knn/jni/NmslibService.java b/src/main/java/org/opensearch/knn/jni/NmslibService.java index 7fdc278d2..294c5a208 100644 --- a/src/main/java/org/opensearch/knn/jni/NmslibService.java +++ b/src/main/java/org/opensearch/knn/jni/NmslibService.java @@ -69,7 +69,7 @@ class NmslibService { * @param k neighbors to be returned * @return KNNQueryResult array of k neighbors */ - public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k); + public static native KNNQueryResult[] queryIndex(long indexPointer, float[] queryVector, int k, Map methodParameters); /** * Free native memory pointer diff --git a/src/main/java/org/opensearch/knn/validation/ParameterValidator.java b/src/main/java/org/opensearch/knn/validation/ParameterValidator.java new file mode 100644 index 000000000..15925fffa --- /dev/null +++ b/src/main/java/org/opensearch/knn/validation/ParameterValidator.java @@ -0,0 +1,64 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.validation; + +import org.opensearch.common.Nullable; +import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.Parameter; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public final class ParameterValidator { + + /** + * A function which validates request parameters. + * @param validParameters A set of valid parameters that can be requestParameters can be validated against + * @param requestParameters parameters from the request + * @return + */ + @Nullable + public static ValidationException validateParameters( + final Map> validParameters, + final Map requestParameters + ) { + + if (validParameters == null) { + throw new IllegalArgumentException("validParameters cannot be null"); + } + + if (requestParameters == null || requestParameters.isEmpty()) { + return null; + } + + final List errorMessages = new ArrayList<>(); + for (Map.Entry parameter : requestParameters.entrySet()) { + if (validParameters.containsKey(parameter.getKey())) { + final ValidationException parameterValidation = validParameters.get(parameter.getKey()).validate(parameter.getValue()); + if (parameterValidation != null) { + errorMessages.addAll(parameterValidation.validationErrors()); + } + } else { + errorMessages.add("Unknown parameter '" + parameter.getKey() + "' found"); + } + } + + if (errorMessages.isEmpty()) { + return null; + } + + final ValidationException validationException = new ValidationException(); + validationException.addValidationErrors(errorMessages); + return validationException; + } +} diff --git a/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java b/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java new file mode 100644 index 000000000..5b7a99ce9 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/FaissHNSWFlatE2EIT.java @@ -0,0 +1,194 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.primitives.Floats; +import lombok.AllArgsConstructor; +import lombok.SneakyThrows; +import org.apache.http.util.EntityUtils; +import org.junit.BeforeClass; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.knn.KNNRestTestCase; +import org.opensearch.knn.KNNResult; +import org.opensearch.knn.TestUtils; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.plugin.script.KNNScoringUtil; + +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; + +@AllArgsConstructor +public class FaissHNSWFlatE2EIT extends KNNRestTestCase { + + private String description; + private int k; + private Map methodParameters; + private boolean deleteRandomDocs; + + static TestUtils.TestData testData; + + @BeforeClass + public static void setUpClass() throws IOException { + if (FaissHNSWFlatE2EIT.class.getClassLoader() == null) { + throw new IllegalStateException("ClassLoader of FaissIT Class is null"); + } + URL testIndexVectors = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_vectors_1000x128.json"); + URL testQueries = FaissHNSWFlatE2EIT.class.getClassLoader().getResource("data/test_queries_100x128.csv"); + assert testIndexVectors != null; + assert testQueries != null; + testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); + } + + @ParametersFactory(argumentFormatting = "description:%1$s; k:%2$s; efSearch:%3$s, deleteDocs:%4$s") + public static Collection parameters() { + return Arrays.asList( + $$( + $("Valid k, valid efSearch efSearch value", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), false), + $("Valid k, efsearch absent", 10, null, false), + $("Has delete docs, ef_search", 10, Map.of(METHOD_PARAMETER_EF_SEARCH, 300), true), + $("Has delete docs", 10, null, true) + ) + ); + } + + @SneakyThrows + public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + + KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); + SpaceType spaceType = SpaceType.L2; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + Integer dimension = testData.indexData.vectors[0].length; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(NAME, hnswMethod.getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + fieldName, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents in the index + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + // Delete few Docs + if (deleteRandomDocs) { + final Set docIdsToBeDeleted = new HashSet<>(); + while (docIdsToBeDeleted.size() < 10) { + docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length - 1)); + } + + for (Integer id : docIdsToBeDeleted) { + deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); + } + refreshAllNonSystemIndices(); + forceMergeKnnIndex(indexName, 3); + + assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); + } + + // Test search queries + for (int i = 0; i < testData.queries.length; i++) { + final KNNQueryBuilder queryBuilder = KNNQueryBuilder.builder() + .fieldName(fieldName) + .vector(testData.queries[i]) + .k(k) + .methodParameters(methodParameters) + .build(); + Response response = searchKNNIndex(indexName, queryBuilder, k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, fieldName); + assertEquals(k, knnResults.size()); + + List actualScores = parseSearchResponseScore(responseBody, fieldName); + for (int j = 0; j < k; j++) { + float[] primitiveArray = knnResults.get(j).getVector(); + assertEquals( + KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), + actualScores.get(j), + 0.0001 + ); + } + } + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } +} diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 3a9b7d596..2e7f772b0 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -37,12 +37,10 @@ import java.net.URL; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashSet; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.TreeMap; import java.util.stream.Collectors; @@ -90,197 +88,6 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); } - @SneakyThrows - public void testEndToEnd_whenMethodIsHNSWFlat_thenSucceed() { - String indexName = "test-index-1"; - String fieldName = "test-field-1"; - - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); - SpaceType spaceType = SpaceType.L2; - - List mValues = ImmutableList.of(16, 32, 64, 128); - List efConstructionValues = ImmutableList.of(16, 32, 64, 128); - List efSearchValues = ImmutableList.of(16, 32, 64, 128); - - Integer dimension = testData.indexData.vectors[0].length; - - // Create an index - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(NAME, hnswMethod.getMethodComponent().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - - Map mappingMap = xContentBuilderToMap(builder); - String mapping = builder.toString(); - - createKnnIndex(indexName, mapping); - assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); - - // Index the test data - for (int i = 0; i < testData.indexData.docs.length; i++) { - addKnnDoc( - indexName, - Integer.toString(testData.indexData.docs[i]), - fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray() - ); - } - - // Assert we have the right number of documents in the index - refreshAllNonSystemIndices(); - assertEquals(testData.indexData.docs.length, getDocCount(indexName)); - - int k = 10; - for (int i = 0; i < testData.queries.length; i++) { - Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); - String responseBody = EntityUtils.toString(response.getEntity()); - List knnResults = parseSearchResponse(responseBody, fieldName); - assertEquals(k, knnResults.size()); - - List actualScores = parseSearchResponseScore(responseBody, fieldName); - for (int j = 0; j < k; j++) { - float[] primitiveArray = knnResults.get(j).getVector(); - assertEquals( - KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), - actualScores.get(j), - 0.0001 - ); - } - } - - // Delete index - deleteKNNIndex(indexName); - - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); - } - - @SneakyThrows - public void testEndToEnd_whenMethodIsHNSWFlatAndHasDeletedDocs_thenSucceed() { - String indexName = "test-index-1"; - String fieldName = "test-field-1"; - - KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); - SpaceType spaceType = SpaceType.L2; - - List mValues = ImmutableList.of(16, 32, 64, 128); - List efConstructionValues = ImmutableList.of(16, 32, 64, 128); - List efSearchValues = ImmutableList.of(16, 32, 64, 128); - - Integer dimension = testData.indexData.vectors[0].length; - - // Create an index - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) - .field(NAME, hnswMethod.getMethodComponent().getName()) - .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - - Map mappingMap = xContentBuilderToMap(builder); - String mapping = builder.toString(); - - createKnnIndex(indexName, mapping); - assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); - - // Index the test data - for (int i = 0; i < testData.indexData.docs.length; i++) { - addKnnDoc( - indexName, - Integer.toString(testData.indexData.docs[i]), - fieldName, - Floats.asList(testData.indexData.vectors[i]).toArray() - ); - } - - // Assert we have the right number of documents in the index - refreshAllNonSystemIndices(); - assertEquals(testData.indexData.docs.length, getDocCount(indexName)); - - final Set docIdsToBeDeleted = new HashSet<>(); - while (docIdsToBeDeleted.size() < 10) { - docIdsToBeDeleted.add(randomInt(testData.indexData.docs.length - 1)); - } - - for (Integer id : docIdsToBeDeleted) { - deleteKnnDoc(indexName, Integer.toString(testData.indexData.docs[id])); - } - refreshAllNonSystemIndices(); - forceMergeKnnIndex(indexName, 3); - - assertEquals(testData.indexData.docs.length - 10, getDocCount(indexName)); - - int k = 10; - for (int i = 0; i < testData.queries.length; i++) { - Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); - String responseBody = EntityUtils.toString(response.getEntity()); - List knnResults = parseSearchResponse(responseBody, fieldName); - assertEquals(k, knnResults.size()); - - List actualScores = parseSearchResponseScore(responseBody, fieldName); - for (int j = 0; j < k; j++) { - float[] primitiveArray = knnResults.get(j).getVector(); - assertEquals( - KNNEngine.FAISS.score(KNNScoringUtil.l2Squared(testData.queries[i], primitiveArray), spaceType), - actualScores.get(j), - 0.0001 - ); - } - } - - // Delete index - deleteKNNIndex(indexName); - - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); - } - @SneakyThrows public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHNSWFlat_thenSucceed() { KNNMethod hnswMethod = KNNEngine.FAISS.getMethod(METHOD_HNSW); @@ -1447,7 +1254,7 @@ public void testDocDeletion() throws IOException { deleteKnnDoc(INDEX_NAME, "1"); } - public void testKNNQuery_withModelDifferentCombination_thenSuccess() throws IOException, InterruptedException { + public void testKNNQuery_withModelDifferentCombination_thenSuccess() throws Exception { String modelId = "test-model"; int dimension = 128; diff --git a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java index bf9a6b776..38895246a 100644 --- a/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java +++ b/src/test/java/org/opensearch/knn/index/LuceneEngineIT.java @@ -122,7 +122,7 @@ public void testQuery_documentsMissingField() throws Exception { validateQueries(spaceType, FIELD_NAME); } - public void testQuery_multipleEngines() throws IOException { + public void testQuery_multipleEngines() throws Exception { String luceneField = "lucene-field"; SpaceType luceneSpaceType = SpaceType.COSINESIMIL; String nmslibField = "nmslib-field"; @@ -175,7 +175,7 @@ public void testQuery_multipleEngines() throws IOException { validateQueries(nmslibSpaceType, nmslibField); } - public void testAddDoc() throws IOException { + public void testAddDoc() throws Exception { List mValues = ImmutableList.of(16, 32, 64, 128); List efConstructionValues = ImmutableList.of(16, 32, 64, 128); @@ -499,13 +499,18 @@ private void baseQueryTest(SpaceType spaceType) throws Exception { } validateQueries(spaceType, FIELD_NAME); + validateQueries(spaceType, FIELD_NAME, Map.of("ef_search", 100)); } - private void validateQueries(SpaceType spaceType, String fieldName) throws IOException { + private void validateQueries(SpaceType spaceType, String fieldName) throws Exception { + validateQueries(spaceType, fieldName, null); + } + + private void validateQueries(SpaceType spaceType, String fieldName, Map methodParameters) throws Exception { int k = LuceneEngineIT.TEST_INDEX_VECTORS.length; for (float[] queryVector : TEST_QUERY_VECTORS) { - Response response = searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(fieldName, queryVector, k), k); + Response response = searchKNNIndex(INDEX_NAME, buildLuceneKSearchQuery(fieldName, k, queryVector, methodParameters), k); String responseBody = EntityUtils.toString(response.getEntity()); List knnResults = parseSearchResponse(responseBody, fieldName); assertEquals(k, knnResults.size()); @@ -520,6 +525,27 @@ private void validateQueries(SpaceType spaceType, String fieldName) throws IOExc } } + @SneakyThrows + private XContentBuilder buildLuceneKSearchQuery(String fieldName, int k, float[] vector, Map methodParams) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(fieldName) + .field("vector", vector) + .field("k", k); + if (methodParams != null) { + builder.startObject("method_parameters"); + for (Map.Entry entry : methodParams.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } + + builder.endObject().endObject().endObject().endObject(); + return builder; + } + private List queryResults(final float[] searchVector, final int k) throws Exception { final String responseBody = EntityUtils.toString( searchKNNIndex(INDEX_NAME, new KNNQueryBuilder(FIELD_NAME, searchVector, k), k).getEntity() diff --git a/src/test/java/org/opensearch/knn/index/NmslibIT.java b/src/test/java/org/opensearch/knn/index/NmslibIT.java index b76a26e69..22168b3e4 100644 --- a/src/test/java/org/opensearch/knn/index/NmslibIT.java +++ b/src/test/java/org/opensearch/knn/index/NmslibIT.java @@ -13,6 +13,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.primitives.Floats; +import lombok.SneakyThrows; import org.apache.http.util.EntityUtils; import org.junit.BeforeClass; import org.opensearch.client.Response; @@ -52,7 +53,77 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); } - public void testEndToEnd() throws IOException, InterruptedException { + public void testInvalidMethodParameters() throws Exception { + String indexName = "test-index-1"; + String fieldName = "test-field-1"; + Integer dimension = testData.indexData.vectors[0].length; + KNNMethod hnswMethod = KNNEngine.NMSLIB.getMethod(KNNConstants.METHOD_HNSW); + SpaceType spaceType = SpaceType.L1; + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName()) + .field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .field(KNNConstants.METHOD_PARAMETER_M, 32) + .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, 100) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + final Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + // Index the test data + // Adding only doc to cut on integ test time + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[0]), + fieldName, + Floats.asList(testData.indexData.vectors[0]).toArray() + ); + + expectThrows( + IllegalArgumentException.class, + () -> searchKNNIndex( + indexName, + KNNQueryBuilder.builder() + .k(10) + .methodParameters(Map.of("foo", "bar")) + .vector(testData.queries[0]) + .fieldName(fieldName) + .build(), + 10 + ) + ); + expectThrows( + IllegalArgumentException.class, + () -> searchKNNIndex( + indexName, + KNNQueryBuilder.builder() + .k(10) + .methodParameters(Map.of("ef_search", "bar")) + .vector(testData.queries[0]) + .fieldName(fieldName) + .build(), + 10 + ) + ); + } + + public void testEndToEnd() throws Exception { String indexName = "test-index-1"; String fieldName = "test-field-1"; @@ -104,9 +175,42 @@ public void testEndToEnd() throws IOException, InterruptedException { refreshAllIndices(); assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + // search index + // without method parameters + validateSearch(indexName, fieldName, spaceType, null); + // With valid method params + validateSearch(indexName, fieldName, spaceType, Map.of("ef_search", 50)); + + // Delete index + deleteKNNIndex(indexName); + + // Search every 5 seconds 14 times to confirm graph gets evicted + int intervals = 14; + for (int i = 0; i < intervals; i++) { + if (getTotalGraphsInCache() == 0) { + return; + } + + Thread.sleep(5 * 1000); + } + + fail("Graphs are not getting evicted"); + } + + @SneakyThrows + private void validateSearch( + final String indexName, + final String fieldName, + SpaceType spaceType, + final Map methodParams + ) { int k = 10; for (int i = 0; i < testData.queries.length; i++) { - Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k); + Response response = searchKNNIndex( + indexName, + KNNQueryBuilder.builder().fieldName(fieldName).vector(testData.queries[i]).k(k).methodParameters(methodParams).build(), + k + ); String responseBody = EntityUtils.toString(response.getEntity()); List knnResults = parseSearchResponse(responseBody, fieldName); assertEquals(k, knnResults.size()); @@ -121,21 +225,6 @@ public void testEndToEnd() throws IOException, InterruptedException { ); } } - - // Delete index - deleteKNNIndex(indexName); - - // Search every 5 seconds 14 times to confirm graph gets evicted - int intervals = 14; - for (int i = 0; i < intervals; i++) { - if (getTotalGraphsInCache() == 0) { - return; - } - - Thread.sleep(5 * 1000); - } - - fail("Graphs are not getting evicted"); } public void testAddDoc() throws Exception { diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index d6ecdb0c6..a751754fd 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -58,7 +58,7 @@ public static void setUpClass() throws IOException { testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath()); } - public void testEndToEnd() throws IOException, InterruptedException { + public void testEndToEnd() throws Exception { String indexName = "test-index-1"; KNNEngine knnEngine1 = KNNEngine.NMSLIB; KNNEngine knnEngine2 = KNNEngine.FAISS; diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 2ce3a7c83..f7c9f3eb8 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -59,6 +59,7 @@ import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.index.KNNSettings.MODEL_CACHE_SIZE_LIMIT_SETTING; @@ -70,6 +71,9 @@ public class KNN80DocValuesConsumerTests extends KNNTestCase { + private static final int EF_SEARCH = 10; + private static final Map HNSW_METHODPARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); + private static Directory directory; private static Codec codec; @@ -202,7 +206,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibCurrent() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by nmslib - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); @@ -255,7 +259,7 @@ public void testAddKNNBinaryField_fromScratch_nmslibLegacy() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by nmslib - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(null, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); @@ -316,7 +320,7 @@ public void testAddKNNBinaryField_fromScratch_faissCurrent() throws IOException assertValidFooter(state.directory, expectedFile); // The document should be readable by faiss - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); @@ -411,7 +415,7 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio assertValidFooter(state.directory, expectedFile); // The document should be readable by faiss - assertLoadableByEngine(state, expectedFile, knnEngine, spaceType, dimension); + assertLoadableByEngine(HNSW_METHODPARAMETERS, state, expectedFile, knnEngine, spaceType, dimension); // The graph creation statistics should be updated assertEquals(1 + initialRefreshOperations, (long) KNNGraphValue.REFRESH_TOTAL_OPERATIONS.getValue()); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index c4d50ec27..80e94caf8 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -326,6 +326,7 @@ public static void assertValidFooter(Directory dir, String filename) throws IOEx } public static void assertLoadableByEngine( + Map methodParameters, SegmentWriteState state, String fileName, KNNEngine knnEngine, @@ -337,7 +338,7 @@ public static void assertLoadableByEngine( long indexPtr = JNIService.loadIndex(filePath, Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), knnEngine); int k = 2; float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, knnEngine, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, methodParameters, knnEngine, null, 0, null); assertTrue(results.length > 0); JNIService.free(indexPtr, knnEngine); } diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index a84974202..876303523 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -76,7 +76,7 @@ public void testIndexLoadStrategy_load() throws IOException { // Confirm that the file was loaded by querying float[] query = new float[dimension]; Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, knnEngine, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, null, knnEngine, null, 0, null); assertTrue(results.length > 0); } diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java new file mode 100644 index 000000000..74c1cca58 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderInvalidParamsTests.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.AllArgsConstructor; +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; + +@AllArgsConstructor +public class KNNQueryBuilderInvalidParamsTests extends KNNTestCase { + + private static final float[] QUERY_VECTOR = new float[] { 1.2f, 2.3f, 4.5f }; + private static final String FIELD_NAME = "test_vector"; + + private String description; + private String expectedMessage; + private KNNQueryBuilder.Builder knnQueryBuilderBuilder; + + @ParametersFactory(argumentFormatting = "description:%1$s; expectedMessage:%2$s; querybuilder:%3$s") + public static Collection invalidParameters() { + return Arrays.asList( + $$( + $("fieldName absent", "[knn] requires fieldName", KNNQueryBuilder.builder().k(1).vector(QUERY_VECTOR)), + $("vector absent", "[knn] requires query vector", KNNQueryBuilder.builder().k(1).fieldName(FIELD_NAME)), + $( + "vector empty", + "[knn] query vector is empty", + KNNQueryBuilder.builder().k(1).fieldName(FIELD_NAME).vector(new float[] {}) + ), + $( + "Neither knn nor radial search", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR) + ), + $( + "max distance and k present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(1).maxDistance(10f) + ), + $( + "min_score and k present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(1).minScore(1.0f) + ), + $( + "max_dist and min_score present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).maxDistance(1.0f).minScore(1.0f) + ), + $( + "max_dist, k and min_score present", + "[knn] requires exactly one of k, distance or score to be set", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(1).maxDistance(1.0f).minScore(1.0f) + ), + $( + "-ve k value", + "[knn] requires k to be in the range (0, 10000]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(-1) + ), + $( + "k value greater than max", + "[knn] requires k to be in the range (0, 10000]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(10001) + ), + $( + "efSearch 0", + "[knn] errors in method parameter [Validation Failed: 1: Validation Failed: 1: ef_search should be greater than 0;;]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).methodParameters(Map.of("ef_search", 0)).k(10) + ), + $( + "efSearch -ve", + "[knn] errors in method parameter [Validation Failed: 1: Validation Failed: 1: ef_search should be greater than 0;;]", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).methodParameters(Map.of("ef_search", -10)).k(10) + ), + $( + "min score less than 0", + "[knn] requires minScore to be greater than 0", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(-1f) + ) + ) + ); + } + + public void testInvalidBuilder() { + Throwable exception = expectThrows(IllegalArgumentException.class, () -> knnQueryBuilderBuilder.build()); + assertEquals(expectedMessage, expectedMessage, exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 4b9872131..10c0155ae 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -7,7 +7,6 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.FloatVectorSimilarityQuery; -import java.util.Locale; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; @@ -15,21 +14,21 @@ import org.opensearch.cluster.ClusterModule; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.index.Index; import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.index.IndexSettings; +import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.knn.KNNTestCase; -import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.core.index.Index; -import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; @@ -46,22 +45,26 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +import static java.util.Collections.emptyMap; import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO; -import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; +import static org.opensearch.knn.index.query.KNNQueryBuilder.EF_SEARCH_FIELD; import static org.opensearch.knn.index.util.KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH; public class KNNQueryBuilderTests extends KNNTestCase { private static final String FIELD_NAME = "myvector"; private static final int K = 1; + private static final int EF_SEARCH = 10; + private static final Map HNSW_METHOD_PARAMS = Map.of("ef_search", EF_SEARCH); private static final Float MAX_DISTANCE = 1.0f; private static final Float MIN_SCORE = 0.5f; private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("field", "value"); @@ -91,7 +94,10 @@ public void testInvalidDistance() { /** * null distance */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(null)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).maxDistance(null).build() + ); } public void testInvalidScore() { @@ -99,17 +105,26 @@ public void testInvalidScore() { /** * null min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(null)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(null).build() + ); /** * negative min_score */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(-1.0f)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(-1.0f).build() + ); /** * min_score = 0 */ - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(0.0f)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(0.0f).build() + ); } public void testEmptyVector() { @@ -129,13 +144,19 @@ public void testEmptyVector() { * null query vector with distance */ float[] queryVector2 = null; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector2).maxDistance(MAX_DISTANCE)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector2).maxDistance(MAX_DISTANCE).build() + ); /** * empty query vector with distance */ float[] queryVector3 = {}; - expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector3).maxDistance(MAX_DISTANCE)); + expectThrows( + IllegalArgumentException.class, + () -> KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector3).maxDistance(MAX_DISTANCE).build() + ); } public void testFromXContent() throws Exception { @@ -154,9 +175,37 @@ public void testFromXContent() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_KnnWithMethodParameters() throws Exception { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -172,12 +221,16 @@ public void testFromXContent_whenDoRadiusSearch_whenDistanceThreshold_thenSuccee public void testFromXContent_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MAX_DISTANCE); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .minScore(MAX_DISTANCE) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.endObject(); builder.endObject(); XContentParser contentParser = createParser(builder); @@ -208,6 +261,37 @@ public void testFromXContent_withFilter() throws Exception { assertEquals(knnQueryBuilder, actualBuilder); } + public void testFromXContent_KnnWithEfSearch_withFilter() throws Exception { + final ClusterService clusterService = mockClusterService(Version.CURRENT); + + final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); + knnClusterUtil.initialize(clusterService); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder.startObject(); + builder.startObject(knnQueryBuilder.fieldName()); + builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); + builder.field(KNNQueryBuilder.K_FIELD.getPreferredName(), knnQueryBuilder.getK()); + builder.startObject(org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER); + builder.field(EF_SEARCH_FIELD.getPreferredName(), EF_SEARCH); + builder.endObject(); + builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); + builder.endObject(); + builder.endObject(); + XContentParser contentParser = createParser(builder); + contentParser.nextToken(); + KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); + assertEquals(knnQueryBuilder, actualBuilder); + } + public void testFromXcontent_WithFilter_UnsupportedClusterVersion() throws Exception { final ClusterService clusterService = mockClusterService(Version.V_2_3_0); @@ -237,7 +321,13 @@ public void testFromXContent_wenDoRadiusSearch_whenDistanceThreshold_whenFilter_ knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .filter(TERM_QUERY) + .build(); + XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); @@ -259,12 +349,17 @@ public void testFromXContent_wenDoRadiusSearch_whenScoreThreshold_whenFilter_the knnClusterUtil.initialize(clusterService); float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .minScore(MIN_SCORE) + .filter(TERM_QUERY) + .build(); XContentBuilder builder = XContentFactory.jsonBuilder(); builder.startObject(); builder.startObject(knnQueryBuilder.fieldName()); builder.field(KNNQueryBuilder.VECTOR_FIELD.getPreferredName(), knnQueryBuilder.vector()); - builder.field(KNNQueryBuilder.MAX_DISTANCE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); + builder.field(KNNQueryBuilder.MIN_SCORE_FIELD.getPreferredName(), knnQueryBuilder.getMinScore()); builder.field(KNNQueryBuilder.FILTER_FIELD.getPreferredName(), knnQueryBuilder.getFilter()); builder.endObject(); builder.endObject(); @@ -406,7 +501,11 @@ public void testDoToQuery_Normal() throws Exception { public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -414,7 +513,10 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); @@ -422,13 +524,19 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenDistanceThreshold_th assertTrue(query.toString().contains("resultSimilarity=" + resultSimilarity)); assertTrue( - query.toString().contains("traversalSimilarity=" + DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity) + query.toString() + .contains( + "traversalSimilarity=" + + org.opensearch.knn.common.KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO * resultSimilarity + ) ); } public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(MIN_SCORE).build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -436,7 +544,10 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); FloatVectorSimilarityQuery query = (FloatVectorSimilarityQuery) knnQueryBuilder.doToQuery(mockQueryShardContext); @@ -446,7 +557,12 @@ public void testDoToQuery_whenNormal_whenDoRadiusSearch_whenScoreThreshold_thenS public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -454,7 +570,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) ); @@ -470,7 +589,12 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -478,7 +602,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) ); @@ -492,7 +619,8 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassNegativeDistance_whenUnSupp public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(score).build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -500,7 +628,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) ); @@ -516,7 +647,7 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenSuppor public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float score = 5f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(score); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(score).build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -524,7 +655,10 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) ); @@ -538,7 +672,12 @@ public void testDoToQuery_whenDoRadiusSearch_whenPassScoreMoreThanOne_whenUnsupp public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -546,7 +685,10 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.INNER_PRODUCT, methodComponentContext) ); @@ -562,7 +704,13 @@ public void testDoToQuery_whenPassNegativeDistance_whenSupportedSpaceType_thenSu public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; float negativeDistance = -1.0f; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(negativeDistance); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(negativeDistance) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -570,7 +718,10 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); when(mockKNNVectorField.getKnnMethodContext()).thenReturn( new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext) ); @@ -581,9 +732,15 @@ public void testDoToQuery_whenPassNegativeDistance_whenUnSupportedSpaceType_then expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); } - public void testDoToQuery_KnnQueryWithFilter() throws Exception { + public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception { + // Given float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -591,25 +748,42 @@ public void testDoToQuery_KnnQueryWithFilter() throws Exception { when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + // When Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + + // Then assertNotNull(query); assertTrue(query.getClass().isAssignableFrom(KnnFloatVectorQuery.class)); } public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE).filter(TERM_QUERY); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .filter(TERM_QUERY) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -620,14 +794,22 @@ public void testDoToQuery_whenDoRadiusSearch_whenDistanceThreshold_whenFilter_th public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE).filter(TERM_QUERY); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .filter(TERM_QUERY) + .build(); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.LUCENE, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -637,23 +819,61 @@ public void testDoToQuery_whenDoRadiusSearch_whenScoreThreshold_whenFilter_thenS } public void testDoToQuery_WhenknnQueryWithFilterAndFaissEngine_thenSuccess() { + // Given float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + + // When + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .filter(TERM_QUERY) + .methodParameters(HNSW_METHOD_PARAMS) + .build(); + Query query = knnQueryBuilder.doToQuery(mockQueryShardContext); + + // Then assertNotNull(query); assertTrue(query.getClass().isAssignableFrom(KNNQuery.class)); + assertEquals(HNSW_METHOD_PARAMS, ((KNNQuery) query).getMethodParameters()); } + /** This test should be uncommented once we have nprobs. Considering engine instance is static its not possible to test this right now + public void testDoToQuery_ThrowsIllegalArgumentExceptionForUnknownMethodParameter() { + + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockKNNVectorField.getKnnMethodContext()).thenReturn( + new KNNMethodContext(KNNEngine.LUCENE, SpaceType.COSINESIMIL, new MethodComponentContext("hnsw", Map.of())) + ); + + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .methodParameters(Map.of("ef_search", 10)) + .build(); + + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext)); + }**/ + public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY); @@ -663,7 +883,10 @@ public void testDoToQuery_whenknnQueryWithFilterAndNmsLibEngine_thenException() when(mockQueryShardContext.index()).thenReturn(dummyIndex); when(mockKNNVectorField.getDimension()).thenReturn(4); when(mockKNNVectorField.getSpaceType()).thenReturn(SpaceType.L2); - MethodComponentContext methodComponentContext = new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()); + MethodComponentContext methodComponentContext = new MethodComponentContext( + org.opensearch.knn.common.KNNConstants.METHOD_HNSW, + ImmutableMap.of() + ); KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.NMSLIB, SpaceType.L2, methodComponentContext); when(mockKNNVectorField.getKnnMethodContext()).thenReturn(knnMethodContext); when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField); @@ -691,6 +914,7 @@ public void testDoToQuery_FromModel() { when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -704,7 +928,13 @@ public void testDoToQuery_FromModel() { public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).maxDistance(MAX_DISTANCE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .maxDistance(MAX_DISTANCE) + .build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -721,6 +951,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -737,7 +968,9 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_thenSucceed() { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector).minScore(MIN_SCORE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(queryVector).minScore(MIN_SCORE).build(); + Index dummyIndex = new Index("dummy", "dummy"); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); @@ -754,6 +987,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -837,39 +1071,38 @@ public void testDoToQuery_InvalidZeroByteVector() { public void testSerialization() throws Exception { // For k-NN search - assertSerialization(Version.CURRENT, Optional.empty(), K, null, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, Map.of("ef_search", EF_SEARCH), null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, Map.of("ef_search", EF_SEARCH), null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null); // For distance threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, MAX_DISTANCE, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MAX_DISTANCE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MAX_DISTANCE); // For score threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, MIN_SCORE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MIN_SCORE); - } - - public void testIgnoreUnmapped() throws IOException { - float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); - knnQueryBuilder.ignoreUnmapped(true); - assertTrue(knnQueryBuilder.getIgnoreUnmapped()); - Query query = knnQueryBuilder.doToQuery(mock(QueryShardContext.class)); - assertNotNull(query); - assertThat(query, instanceOf(MatchNoDocsQuery.class)); - knnQueryBuilder.ignoreUnmapped(false); - expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.doToQuery(mock(QueryShardContext.class))); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MIN_SCORE); } private void assertSerialization( final Version version, final Optional queryBuilderOptional, Integer k, + Map methodParameters, Float distance, Float score ) throws Exception { - final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance, score); + final KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .maxDistance(distance) + .minScore(score) + .k(k) + .methodParameters(methodParameters) + .filter(queryBuilderOptional.orElse(null)) + .build(); final ClusterService clusterService = mockClusterService(version); @@ -901,28 +1134,34 @@ private void assertSerialization( } else { assertNull(deserializedKnnQueryBuilder.getFilter()); } + assertMethodParameters(version, methodParameters, deserializedKnnQueryBuilder.getMethodParameters()); } } } - private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBuilderOptional, Integer k, Float distance, Float score) { - final KNNQueryBuilder knnQueryBuilder; - if (k != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k, queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k); - } else if (distance != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(distance); - } else if (score != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score).filter(queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).minScore(score); - } else { - throw new IllegalArgumentException("Either k or distance must be provided"); + private void assertMethodParameters(Version version, Map expectedMethodParameters, Map actualMethodParameters) { + if (!version.onOrAfter(Version.V_2_16_0)) { + assertNull(actualMethodParameters); + } else if (expectedMethodParameters != null) { + if (version.onOrAfter(Version.V_2_16_0)) { + assertEquals(expectedMethodParameters.get("ef_search"), actualMethodParameters.get("ef_search")); + } } - return knnQueryBuilder; + } + + public void testIgnoreUnmapped() throws IOException { + float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; + KNNQueryBuilder.Builder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(queryVector) + .k(K) + .ignoreUnmapped(true); + assertTrue(knnQueryBuilder.build().isIgnoreUnmapped()); + Query query = knnQueryBuilder.build().doToQuery(mock(QueryShardContext.class)); + assertNotNull(query); + assertThat(query, instanceOf(MatchNoDocsQuery.class)); + knnQueryBuilder.ignoreUnmapped(false); + expectThrows(IllegalArgumentException.class, () -> knnQueryBuilder.build().doToQuery(mock(QueryShardContext.class))); } public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { @@ -933,9 +1172,15 @@ public void testRadialSearch_whenUnsupportedEngine_thenThrowException() { KNNMethodContext knnMethodContext = new KNNMethodContext( knnEngine, SpaceType.L2, - new MethodComponentContext(METHOD_HNSW, ImmutableMap.of()) + new MethodComponentContext(org.opensearch.knn.common.KNNConstants.METHOD_HNSW, ImmutableMap.of()) ); - KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR).maxDistance(MAX_DISTANCE); + + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .maxDistance(MAX_DISTANCE) + .build(); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); Index dummyIndex = new Index("dummy", "dummy"); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java new file mode 100644 index 000000000..4b97df4b4 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderValidParamsTests.java @@ -0,0 +1,90 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import lombok.AllArgsConstructor; +import org.opensearch.knn.KNNTestCase; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import static com.carrotsearch.randomizedtesting.RandomizedTest.$; +import static com.carrotsearch.randomizedtesting.RandomizedTest.$$; + +@AllArgsConstructor +public class KNNQueryBuilderValidParamsTests extends KNNTestCase { + + private static final float[] QUERY_VECTOR = new float[] { 1.2f, 2.3f, 4.5f }; + private static final String FIELD_NAME = "test_vector"; + + private String description; + private KNNQueryBuilder expected; + private Integer k; + private Map methodParameters; + private Float maxDistance; + private Float minScore; + + @ParametersFactory(argumentFormatting = "description:%1$s; k:%3$s, efSearch:%4$s, maxDist:%5$s, minScore:%6$s") + public static Collection validParameters() { + return Arrays.asList( + $$( + $( + "valid knn with k", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).k(10).build(), + 10, + null, + null, + null + ), + $( + "valid knn with k and efSearch", + KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .k(10) + .methodParameters(Map.of("ef_search", 12)) + .build(), + 10, + Map.of("ef_search", 12), + null, + null + ), + $( + "valid knn with maxDis", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).maxDistance(10.0f).build(), + null, + null, + 10.0f, + null + ), + $( + "valid knn with minScore", + KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(10.0f).build(), + null, + null, + null, + 10.0f + ) + ) + ); + } + + public void testValidBuilder() { + assertEquals( + expected, + KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .k(k) + .methodParameters(methodParameters) + .maxDistance(maxDistance) + .minScore(minScore) + .build() + ); + } +} diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java index 1bb17cfae..56e81d237 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryFactoryTests.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.query; import org.apache.lucene.index.Term; +import org.apache.lucene.search.KnnByteVectorQuery; import org.apache.lucene.search.KnnFloatVectorQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.TermQuery; @@ -27,12 +28,14 @@ import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.DEFAULT_VECTOR_DATA_TYPE_FIELD; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; public class KNNQueryFactoryTests extends KNNTestCase { private static final String FILTER_FILED_NAME = "foo"; @@ -45,6 +48,7 @@ public class KNNQueryFactoryTests extends KNNTestCase { private final String testIndexName = "test-index"; private final String testFieldName = "test-field"; private final int testK = 10; + private final Map methodParameters = Map.of(METHOD_PARAMETER_EF_SEARCH, 100); public void testCreateCustomKNNQuery() { for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { @@ -82,6 +86,98 @@ public void testCreateLuceneDefaultQuery() { } } + public void testLuceneFloatVectorQuery() { + Query actualQuery1 = KNNQueryFactory.create( + BaseQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.LUCENE) + .vector(testQueryVector) + .k(testK) + .indexName(testIndexName) + .fieldName(testFieldName) + .methodParameters(methodParameters) + .vectorDataType(VectorDataType.FLOAT) + .build() + ); + + // efsearch > k + Query expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, 100, null); + assertEquals(expectedQuery1, actualQuery1); + + // efsearch < k + actualQuery1 = KNNQueryFactory.create( + BaseQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.LUCENE) + .vector(testQueryVector) + .k(testK) + .indexName(testIndexName) + .fieldName(testFieldName) + .methodParameters(Map.of("ef_search", 1)) + .vectorDataType(VectorDataType.FLOAT) + .build() + ); + expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null); + assertEquals(expectedQuery1, actualQuery1); + + actualQuery1 = KNNQueryFactory.create( + BaseQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.LUCENE) + .vector(testQueryVector) + .k(testK) + .indexName(testIndexName) + .fieldName(testFieldName) + .vectorDataType(VectorDataType.FLOAT) + .build() + ); + expectedQuery1 = new KnnFloatVectorQuery(testFieldName, testQueryVector, testK, null); + assertEquals(expectedQuery1, actualQuery1); + } + + public void testLuceneByteVectorQuery() { + Query actualQuery1 = KNNQueryFactory.create( + BaseQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.LUCENE) + .byteVector(testByteQueryVector) + .k(testK) + .indexName(testIndexName) + .fieldName(testFieldName) + .methodParameters(methodParameters) + .vectorDataType(VectorDataType.BYTE) + .build() + ); + + // efsearch > k + Query expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, 100, null); + assertEquals(expectedQuery1, actualQuery1); + + // efsearch < k + actualQuery1 = KNNQueryFactory.create( + BaseQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.LUCENE) + .byteVector(testByteQueryVector) + .k(testK) + .indexName(testIndexName) + .fieldName(testFieldName) + .methodParameters(Map.of("ef_search", 1)) + .vectorDataType(VectorDataType.BYTE) + .build() + ); + expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null); + assertEquals(expectedQuery1, actualQuery1); + + actualQuery1 = KNNQueryFactory.create( + BaseQueryFactory.CreateQueryRequest.builder() + .knnEngine(KNNEngine.LUCENE) + .byteVector(testByteQueryVector) + .k(testK) + .indexName(testIndexName) + .fieldName(testFieldName) + .vectorDataType(VectorDataType.BYTE) + .build() + ); + expectedQuery1 = new KnnByteVectorQuery(testFieldName, testByteQueryVector, testK, null); + assertEquals(expectedQuery1, actualQuery1); + } + public void testCreateLuceneQueryWithFilter() { List luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values()) .filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine)) @@ -106,28 +202,71 @@ public void testCreateLuceneQueryWithFilter() { } public void testCreateFaissQueryWithFilter_withValidValues_thenSuccess() { + // Given final KNNEngine knnEngine = KNNEngine.FAISS; final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); MappedFieldType testMapper = mock(MappedFieldType.class); when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + + final KNNQuery expectedQuery = KNNQuery.builder() + .indexName(testIndexName) + .filterQuery(FILTER_QUERY) + .field(testFieldName) + .queryVector(testQueryVector) + .k(testK) + .methodParameters(methodParameters) + .build(); + + // When final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() .knnEngine(knnEngine) .indexName(testIndexName) .fieldName(testFieldName) .vector(testQueryVector) .k(testK) + .methodParameters(methodParameters) .context(mockQueryShardContext) .filter(FILTER_QUERY_BUILDER) .build(); - final Query query = KNNQueryFactory.create(createQueryRequest); - assertTrue(query instanceof KNNQuery); - assertEquals(testIndexName, ((KNNQuery) query).getIndexName()); - assertEquals(testFieldName, ((KNNQuery) query).getField()); - assertEquals(testQueryVector, ((KNNQuery) query).getQueryVector()); - assertEquals(testK, ((KNNQuery) query).getK()); - assertEquals(FILTER_QUERY, ((KNNQuery) query).getFilterQuery()); + final Query actual = KNNQueryFactory.create(createQueryRequest); + + // Then + assertEquals(expectedQuery, actual); + } + + public void testCreateFaissQueryWithFilter_withValidValues_nullEfSearch_thenSuccess() { + // Given + final KNNEngine knnEngine = KNNEngine.FAISS; + final QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + MappedFieldType testMapper = mock(MappedFieldType.class); + when(mockQueryShardContext.fieldMapper(any())).thenReturn(testMapper); + when(testMapper.termQuery(Mockito.any(), Mockito.eq(mockQueryShardContext))).thenReturn(FILTER_QUERY); + + final KNNQuery expectedQuery = KNNQuery.builder() + .indexName(testIndexName) + .filterQuery(FILTER_QUERY) + .field(testFieldName) + .queryVector(testQueryVector) + .k(testK) + .build(); + + // When + final KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder() + .knnEngine(knnEngine) + .indexName(testIndexName) + .fieldName(testFieldName) + .vector(testQueryVector) + .k(testK) + .context(mockQueryShardContext) + .filter(FILTER_QUERY_BUILDER) + .build(); + + final Query actual = KNNQueryFactory.create(createQueryRequest); + + // Then + assertEquals(expectedQuery, actual); } public void testCreate_whenLuceneWithParentFilter_thenReturnDiversifyingQuery() { diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 0d15b5f5f..021e3a825 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -67,11 +67,13 @@ import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; @@ -83,6 +85,8 @@ public class KNNWeightTests extends KNNTestCase { private static final Set SEGMENT_FILES_NMSLIB = Set.of("_0.cfe", "_0_2011_target_field.hnswc"); private static final Set SEGMENT_FILES_FAISS = Set.of("_0.cfe", "_0_2011_target_field.faissc"); private static final String CIRCUIT_BREAKER_LIMIT_100KB = "100Kb"; + private static final Integer EF_SEARCH = 10; + private static final Map HNSW_METHOD_PARAMETERS = Map.of(METHOD_PARAMETER_EF_SEARCH, EF_SEARCH); private static final Map DOC_ID_TO_SCORES = Map.of(10, 0.4f, 101, 0.05f, 100, 0.8f, 50, 0.52f); private static final Map FILTERED_DOC_ID_TO_SCORES = Map.of(101, 0.05f, 100, 0.8f, 50, 0.52f); @@ -159,7 +163,7 @@ public void testQueryScoreForFaissWithModel() { SpaceType spaceType = SpaceType.L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), eq(K), isNull(), any(), any(), anyInt(), any())) .thenReturn(getKNNQueryResults()); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -303,7 +307,7 @@ public void testShardWithoutFiles() { @SneakyThrows public void testEmptyQueryResults() { final KNNQueryResult[] knnQueryResults = new KNNQueryResult[] {}; - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), any())) + jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), eq(K), isNull(), any(), any(), anyInt(), any())) .thenReturn(knnQueryResults); final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -346,6 +350,7 @@ public void testEmptyQueryResults() { @SneakyThrows public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { + // Given int k = 3; final int[] filterDocIds = new int[] { 0, 1, 2, 3, 4, 5 }; FixedBitSet filterBitSet = new FixedBitSet(filterDocIds.length); @@ -353,7 +358,16 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { filterBitSet.set(docId); } jniServiceMockedStatic.when( - () -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), eq(filterBitSet.getBits()), anyInt(), any()) + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(filterBitSet.getBits()), + anyInt(), + any() + ) ).thenReturn(getFilteredKNNQueryResults()); final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); final SegmentReader reader = mock(SegmentReader.class); @@ -366,7 +380,15 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { when(liveDocsBits.length()).thenReturn(1000); when(leafReaderContext.reader()).thenReturn(reader); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, k, INDEX_NAME, FILTER_QUERY, null); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(k) + .indexName(INDEX_NAME) + .filterQuery(FILTER_QUERY) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final Weight filterQueryWeight = mock(Weight.class); final Scorer filterScorer = mock(Scorer.class); when(filterQueryWeight.scorer(leafReaderContext)).thenReturn(filterScorer); @@ -406,15 +428,26 @@ public void testANNWithFilterQuery_whenDoingANN_thenSuccess() { when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); + // When final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); - assertNotNull(knnScorer); + // Then + assertNotNull(knnScorer); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); assertEquals(FILTERED_DOC_ID_TO_SCORES.size(), docIdSetIterator.cost()); jniServiceMockedStatic.verify( - () -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), eq(filterBitSet.getBits()), anyInt(), any()) + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(k), + eq(HNSW_METHOD_PARAMETERS), + any(), + eq(filterBitSet.getBits()), + anyInt(), + any() + ) ); final List actualDocIds = new ArrayList<>(); @@ -677,17 +710,47 @@ public void testANNWithParentsFilter_whenDoingANN_thenBitSetIsPassedToJNI() { // Prepare query and weight when(bitSetProducer.getBitSet(leafReaderContext)).thenReturn(bitset); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, 1, INDEX_NAME, null, bitSetProducer); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(1) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .parentsFilter(bitSetProducer) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 0.0f, null); - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), eq(parentsFilter))) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(1), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + eq(parentsFilter) + ) + ).thenReturn(getKNNQueryResults()); // Execute Scorer knnScorer = knnWeight.scorer(leafReaderContext); // Verify - jniServiceMockedStatic.verify(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), eq(parentsFilter))); + jniServiceMockedStatic.verify( + () -> JNIService.queryIndex( + anyLong(), + eq(QUERY_VECTOR), + eq(1), + eq(HNSW_METHOD_PARAMETERS), + any(), + any(), + anyInt(), + eq(parentsFilter) + ) + ); assertNotNull(knnScorer); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertNotNull(docIdSetIterator); @@ -811,10 +874,17 @@ private void testQueryScore( final Set segmentFiles, final Map fileAttributes ) throws IOException { - jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), anyInt(), any(), any(), anyInt(), any())) - .thenReturn(getKNNQueryResults()); + jniServiceMockedStatic.when( + () -> JNIService.queryIndex(anyLong(), eq(QUERY_VECTOR), eq(K), eq(HNSW_METHOD_PARAMETERS), any(), any(), anyInt(), any()) + ).thenReturn(getKNNQueryResults()); - final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(QUERY_VECTOR) + .k(K) + .indexName(INDEX_NAME) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); final float boost = (float) randomDoubleBetween(0, 10, true); final KNNWeight knnWeight = new KNNWeight(query, boost); diff --git a/src/test/java/org/opensearch/knn/index/query/parser/MethodParametersParserTests.java b/src/test/java/org/opensearch/knn/index/query/parser/MethodParametersParserTests.java new file mode 100644 index 000000000..f2924fb4f --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/query/parser/MethodParametersParserTests.java @@ -0,0 +1,84 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.query.parser; + +import lombok.SneakyThrows; +import org.opensearch.common.ValidationException; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.knn.KNNTestCase; + +import java.util.Map; + +import static org.opensearch.knn.index.query.parser.MethodParametersParser.doXContent; +import static org.opensearch.knn.index.query.parser.MethodParametersParser.validateMethodParameters; + +public class MethodParametersParserTests extends KNNTestCase { + + public void testValidateMethodParameters() { + ValidationException validationException = validateMethodParameters(Map.of("dummy", 0)); + assertEquals("Validation Failed: 1: dummy is not a valid method parameter;", validationException.getMessage()); + + ValidationException validationException2 = validateMethodParameters(Map.of("ef_search", 0)); + assertTrue(validationException2.getMessage().contains("Validation Failed: 1: ef_search should be greater than 0")); + + ValidationException validationException3 = validateMethodParameters(Map.of("ef_search", 10)); + assertNull(validationException3); + } + + @SneakyThrows + public void testDoXContent() { + Map params = Map.of("ef_search", 10); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("method_parameters") + .field("ef_search", 10) + .endObject() + .endObject(); + + XContentBuilder builder2 = XContentFactory.jsonBuilder().startObject(); + doXContent(builder2, params); + builder2.endObject(); + assertEquals(builder.toString(), builder2.toString()); + + XContentBuilder b3 = XContentFactory.jsonBuilder(); + XContentBuilder b4 = XContentFactory.jsonBuilder(); + + doXContent(b4, null); + assertEquals(b3.toString(), b4.toString()); + } + + @SneakyThrows + public void testFromXContent() { + // efsearch string + XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("ef_search", "string").endObject(); + XContentParser parser1 = createParser(builder); + expectThrows(ParsingException.class, () -> MethodParametersParser.fromXContent(parser1)); + + // unknown method parameter + builder = XContentFactory.jsonBuilder().startObject().field("unknown", "10").endObject(); + XContentParser parser2 = createParser(builder); + expectThrows(ParsingException.class, () -> MethodParametersParser.fromXContent(parser2)); + + // Valid + builder = XContentFactory.jsonBuilder().startObject().field("ef_search", 10).endObject(); + XContentParser parser3 = createParser(builder); + assertEquals(Map.of("ef_search", 10), MethodParametersParser.fromXContent(parser3)); + + // empty map + builder = XContentFactory.jsonBuilder().startObject().endObject(); + XContentParser parser4 = createParser(builder); + expectThrows(ParsingException.class, () -> MethodParametersParser.fromXContent(parser4)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java index 9e6bd67ea..0aab78042 100644 --- a/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/AbstractKNNLibraryTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponent; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.Parameter; import org.opensearch.knn.index.SpaceType; import java.io.IOException; @@ -77,6 +78,20 @@ public ValidationException validate(KNNMethodContext knnMethodContext) { assertNotNull(testAbstractKNNLibrary2.validateMethod(knnMethodContext2)); } + public void testEngineSpecificMethods() throws IOException { + String methodName1 = "test-method-1"; + EngineSpecificMethodContext context = () -> Map.of("myparameter", new Parameter.BooleanParameter("myparameter", false, o -> o)); + + TestAbstractKNNLibrary testAbstractKNNLibrary1 = new TestAbstractKNNLibrary( + Collections.emptyMap(), + Map.of(methodName1, context), + "" + ); + + assertNotNull(testAbstractKNNLibrary1.getMethodContext(methodName1)); + assertTrue(testAbstractKNNLibrary1.getMethodContext(methodName1).supportedMethodParameters().containsKey("myparameter")); + } + public void testGetMethodAsMap() { String methodName = "test-method-1"; SpaceType spaceType = SpaceType.DEFAULT; @@ -109,7 +124,15 @@ public void testGetMethodAsMap() { private static class TestAbstractKNNLibrary extends AbstractKNNLibrary { public TestAbstractKNNLibrary(Map methods, String currentVersion) { - super(methods, currentVersion); + super(methods, Collections.emptyMap(), currentVersion); + } + + public TestAbstractKNNLibrary( + Map methods, + Map engineSpecificMethodContextMap, + String currentVersion + ) { + super(methods, engineSpecificMethodContextMap, currentVersion); } @Override diff --git a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java index 3c3afbee6..814712560 100644 --- a/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java +++ b/src/test/java/org/opensearch/knn/index/util/NativeLibraryTests.java @@ -62,7 +62,7 @@ public TestNativeLibrary( String currentVersion, String extension ) { - super(methods, scoreTranslation, currentVersion, extension); + super(methods, Collections.emptyMap(), scoreTranslation, currentVersion, extension); } @Override diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index d6ae13e92..e71930d48 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -527,6 +527,7 @@ public void testQueryIndex_faiss_sqfp16_valid() { String sqfp16IndexDescription = "HNSW16,SQfp16"; int k = 10; + Map methodParameters = Map.of("ef_search", 12); float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); Path tmpFile = createTempFile(); @@ -544,13 +545,22 @@ public void testQueryIndex_faiss_sqfp16_valid() { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, methodParameters, KNNEngine.FAISS, null, 0, null); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, new long[] { 0 }, 0, null); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + methodParameters, + KNNEngine.FAISS, + new long[] { 0 }, + 0, + null + ); assertEquals(0, results.length); } } @@ -736,12 +746,15 @@ public void testLoadIndex_faiss_valid() throws IOException { } public void testQueryIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.LUCENE, null, 0, null)); + expectThrows( + IllegalArgumentException.class, + () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.LUCENE, null, 0, null) + ); } public void testQueryIndex_nmslib_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.NMSLIB, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.NMSLIB, null, 0, null)); } public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { @@ -765,7 +778,7 @@ public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { ); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.NMSLIB, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.NMSLIB, null, 0, null)); } public void testQueryIndex_nmslib_valid() throws IOException { @@ -792,7 +805,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.NMSLIB, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, null, KNNEngine.NMSLIB, null, 0, null); assertEquals(k, results.length); } } @@ -800,7 +813,7 @@ public void testQueryIndex_nmslib_valid() throws IOException { public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, KNNEngine.FAISS, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.FAISS, null, 0, null)); } public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { @@ -820,12 +833,13 @@ public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, KNNEngine.FAISS, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); } public void testQueryIndex_faiss_valid() throws IOException { int k = 10; + int efSearch = 100; List methods = ImmutableList.of(faissMethod); List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); @@ -850,13 +864,31 @@ public void testQueryIndex_faiss_valid() throws IOException { assertNotEquals(0, pointer); for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, null, 0, null); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + null + ); assertEquals(k, results.length); } // Filter will result in no ids for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, new long[] { 0 }, 0, null); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + new long[] { 0 }, + 0, + null + ); assertEquals(0, results.length); } } @@ -866,6 +898,7 @@ public void testQueryIndex_faiss_valid() throws IOException { public void testQueryIndex_faiss_parentIds() throws IOException { int k = 100; + int efSearch = 100; List methods = ImmutableList.of(faissMethod); List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); @@ -892,7 +925,16 @@ public void testQueryIndex_faiss_parentIds() throws IOException { assertNotEquals(0, pointer); for (float[] query : testDataNested.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, KNNEngine.FAISS, null, 0, parentIds); + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + parentIds + ); // Verify there is no more than one result from same parent Set parentIdSet = toParentIdSet(results, idToParentIdMap); assertEquals(results.length, parentIdSet.size()); @@ -1223,7 +1265,7 @@ private void assertQueryResultsMatch(float[][] testQueries, int k, List in for (float[] query : testQueries) { KNNQueryResult[][] allResults = new KNNQueryResult[indexAddresses.size()][]; for (int i = 0; i < indexAddresses.size(); i++) { - allResults[i] = JNIService.queryIndex(indexAddresses.get(i), query, k, KNNEngine.FAISS, null, 0, null); + allResults[i] = JNIService.queryIndex(indexAddresses.get(i), query, k, null, KNNEngine.FAISS, null, 0, null); assertEquals(k, allResults[i].length); } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index 9af1f49cc..cf7b869f8 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -523,7 +523,7 @@ public void trainKnnModel(String modelId, String trainingIndexName, String train assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); } - public void validateModelCreated(String modelId) throws IOException, InterruptedException { + public void validateModelCreated(String modelId) throws Exception { Response getResponse = getModel(modelId, null); String responseBody = EntityUtils.toString(getResponse.getEntity()); assertNotNull(responseBody); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java index a2078c291..d433cd285 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNWarmupHandlerIT.java @@ -36,7 +36,7 @@ public void testNonKnnIndex() throws IOException { knnWarmup(Collections.singletonList("not-knn-index")); } - public void testEmptyIndex() throws IOException { + public void testEmptyIndex() throws Exception { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName, getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); @@ -45,7 +45,7 @@ public void testEmptyIndex() throws IOException { assertEquals(graphCountBefore, getTotalGraphsInCache()); } - public void testSingleIndex() throws IOException { + public void testSingleIndex() throws Exception { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName, getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 6.0f, 6.0f }); @@ -55,7 +55,7 @@ public void testSingleIndex() throws IOException { assertEquals(graphCountBefore + 1, getTotalGraphsInCache()); } - public void testMultipleIndices() throws IOException { + public void testMultipleIndices() throws Exception { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName + "1", getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNWarmupHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNWarmupHandlerIT.java index f74135c0c..e6345faba 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNWarmupHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestLegacyKNNWarmupHandlerIT.java @@ -43,7 +43,7 @@ public void testNonKnnIndex() throws IOException { executeWarmupRequest(Collections.singletonList("not-knn-index"), KNNPlugin.LEGACY_KNN_BASE_URI); } - public void testEmptyIndex() throws IOException { + public void testEmptyIndex() throws Exception { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName, getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); @@ -52,7 +52,7 @@ public void testEmptyIndex() throws IOException { assertEquals(graphCountBefore, getTotalGraphsInCache()); } - public void testSingleIndex() throws IOException { + public void testSingleIndex() throws Exception { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName, getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 6.0f, 6.0f }); @@ -62,7 +62,7 @@ public void testSingleIndex() throws IOException { assertEquals(graphCountBefore + 1, getTotalGraphsInCache()); } - public void testMultipleIndices() throws IOException { + public void testMultipleIndices() throws Exception { int graphCountBefore = getTotalGraphsInCache(); createKnnIndex(testIndexName + "1", getKNNDefaultIndexSettings(), createKnnIndexMapping(testFieldName, dimensions)); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java index a22e1acb8..1ba6eae9b 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java @@ -20,7 +20,6 @@ import org.opensearch.knn.KNNRestTestCase; import org.opensearch.core.rest.RestStatus; -import java.io.IOException; import java.util.Map; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; @@ -35,7 +34,7 @@ public class RestTrainModelHandlerIT extends KNNRestTestCase { - public void testTrainModel_fail_notEnoughData() throws IOException, InterruptedException { + public void testTrainModel_fail_notEnoughData() throws Exception { // Check that training fails properly when there is not enough data @@ -326,7 +325,7 @@ public void testTrainModel_success_withId() throws Exception { assertTrainingSucceeds(modelId, 30, 1000); } - public void testTrainModel_success_noId() throws IOException, InterruptedException { + public void testTrainModel_success_noId() throws Exception { // Test to check if training succeeds when no id is passed in String trainingIndexName = "train-index"; diff --git a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java index 46240e830..d66241376 100644 --- a/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java +++ b/src/test/java/org/opensearch/knn/plugin/stats/suppliers/LibraryInitializedSupplierTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.plugin.stats.suppliers; import org.opensearch.common.ValidationException; +import org.opensearch.knn.index.util.EngineSpecificMethodContext; import org.opensearch.knn.index.KNNMethod; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; @@ -59,6 +60,11 @@ public KNNMethod getMethod(String methodName) { return null; } + @Override + public EngineSpecificMethodContext getMethodContext(String methodName) { + return null; + } + @Override public float score(float rawScore, SpaceType spaceType) { return 0; diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index adb1726d4..e71399c1e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -14,7 +14,6 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.xcontent.DeprecationHandler; -import org.opensearch.core.xcontent.MediaType; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; @@ -34,7 +33,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.MediaType; import org.opensearch.index.query.ExistsQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; @@ -124,7 +123,7 @@ public static void dumpCoverage() throws IOException, MalformedObjectNameExcepti // jacoco.dir is set in esplugin-coverage.gradle, if it doesn't exist we don't // want to collect coverage so we can return early String jacocoBuildPath = System.getProperty("jacoco.dir"); - if (StringUtils.isBlank(jacocoBuildPath)) { + if (org.opensearch.core.common.Strings.isNullOrEmpty(jacocoBuildPath)) { return; } @@ -194,18 +193,7 @@ protected Response searchKNNIndex(String index, KNNQueryBuilder knnQueryBuilder, XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject().endObject(); - - Request request = new Request("POST", "/" + index + "/_search"); - - request.addParameter("size", Integer.toString(resultSize)); - request.addParameter("explain", Boolean.toString(true)); - request.addParameter("search_type", "query_then_fetch"); - request.setJsonEntity(builder.toString()); - - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - - return response; + return searchKNNIndex(index, builder, resultSize); } /** @@ -260,8 +248,10 @@ protected Response performSearch(final String indexName, final String query) thr */ protected List parseSearchResponse(String responseBody, String fieldName) throws IOException { @SuppressWarnings("unchecked") - List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody).map() - .get("hits")).get("hits"); + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("hits")).get("hits"); @SuppressWarnings("unchecked") List knnSearchResponses = hits.stream().map(hit -> { @@ -283,8 +273,10 @@ protected List parseSearchResponse(String responseBody, String fieldN protected List parseSearchResponseScore(String responseBody, String fieldName) throws IOException { @SuppressWarnings("unchecked") - List hits = (List) ((Map) createParser(XContentType.JSON.xContent(), responseBody).map() - .get("hits")).get("hits"); + List hits = (List) ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("hits")).get("hits"); @SuppressWarnings("unchecked") List knnSearchResponses = hits.stream() @@ -299,8 +291,10 @@ protected List parseSearchResponseScore(String responseBody, String field */ protected Double parseAggregationResponse(String responseBody, String aggregationName) throws IOException { @SuppressWarnings("unchecked") - Map aggregations = ((Map) createParser(XContentType.JSON.xContent(), responseBody).map() - .get("aggregations")); + Map aggregations = ((Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("aggregations")); final Map values = (Map) aggregations.get(aggregationName); return Double.valueOf(String.valueOf(values.get("value"))); @@ -450,7 +444,7 @@ protected String createKnnIndexNestedMapping(Integer dimensions, String fieldPat * @return index mapping a map */ @SuppressWarnings("unchecked") - public Map getIndexMappingAsMap(String index) throws IOException { + public Map getIndexMappingAsMap(String index) throws Exception { Request request = new Request("GET", "/" + index + "/_mapping"); Response response = client().performRequest(request); @@ -459,12 +453,12 @@ public Map getIndexMappingAsMap(String index) throws IOException String responseBody = EntityUtils.toString(response.getEntity()); - Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); return (Map) ((Map) responseMap.get(index)).get("mappings"); } - public int getDocCount(String indexName) throws IOException { + public int getDocCount(String indexName) throws Exception { Request request = new Request("GET", "/" + indexName + "/_count"); Response response = client().performRequest(request); @@ -473,7 +467,7 @@ public int getDocCount(String indexName) throws IOException { String responseBody = EntityUtils.toString(response.getEntity()); - Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); return (Integer) responseMap.get("count"); } @@ -632,12 +626,14 @@ protected void deleteKnnDoc(String index, String docId) throws IOException { /** * Retrieve document by index and document id */ - protected Map getKnnDoc(final String index, final String docId) throws IOException { + protected Map getKnnDoc(final String index, final String docId) throws Exception { final Request request = new Request("GET", "/" + index + "/_doc/" + docId); final Response response = client().performRequest(request); - final Map responseMap = createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())) - .map(); + final Map responseMap = createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + EntityUtils.toString(response.getEntity()) + ).map(); assertNotNull(responseMap); assertTrue((Boolean) responseMap.get(DOCUMENT_FIELD_FOUND)); @@ -733,7 +729,7 @@ protected Response clearCache(List indices) throws IOException { * Parse KNN Cluster stats from response */ protected Map parseClusterStatsResponse(String responseBody) throws IOException { - Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + Map responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map(); responseMap.remove("cluster_name"); responseMap.remove("_nodes"); responseMap.remove("nodes"); @@ -745,7 +741,10 @@ protected Map parseClusterStatsResponse(String responseBody) thr */ protected List> parseNodeStatsResponse(String responseBody) throws IOException { @SuppressWarnings("unchecked") - Map responseMap = (Map) createParser(XContentType.JSON.xContent(), responseBody).map().get("nodes"); + Map responseMap = (Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + responseBody + ).map().get("nodes"); @SuppressWarnings("unchecked") List> nodeResponses = responseMap.keySet() @@ -761,8 +760,10 @@ protected List> parseNodeStatsResponse(String responseBody) */ @SuppressWarnings("unchecked") protected int parseTotalSearchHits(String searchResponseBody) throws IOException { - Map responseMap = (Map) createParser(XContentType.JSON.xContent(), searchResponseBody).map() - .get("hits"); + Map responseMap = (Map) createParser( + MediaTypeRegistry.getDefaultMediaType().xContent(), + searchResponseBody + ).map().get("hits"); return (int) ((Map) responseMap.get("total")).get("value"); } @@ -789,7 +790,7 @@ protected List parseIds(String searchResponseBody) throws IOException { * Get the total number of graphs in the cache across all nodes */ @SuppressWarnings("unchecked") - protected int getTotalGraphsInCache() throws IOException { + protected int getTotalGraphsInCache() throws Exception { Response response = getKnnStats(Collections.emptyList(), Collections.emptyList()); String responseBody = EntityUtils.toString(response.getEntity()); @@ -1046,7 +1047,7 @@ public void bulkAddKnnDocs(String index, String fieldName, float[][] indexVector } // Method that returns index vectors of the documents that were added before into the index - public float[][] getIndexVectorsFromIndex(String testIndex, String testField, int docCount, int dimensions) throws IOException { + public float[][] getIndexVectorsFromIndex(String testIndex, String testField, int docCount, int dimensions) throws Exception { float[][] vectors = new float[docCount][dimensions]; QueryBuilder qb = new MatchAllQueryBuilder(); @@ -1073,7 +1074,7 @@ public float[][] getIndexVectorsFromIndex(String testIndex, String testField, in } // Method that performs bulk search for multiple queries and stores the resulting documents ids into list - public List> bulkSearch(String testIndex, String testField, float[][] queryVectors, int k) throws IOException { + public List> bulkSearch(String testIndex, String testField, float[][] queryVectors, int k) throws Exception { List> searchResults = new ArrayList<>(); List kVectors; @@ -1110,12 +1111,22 @@ public void addKNNDocs(String testIndex, String testField, int dimension, int fi } } + public void validateKNNSearch(String testIndex, String testField, int dimension, int numDocs, int k) throws Exception { + validateKNNSearch(testIndex, testField, dimension, numDocs, k, null); + } + // Validate KNN search on a KNN index by generating the query vector from the number of documents in the index - public void validateKNNSearch(String testIndex, String testField, int dimension, int numDocs, int k) throws IOException { + public void validateKNNSearch(String testIndex, String testField, int dimension, int numDocs, int k, Map methodParameters) + throws Exception { float[] queryVector = new float[dimension]; Arrays.fill(queryVector, (float) numDocs); - Response searchResponse = searchKNNIndex(testIndex, new KNNQueryBuilder(testField, queryVector, k), k); + Response searchResponse = searchKNNIndex( + testIndex, + KNNQueryBuilder.builder().k(k).methodParameters(methodParameters).fieldName(testField).vector(queryVector).build(), + k + ); + List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), testField); assertEquals(k, results.size()); @@ -1371,7 +1382,7 @@ public void deleteModel(String modelId) throws IOException { assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); } - public void assertTrainingSucceeds(String modelId, int attempts, int delayInMillis) throws InterruptedException, IOException { + public void assertTrainingSucceeds(String modelId, int attempts, int delayInMillis) throws InterruptedException, Exception { int attemptNum = 0; Response response; Map responseMap; @@ -1382,7 +1393,8 @@ public void assertTrainingSucceeds(String modelId, int attempts, int delayInMill response = getModel(modelId, null); - responseMap = createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())).map(); + responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), EntityUtils.toString(response.getEntity())) + .map(); modelState = ModelState.getModelState((String) responseMap.get(MODEL_STATE)); if (modelState == ModelState.CREATED) { @@ -1395,7 +1407,7 @@ public void assertTrainingSucceeds(String modelId, int attempts, int delayInMill fail("Training did not succeed after " + attempts + " attempts with a delay of " + delayInMillis + " ms."); } - public void assertTrainingFails(String modelId, int attempts, int delayInMillis) throws InterruptedException, IOException { + public void assertTrainingFails(String modelId, int attempts, int delayInMillis) throws Exception { int attemptNum = 0; Response response; Map responseMap; @@ -1406,7 +1418,8 @@ public void assertTrainingFails(String modelId, int attempts, int delayInMillis) response = getModel(modelId, null); - responseMap = createParser(XContentType.JSON.xContent(), EntityUtils.toString(response.getEntity())).map(); + responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), EntityUtils.toString(response.getEntity())) + .map(); modelState = ModelState.getModelState((String) responseMap.get(MODEL_STATE)); if (modelState == ModelState.FAILED) { @@ -1526,9 +1539,9 @@ public interface IProxy { protected void refreshAllNonSystemIndices() throws Exception { Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all")); - MediaType xContentType = MediaType.fromMediaType(response.getEntity().getContentType().getValue()); + MediaType mediaType = MediaType.fromMediaType(response.getEntity().getContentType().getValue()); try ( - XContentParser parser = xContentType.xContent() + XContentParser parser = mediaType.xContent() .createParser( NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION,