diff --git a/CHANGELOG.md b/CHANGELOG.md index bd5b9ff72..5637adda0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Adds dynamic query parameter ef_search [#1783](https://github.com/opensearch-project/k-NN/pull/1783) * Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790) * Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781) +* Add binary format support with IVF method in Faiss Engine [#1784](https://github.com/opensearch-project/k-NN/pull/1784) ### Enhancements ### Bug Fixes * Fixing the arithmetic to find the number of vectors to stream from java to jni layer.[#1804](https://github.com/opensearch-project/k-NN/pull/1804) diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 2b9bc2c76..d25bf8f7c 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -29,6 +29,12 @@ namespace knn_jni { jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, jobject parametersJ); + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index + // based off of the template index passed in. The index is serialized to indexPathJ. + void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, + jobject parametersJ); + // Load an index from indexPathJ into memory. // // Return a pointer to the loaded index @@ -80,6 +86,12 @@ namespace knn_jni { jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Execute a query against the binary index located in memory at indexPointerJ along with Filters + // + // Return an array of KNNQueryResults + jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, + jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ); + // Free the index located in memory at indexPointerJ void Free(jlong indexPointer, jboolean isBinaryIndexJ); @@ -96,6 +108,13 @@ namespace knn_jni { jbyteArray TrainIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, jlong trainVectorsPointerJ); + // Create an empty binary index defined by the values in the Java map, parametersJ. Train the index with + // the vector of floats located at trainVectorsPointerJ. + // + // Return the serialized representation + jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, + jlong trainVectorsPointerJ); + /* * Perform a range search with filter against the index located in memory at indexPointerJ. * diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 7cc071ff3..025fb12e8 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -43,6 +43,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryInde JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: createBinaryIndexFromTemplate + * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + */ + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); + /* * Class: org_opensearch_knn_jni_FaissService * Method: loadIndex @@ -139,6 +147,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_initLibrary JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex (JNIEnv *, jclass, jobject, jint, jlong); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: trainBinaryIndex + * Signature: (Ljava/util/Map;IJ)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex + (JNIEnv *, jclass, jobject, jint, jlong); + /* * Class: org_opensearch_knn_jni_FaissService * Method: transferVectors diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 9abb2357f..92393245e 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -70,6 +70,9 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, // Train an index with data provided void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x); +// Train a binary index with data provided +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x); + // Converts the int FilterIds to Faiss ids type array. void convertFilterIdsToFaissIdType(const int* filterIds, int filterIdsLength, faiss::idx_t* convertedFilterIds); @@ -223,6 +226,76 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * faiss::write_index(&idMap, indexPathCpp.c_str()); } +void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, + jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) { + throw std::runtime_error("IDs cannot be null"); + } + + if (vectorsAddressJ <= 0) { + throw std::runtime_error("VectorsAddress cannot be less than 0"); + } + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (indexPathJ == nullptr) { + throw std::runtime_error("Index path cannot be null"); + } + + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Read data set + // Read vectors from memory address + auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); + int dim = (int)dimJ; + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiply of 8"); + } + int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); + int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); + if (numIds != numVectors) { + throw std::runtime_error("Number of IDs does not match number of vectors"); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr indexWriter; + indexWriter.reset(faiss::read_index_binary(&vectorIoReader, 0)); + + auto idVector = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); + faiss::IndexBinaryIDMap idMap = faiss::IndexBinaryIDMap(indexWriter.get()); + idMap.add_with_ids(numVectors, reinterpret_cast(inputVectors->data()), idVector.data()); + // Releasing the vectorsAddressJ memory as that is not required once we have created the index. + // This is not the ideal approach, please refer this gh issue for long term solution: + // https://github.com/opensearch-project/k-NN/issues/1600 + delete inputVectors; + // Write the index to disk + std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + faiss::write_index_binary(&idMap, indexPathCpp.c_str()); +} + jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { if (indexPathJ == nullptr) { throw std::runtime_error("Index path cannot be null"); @@ -624,6 +697,57 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti return ret; } +jbyteArray knn_jni::faiss_wrapper::TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, + jint dimensionJ, jlong trainVectorsPointerJ) { + // First, we need to build the index + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); + std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); + faiss::MetricType metric = TranslateSpaceToMetric(spaceTypeCpp); + + // Create faiss index + jobject indexDescriptionJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::INDEX_DESCRIPTION); + std::string indexDescriptionCpp(jniUtil->ConvertJavaObjectToCppString(env, indexDescriptionJ)); + + std::unique_ptr indexWriter; + indexWriter.reset(faiss::index_binary_factory((int) dimensionJ, indexDescriptionCpp.c_str())); + + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + omp_set_num_threads(threadCount); + } + + // Train index if needed + auto *trainingVectorsPointerCpp = reinterpret_cast*>(trainVectorsPointerJ); + int numVectors = trainingVectorsPointerCpp->size()/(int) dimensionJ; + if(!indexWriter->is_trained) { + InternalTrainBinaryIndex(indexWriter.get(), numVectors, trainingVectorsPointerCpp->data()); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + // Now that indexWriter is trained, we just load the bytes into an array and return + faiss::VectorIOWriter vectorIoWriter; + faiss::write_index_binary(indexWriter.get(), &vectorIoWriter); + + // Wrap in smart pointer + std::unique_ptr jbytesBuffer; + jbytesBuffer.reset(new jbyte[vectorIoWriter.data.size()]); + int c = 0; + for (auto b : vectorIoWriter.data) { + jbytesBuffer[c++] = (jbyte) b; + } + + jbyteArray ret = jniUtil->NewByteArray(env, vectorIoWriter.data.size()); + jniUtil->SetByteArrayRegion(env, ret, 0, vectorIoWriter.data.size(), jbytesBuffer.get()); + return ret; +} + faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType) { if (spaceType == knn_jni::L2) { return faiss::METRIC_L2; @@ -682,6 +806,15 @@ void InternalTrainIndex(faiss::Index * index, faiss::idx_t n, const float* x) { } } +void InternalTrainBinaryIndex(faiss::IndexBinary * index, faiss::idx_t n, const float* x) { + if (auto * indexIvf = dynamic_cast(index)) { + indexIvf->make_direct_map(); + } + if (!index->is_trained) { + index->train(n, reinterpret_cast(x)); + } +} + std::unique_ptr buildIDGrouperBitmap(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, jintArray parentIdsJ, std::vector* bitmap) { int *parentIdsArray = jniUtil->GetIntArrayElements(env, parentIdsJ, nullptr); int parentIdsLength = jniUtil->GetJavaIntArrayLength(env, parentIdsJ); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 6e447b034..2394e2951 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -90,6 +90,21 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromT } } +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jstring indexPathJ, + jbyteArray templateIndexJ, + jobject parametersJ) +{ + try { + knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { try { @@ -220,6 +235,19 @@ JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainIndex return nullptr; } +JNIEXPORT jbyteArray JNICALL Java_org_opensearch_knn_jni_FaissService_trainBinaryIndex(JNIEnv * env, jclass cls, + jobject parametersJ, + jint dimensionJ, + jlong trainVectorsPointerJ) +{ + try { + return knn_jni::faiss_wrapper::TrainBinaryIndex(&jniUtil, env, parametersJ, dimensionJ, trainVectorsPointerJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return nullptr; +} + JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors(JNIEnv * env, jclass cls, jlong vectorsPointerJ, jobjectArray vectorsJ) diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index a85852027..77aae79c3 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -67,6 +67,7 @@ public class KNNConstants { public static final String SEARCH_SIZE_PARAMETER = "search_size"; public static final String VECTOR_DATA_TYPE_FIELD = "data_type"; + public static final String MODEL_VECTOR_DATA_TYPE_KEY = VECTOR_DATA_TYPE_FIELD; public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT; public static final String RADIAL_SEARCH_KEY = "radial_search"; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 5815b343e..327c53844 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -48,6 +48,7 @@ public class IndexUtil { private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT = Version.V_2_13_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH = Version.V_2_14_0; private static final Version MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS = Version.V_2_16_0; + private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE = Version.V_2_16_0; // public so neural search can access it public static final Map minimalRequiredVersionMap = initializeMinimalRequiredVersionMap(); @@ -59,6 +60,7 @@ private static Map initializeMinimalRequiredVersionMap() { put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); put(KNNConstants.METHOD_PARAMETER, MINIMAL_SUPPORTED_VERSION_FOR_METHOD_PARAMETERS); + put(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VECTOR_DATA_TYPE); } }; @@ -135,7 +137,8 @@ public static ValidationException validateKnnField( IndexMetadata indexMetadata, String field, int expectedDimension, - ModelDao modelDao + ModelDao modelDao, + VectorDataType expectedVectorDataType ) { // Index metadata should not be null if (indexMetadata == null) { @@ -248,6 +251,29 @@ public static ValidationException validateKnnField( return exception; } + // Return if vector data type does not need to be checked + if (expectedVectorDataType == null) { + return null; + } + + // Determine the data type of the training index + VectorDataType trainIndexDataType = fieldMap.containsKey(VECTOR_DATA_TYPE_FIELD) + ? VectorDataType.get((String) fieldMap.get(VECTOR_DATA_TYPE_FIELD)) + : VectorDataType.FLOAT; + + // Check if the data type matches the expected vector data type + if (trainIndexDataType != expectedVectorDataType) { + exception.addValidationError( + String.format( + "Field \"%s\" has data type %s, which is different from data type used in the training request: %s", + field, + trainIndexDataType.getValue(), + expectedVectorDataType.getValue() + ) + ); + return exception; + } + return null; } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 50c1c9271..4403f6599 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -5,7 +5,6 @@ package org.opensearch.knn.index.codec.KNN80Codec; -import com.google.common.collect.ImmutableMap; import lombok.NonNull; import lombok.extern.log4j.Log4j2; import org.apache.lucene.store.ChecksumIndexInput; @@ -112,7 +111,18 @@ private KNNEngine getKNNEngine(@NonNull FieldInfo field) { } private VectorTransfer getVectorTransfer(FieldInfo field) { - if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + boolean isBinary = false; + + // Check if the field has a model ID and retrieve the model's vector data type + if (field.attributes().containsKey(MODEL_ID)) { + Model model = ModelCache.getInstance().get(field.attributes().get(MODEL_ID)); + isBinary = model.getModelMetadata().getVectorDataType() == VectorDataType.BINARY; + } else if (VectorDataType.BINARY.getValue().equalsIgnoreCase(field.attributes().get(KNNConstants.VECTOR_DATA_TYPE_FIELD))) { + isBinary = true; + } + + // Return the appropriate VectorTransfer instance based on the vector data type + if (isBinary) { return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); } return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes()); @@ -154,7 +164,7 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, if (model.getModelBlob() == null) { throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId)); } - indexCreator = () -> createKNNIndexFromTemplate(model.getModelBlob(), pair, knnEngine, indexPath); + indexCreator = () -> createKNNIndexFromTemplate(model, pair, knnEngine, indexPath); } else { indexCreator = () -> createKNNIndexFromScratch(field, pair, knnEngine, indexPath); } @@ -188,18 +198,22 @@ private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private void createKNNIndexFromTemplate(byte[] model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { - Map parameters = ImmutableMap.of( - KNNConstants.INDEX_THREAD_QTY, - KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) - ); + private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) { + Map parameters = new HashMap<>(); + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + // Update index description of Faiss for binary data type + if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) { + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); + } + AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, - model, + model.getModelBlob(), parameters, knnEngine ); diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 7f7c83f3e..7c7f446dd 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -603,7 +603,8 @@ protected void parseCreateField(ParseContext context) throws IOException { context, fieldType().getDimension(), fieldType().getSpaceType(), - getMethodComponentContext(fieldType().getKnnMethodContext()) + getMethodComponentContext(fieldType().getKnnMethodContext()), + fieldType().getVectorDataType() ); } @@ -646,8 +647,13 @@ protected List getFieldsForByteVector(final byte[] array, final FieldType return fields; } - protected void parseCreateField(ParseContext context, int dimension, SpaceType spaceType, MethodComponentContext methodComponentContext) - throws IOException { + protected void parseCreateField( + ParseContext context, + int dimension, + SpaceType spaceType, + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType + ) throws IOException { validateIfKNNPluginEnabled(); validateIfCircuitBreakerIsNotTriggered(); diff --git a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java index 554871279..adaaef28e 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java @@ -62,6 +62,12 @@ protected void parseCreateField(ParseContext context) throws IOException { ); } - parseCreateField(context, modelMetadata.getDimension(), modelMetadata.getSpaceType(), modelMetadata.getMethodComponentContext()); + parseCreateField( + context, + modelMetadata.getDimension(), + modelMetadata.getSpaceType(), + modelMetadata.getMethodComponentContext(), + modelMetadata.getVectorDataType() + ); } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index a10e04788..5baaf59cd 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -515,6 +515,7 @@ protected Query doToQuery(QueryShardContext context) { knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); methodComponentContext = modelMetadata.getMethodComponentContext(); + vectorDataType = modelMetadata.getVectorDataType(); } else if (knnMethodContext != null) { // If the dimension is set but the knnMethodContext is not then the field is using the legacy mapping diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index c08997c26..235e66411 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -213,6 +213,7 @@ private Map doANNSearch(final LeafReaderContext context, final B KNNEngine knnEngine; SpaceType spaceType; + VectorDataType vectorDataType; // Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's // metadata. @@ -225,11 +226,15 @@ private Map doANNSearch(final LeafReaderContext context, final B knnEngine = modelMetadata.getKnnEngine(); spaceType = modelMetadata.getSpaceType(); + vectorDataType = modelMetadata.getVectorDataType(); } else { String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName()); knnEngine = KNNEngine.getEngine(engineName); String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); spaceType = SpaceType.getSpace(spaceTypeName); + vectorDataType = VectorDataType.get( + fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()) + ); } /* @@ -261,12 +266,7 @@ private Map doANNSearch(final LeafReaderContext context, final B new NativeMemoryEntryContext.IndexEntryContext( indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading( - spaceType, - knnEngine, - knnQuery.getIndexName(), - VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())) - ), + getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName(), vectorDataType), knnQuery.getIndexName(), modelId ), diff --git a/src/main/java/org/opensearch/knn/index/util/Faiss.java b/src/main/java/org/opensearch/knn/index/util/Faiss.java index 711c206f5..4e39c1af1 100644 --- a/src/main/java/org/opensearch/knn/index/util/Faiss.java +++ b/src/main/java/org/opensearch/knn/index/util/Faiss.java @@ -305,7 +305,7 @@ public class Faiss extends NativeLibrary { return ((4L * centroids * dimension) / BYTES_PER_KILOBYTES) + 1; }) .build() - ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT).build() + ).addSpaces(SpaceType.UNDEFINED, SpaceType.L2, SpaceType.INNER_PRODUCT, SpaceType.HAMMING_BIT).build() ); final static Faiss INSTANCE = new Faiss( diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 0bc6c5edb..37edcd3ae 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -292,6 +292,7 @@ private void putInternal(Model model, ActionListener listener, Do put(KNNConstants.MODEL_DESCRIPTION, modelMetadata.getDescription()); put(KNNConstants.MODEL_ERROR, modelMetadata.getError()); put(KNNConstants.MODEL_NODE_ASSIGNMENT, modelMetadata.getNodeAssignment()); + put(KNNConstants.VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType()); MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext(); if (!methodComponentContext.getName().isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java index f3a5506cd..d11173f1d 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelMetadata.java +++ b/src/main/java/org/opensearch/knn/indices/ModelMetadata.java @@ -26,6 +26,7 @@ import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -48,6 +49,7 @@ public class ModelMetadata implements Writeable, ToXContentObject { final private String timestamp; final private String description; final private String trainingNodeAssignment; + final private VectorDataType vectorDataType; private MethodComponentContext methodComponentContext; private String error; @@ -81,6 +83,12 @@ public ModelMetadata(StreamInput in) throws IOException { } else { this.methodComponentContext = MethodComponentContext.EMPTY; } + + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + this.vectorDataType = VectorDataType.get(in.readOptionalString()); + } else { + this.vectorDataType = VectorDataType.FLOAT; + } } /** @@ -95,6 +103,7 @@ public ModelMetadata(StreamInput in) throws IOException { * @param error error message associated with model * @param trainingNodeAssignment node assignment for the model * @param methodComponentContext method component context associated with model + * @param vectorDataType vector data type of the model */ public ModelMetadata( KNNEngine knnEngine, @@ -105,7 +114,8 @@ public ModelMetadata( String description, String error, String trainingNodeAssignment, - MethodComponentContext methodComponentContext + MethodComponentContext methodComponentContext, + VectorDataType vectorDataType ) { this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null"); this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null"); @@ -128,6 +138,7 @@ public ModelMetadata( this.error = Objects.requireNonNull(error, "error must not be null"); this.trainingNodeAssignment = Objects.requireNonNull(trainingNodeAssignment, "node assignment must not be null"); this.methodComponentContext = Objects.requireNonNull(methodComponentContext, "method context must not be null"); + this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null"); } /** @@ -211,6 +222,10 @@ public MethodComponentContext getMethodComponentContext() { return methodComponentContext; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * setter for model's state * @@ -241,7 +256,8 @@ public String toString() { description, error, trainingNodeAssignment, - methodComponentContext.toClusterStateString() + methodComponentContext.toClusterStateString(), + vectorDataType.getValue() ); } @@ -259,6 +275,7 @@ public boolean equals(Object obj) { equalsBuilder.append(getTimestamp(), other.getTimestamp()); equalsBuilder.append(getDescription(), other.getDescription()); equalsBuilder.append(getError(), other.getError()); + equalsBuilder.append(getVectorDataType(), other.getVectorDataType()); return equalsBuilder.isEquals(); } @@ -273,6 +290,7 @@ public int hashCode() { .append(getDescription()) .append(getError()) .append(getMethodComponentContext()) + .append(getVectorDataType()) .toHashCode(); } @@ -308,7 +326,8 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); } else if (modelMetadataArray.length == 8) { log.debug("Model metadata contains training node assignment. Assuming empty method component context."); @@ -329,7 +348,8 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); } else if (modelMetadataArray.length == 9) { log.debug("Model metadata contains training node assignment and method context"); @@ -351,12 +371,37 @@ public static ModelMetadata fromString(String modelMetadataString) { description, error, trainingNodeAssignment, - methodComponentContext + methodComponentContext, + VectorDataType.FLOAT + ); + } else if (modelMetadataArray.length == 10) { + log.debug("Model metadata contains training node assignment and method context and vector data type"); + KNNEngine knnEngine = KNNEngine.getEngine(modelMetadataArray[0]); + SpaceType spaceType = SpaceType.getSpace(modelMetadataArray[1]); + int dimension = Integer.parseInt(modelMetadataArray[2]); + ModelState modelState = ModelState.getModelState(modelMetadataArray[3]); + String timestamp = modelMetadataArray[4]; + String description = modelMetadataArray[5]; + String error = modelMetadataArray[6]; + String trainingNodeAssignment = modelMetadataArray[7]; + MethodComponentContext methodComponentContext = MethodComponentContext.fromClusterStateString(modelMetadataArray[8]); + VectorDataType vectorDataType = VectorDataType.get(modelMetadataArray[9]); + return new ModelMetadata( + knnEngine, + spaceType, + dimension, + modelState, + timestamp, + description, + error, + trainingNodeAssignment, + methodComponentContext, + vectorDataType ); } else { throw new IllegalArgumentException( "Illegal format for model metadata. Must be of the form " - + "\",,,,,,\" or \",,,,,,,\" or \",,,,,,,,\"." + + "\",,,,,,,\" or \",,,,,,,,\" or \",,,,,,,,,\"." ); } } @@ -387,6 +432,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m Object error = modelSourceMap.get(KNNConstants.MODEL_ERROR); Object trainingNodeAssignment = modelSourceMap.get(KNNConstants.MODEL_NODE_ASSIGNMENT); Object methodComponentContext = modelSourceMap.get(KNNConstants.MODEL_METHOD_COMPONENT_CONTEXT); + Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD); if (trainingNodeAssignment == null) { trainingNodeAssignment = ""; @@ -416,7 +462,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map m objectToString(description), objectToString(error), objectToString(trainingNodeAssignment), - (MethodComponentContext) methodComponentContext + (MethodComponentContext) methodComponentContext, + VectorDataType.get(objectToString(vectorDataType)) ); return modelMetadata; } @@ -436,6 +483,9 @@ public void writeTo(StreamOutput out) throws IOException { if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), IndexUtil.MODEL_METHOD_COMPONENT_CONTEXT_KEY)) { getMethodComponentContext().writeTo(out); } + if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + out.writeOptionalString(vectorDataType.getValue()); + } } @Override @@ -456,6 +506,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws getMethodComponentContext().toXContent(builder, params); builder.endObject(); } + if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VECTOR_DATA_TYPE_KEY)) { + builder.field(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); + } return builder; } } diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index 21de90765..1f23f6fcd 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -96,6 +96,25 @@ public static native void createIndexFromTemplate( Map parameters ); + /** + * Create a binary index for the native library with a provided template index + * + * @param ids array of ids mapping to the data passed in + * @param vectorsAddress address of native memory where vectors are stored + * @param dim dimension of the vector to be indexed + * @param indexPath path to save index file to + * @param templateIndex empty template index + * @param parameters additional build time parameters + */ + public static native void createBinaryIndexFromTemplate( + int[] ids, + long vectorsAddress, + int dim, + String indexPath, + byte[] templateIndex, + Map parameters + ); + /** * Load an index into memory * @@ -249,6 +268,16 @@ public static native KNNQueryResult[] queryBinaryIndexWithFilter( */ public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** + * Train an empty binary index + * + * @param indexParameters parameters used to build index + * @param dimension dimension for the index + * @param trainVectorsPointer pointer to where training vectors are stored in native memory + * @return bytes array of trained template index + */ + public static native byte[] trainBinaryIndex(Map indexParameters, int dimension, long trainVectorsPointer); + /** *

* The function is deprecated. Use {@link JNICommons#storeVectorData(long, float[][], long)} diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index cefd0af53..2a8d3ea8f 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -83,8 +83,13 @@ public static void createIndexFromTemplate( KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { - FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); - return; + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } else { + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + return; + } } throw new IllegalArgumentException( @@ -308,6 +313,9 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE */ public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) { if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, indexParameters)) { + return FaissService.trainBinaryIndex(indexParameters, dimension, trainVectorsPointer); + } return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); } diff --git a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java index fb8ccc4ce..c1a843540 100644 --- a/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java +++ b/src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java @@ -17,6 +17,7 @@ import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.plugin.KNNPlugin; import org.opensearch.knn.plugin.transport.TrainingJobRouterAction; @@ -40,6 +41,7 @@ import static org.opensearch.knn.common.KNNConstants.SEARCH_SIZE_PARAMETER; import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; /** * Rest Handler for model training api endpoint. @@ -83,6 +85,7 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr String trainingIndex = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String trainingField = (String) DEFAULT_NOT_SET_OBJECT_VALUE; String description = (String) DEFAULT_NOT_SET_OBJECT_VALUE; + VectorDataType vectorDataType = (VectorDataType) DEFAULT_NOT_SET_OBJECT_VALUE; int dimension = DEFAULT_NOT_SET_INT_VALUE; int maximumVectorCount = DEFAULT_NOT_SET_INT_VALUE; @@ -110,6 +113,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr } else if (MODEL_DESCRIPTION.equals(fieldName) && ensureNotSet(fieldName, description)) { description = parser.textOrNull(); ModelUtil.blockCommasInModelDescription(description); + } else if (VECTOR_DATA_TYPE_FIELD.equals(fieldName) && ensureNotSet(fieldName, vectorDataType)) { + vectorDataType = VectorDataType.get(parser.text()); } else { throw new IllegalArgumentException("Unable to parse token. \"" + fieldName + "\" is not a valid " + "parameter."); } @@ -126,6 +131,10 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr description = ""; } + if (vectorDataType == DEFAULT_NOT_SET_OBJECT_VALUE) { + vectorDataType = VectorDataType.FLOAT; + } + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( modelId, knnMethodContext, @@ -133,7 +142,8 @@ private TrainingModelRequest createTransportRequest(RestRequest restRequest) thr trainingIndex, trainingField, preferredNodeId, - description + description, + vectorDataType ); if (maximumVectorCount != DEFAULT_NOT_SET_INT_VALUE) { diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 5f3913ac5..11ec7aada 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -21,6 +21,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.training.VectorSpaceInfo; @@ -41,6 +42,7 @@ public class TrainingModelRequest extends ActionRequest { private final String trainingField; private final String preferredNodeId; private final String description; + private final VectorDataType vectorDataType; private int maximumVectorCount; private int searchSize; @@ -65,7 +67,8 @@ public TrainingModelRequest( String trainingIndex, String trainingField, String preferredNodeId, - String description + String description, + VectorDataType vectorDataType ) { super(); this.modelId = modelId; @@ -75,6 +78,7 @@ public TrainingModelRequest( this.trainingField = trainingField; this.preferredNodeId = preferredNodeId; this.description = description; + this.vectorDataType = vectorDataType; // Set these as defaults initially. If call wants to override them, they can use the setters. this.maximumVectorCount = Integer.MAX_VALUE; // By default, get all vectors in the index @@ -103,6 +107,7 @@ public TrainingModelRequest(StreamInput in) throws IOException { this.maximumVectorCount = in.readInt(); this.searchSize = in.readInt(); this.trainingDataSizeInKB = in.readInt(); + this.vectorDataType = VectorDataType.get(in.readOptionalString()); } /** @@ -213,6 +218,10 @@ public int getSearchSize() { return searchSize; } + public VectorDataType getVectorDataType() { + return vectorDataType; + } + /** * Setter for search size. * @@ -314,7 +323,13 @@ public ActionRequestValidationException validate() { } // Validate the training field - ValidationException fieldValidation = IndexUtil.validateKnnField(indexMetadata, this.trainingField, this.dimension, modelDao); + ValidationException fieldValidation = IndexUtil.validateKnnField( + indexMetadata, + this.trainingField, + this.dimension, + modelDao, + this.vectorDataType + ); if (fieldValidation != null) { exception = exception == null ? new ActionRequestValidationException() : exception; exception.addValidationErrors(fieldValidation.validationErrors()); @@ -336,5 +351,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeInt(this.maximumVectorCount); out.writeInt(this.searchSize); out.writeInt(this.trainingDataSizeInKB); + out.writeOptionalString(this.vectorDataType.getValue()); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 33b420e2c..58ac41b31 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -68,7 +68,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener modelAnonymousEntryContext, request.getDimension(), request.getDescription(), - clusterService.localNode().getEphemeralId() + clusterService.localNode().getEphemeralId(), + request.getVectorDataType() ); KNNCounter.TRAINING_REQUESTS.increment(); diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index aa2786c0a..ad6d99ebc 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -17,6 +17,7 @@ import org.opensearch.common.UUIDs; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.memory.NativeMemoryAllocation; @@ -32,6 +33,8 @@ import java.util.Map; import java.util.Objects; +import static org.opensearch.knn.index.util.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + /** * Encapsulates all information required to generate and train a model. */ @@ -66,7 +69,8 @@ public TrainingJob( NativeMemoryEntryContext.AnonymousEntryContext modelAnonymousEntryContext, int dimension, String description, - String nodeAssignment + String nodeAssignment, + VectorDataType vectorDataType ) { // Generate random base64 string if one is not provided this.modelId = StringUtils.isNotBlank(modelId) ? modelId : UUIDs.randomBase64UUID(); @@ -84,7 +88,8 @@ public TrainingJob( description, "", nodeAssignment, - knnMethodContext.getMethodComponentContext() + knnMethodContext.getMethodComponentContext(), + vectorDataType ), null, this.modelId @@ -182,6 +187,14 @@ public void run() { KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY) ); + if (VectorDataType.BINARY.equals(model.getModelMetadata().getVectorDataType())) { + trainParameters.put( + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + FAISS_BINARY_INDEX_DESCRIPTION_PREFIX + trainParameters.get(KNNConstants.INDEX_DESCRIPTION_PARAMETER).toString() + ); + trainParameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()); + } + byte[] modelBlob = JNIService.trainIndex( trainParameters, model.getModelMetadata().getDimension(), diff --git a/src/main/java/org/opensearch/knn/training/VectorReader.java b/src/main/java/org/opensearch/knn/training/VectorReader.java index aeebae129..8eb7b475c 100644 --- a/src/main/java/org/opensearch/knn/training/VectorReader.java +++ b/src/main/java/org/opensearch/knn/training/VectorReader.java @@ -89,7 +89,7 @@ public void read( throw validationException; } - ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null); + ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null, null); if (fieldValidationException != null) { validationException = validationException == null ? new ValidationException() : validationException; validationException.addValidationErrors(validationException.validationErrors()); diff --git a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java index f9c0161d6..06431bf07 100644 --- a/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java +++ b/src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java @@ -46,8 +46,17 @@ import java.util.concurrent.ExecutionException; import static org.mockito.Mockito.when; -import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_BLOB_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; +import static org.opensearch.knn.common.KNNConstants.MODEL_ERROR; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.MODEL_STATE; +import static org.opensearch.knn.common.KNNConstants.MODEL_TIMESTAMP; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class KNNSingleNodeTestCase extends OpenSearchSingleNodeTestCase { @Override @@ -201,7 +210,8 @@ protected void writeModelToModelSystemIndex(Model model) throws IOException, Exe .field(MODEL_STATE, modelMetadata.getState().getName()) .field(MODEL_TIMESTAMP, modelMetadata.getTimestamp().toString()) .field(MODEL_DESCRIPTION, modelMetadata.getDescription()) - .field(MODEL_ERROR, modelMetadata.getError()); + .field(MODEL_ERROR, modelMetadata.getError()) + .field(VECTOR_DATA_TYPE_FIELD, modelMetadata.getVectorDataType().getValue()); if (model.getModelBlob() != null) { builder.field(MODEL_BLOB_PARAMETER, Base64.getEncoder().encodeToString(model.getModelBlob())); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index b9116b0b1..6819db1bb 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -44,6 +44,7 @@ import java.util.TreeMap; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; @@ -55,18 +56,25 @@ import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; +import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.NAME; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; public class FaissIT extends KNNRestTestCase { private static final String DOC_ID_1 = "doc1"; @@ -107,13 +115,13 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .endObject() .endObject() @@ -166,13 +174,13 @@ public void testEndToEnd_whenDoRadiusSearch_whenScoreThreshold_whenMethodIsHNSWF .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .endObject() .endObject() @@ -226,13 +234,13 @@ public void testEndToEnd_whenDoRadiusSearch_whenMoreThanOneScoreThreshold_whenMe .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .endObject() .endObject() @@ -296,8 +304,8 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN .field(NAME, METHOD_HNSW) .field(KNN_ENGINE, FAISS_NAME) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) @@ -424,8 +432,8 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { .field(NAME, METHOD_HNSW) .field(KNN_ENGINE, FAISS_NAME) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) @@ -531,13 +539,13 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_SQ) @@ -644,13 +652,13 @@ public void testHNSWSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_SQ) @@ -744,13 +752,13 @@ public void testHNSWSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_then .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_SQ) @@ -997,7 +1005,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( .field(NAME, METHOD_HNSW) .field(KNN_ENGINE, FAISS_NAME) .startObject(PARAMETERS) - .field(KNNConstants.METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) .startObject(METHOD_ENCODER_PARAMETER) .field(NAME, ENCODER_PQ) .startObject(PARAMETERS) @@ -1204,7 +1212,7 @@ public void testDocUpdate() throws IOException { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) @@ -1240,7 +1248,7 @@ public void testDocDeletion() throws IOException { .startObject(fieldName) .field("type", "knn_vector") .field("dimension", dimension) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, hnswMethod.getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) @@ -1418,7 +1426,7 @@ public void testFiltering_whenUsingFaissExactSearchWithIP_thenMatchExpectedScore .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", 2) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.INNER_PRODUCT.getValue()) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) @@ -1593,6 +1601,119 @@ public void testIVF_InvalidPQM_thenFail() { ); } + @SneakyThrows + public void testIVF_whenBinaryFormat_thenSuccess() { + String modelId = "test-model-ivf-binary"; + int dimension = 8; + + String trainingIndexName = "train-index-ivf-binary"; + String trainingFieldName = "train-field-ivf-binary"; + + String trainIndexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(trainingFieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .field("data_type", VectorDataType.BINARY.getValue()) + .startObject(KNN_METHOD) + .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) + .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.HAMMING_BIT.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, 24) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, 128) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(trainingIndexName, trainIndexMapping); + + int trainingDataCount = 200; + bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, trainingIndexName) + .field(TRAIN_FIELD_PARAMETER, trainingFieldName) + .field(DIMENSION, dimension) + .field(MODEL_DESCRIPTION, "My model description") + .field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()) + .field( + KNN_METHOD, + Map.of( + NAME, + METHOD_IVF, + KNN_ENGINE, + FAISS_NAME, + METHOD_PARAMETER_SPACE_TYPE, + SpaceType.HAMMING_BIT.getValue(), + PARAMETERS, + Map.of(METHOD_PARAMETER_NLIST, 4, METHOD_PARAMETER_NPROBES, 2) + ) + ) + .endObject(); + + trainModel(modelId, trainModelXContentBuilder); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(modelId, 30, 1000); + + // Create knn index from model + String fieldName = "test-field-name-ivf-binary"; + String indexName = "test-index-name-ivf-binary"; + String indexMapping = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field(MODEL_ID, modelId) + .endObject() + .endObject() + .endObject() + .toString(); + + createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping); + Integer[] vector1 = { 11 }; + Integer[] vector2 = { 22 }; + Integer[] vector3 = { 33 }; + Integer[] vector4 = { 44 }; + addKnnDoc(indexName, "1", fieldName, vector1); + addKnnDoc(indexName, "2", fieldName, vector2); + addKnnDoc(indexName, "3", fieldName, vector3); + addKnnDoc(indexName, "4", fieldName, vector4); + + Integer[] queryVector = { 15 }; + int k = 2; + + XContentBuilder queryBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject("query") + .startObject("knn") + .startObject(fieldName) + .field("vector", queryVector) + .field("k", k) + .endObject() + .endObject() + .endObject() + .endObject(); + Response searchResponse = searchKNNIndex(indexName, queryBuilder, k); + List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + assertEquals(k, results.size()); + for (int i = 1; i <= k; i++) { + assertEquals(Integer.toString(i), results.get(i - 1).getDocId()); + } + + deleteKNNIndex(indexName); + Thread.sleep(1000 * 45); + deleteModel(modelId); + deleteKNNIndex(trainingIndexName); + validateGraphEviction(); + } + protected void setupKNNIndexForFilterQuery() throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1601,7 +1722,7 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .startObject(FIELD_NAME) .field("type", "knn_vector") .field("dimension", 3) - .startObject(KNNConstants.KNN_METHOD) + .startObject(KNN_METHOD) .field(NAME, KNNEngine.FAISS.getMethod(METHOD_HNSW).getMethodComponent().getName()) .field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2) .field(KNN_ENGINE, KNNEngine.FAISS.getName()) diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index d500fc342..d7867d383 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -117,7 +117,7 @@ public void testValidateKnnField_NestedField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assertNull(e); } @@ -138,7 +138,7 @@ public void testValidateKnnField_NonNestedField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assertNull(e); } @@ -158,7 +158,7 @@ public void testValidateKnnField_NonKnnField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assert Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" is not of type knn_vector.;"); } @@ -182,7 +182,7 @@ public void testValidateKnnField_WrongFieldPath() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Field \"" + field + "\" does not exist.;")); } @@ -206,7 +206,7 @@ public void testValidateKnnField_EmptyField() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); System.out.println(Objects.requireNonNull(e).getMessage()); @@ -223,7 +223,7 @@ public void testValidateKnnField_EmptyIndexMetadata() { when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension); when(modelDao.getMetadata(anyString())).thenReturn(trainingFieldModelMetadata); - ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao); + ValidationException e = IndexUtil.validateKnnField(indexMetadata, field, dimension, modelDao, null); assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;")); } diff --git a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java index 11a8bdb15..e9b78e7ec 100644 --- a/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNCreateIndexFromModelTests.java @@ -63,7 +63,8 @@ public void testCreateIndexFromModel() throws IOException, InterruptedException "", "", "test-node", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model model = new Model(modelMetadata, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java index 847cad04e..8ad5c90d7 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumerTests.java @@ -424,7 +424,8 @@ public void testAddKNNBinaryField_fromModel_faiss() throws IOException, Executio "Empty description", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBytes, modelId diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index b82bc85e0..66fe9770d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -20,15 +20,16 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.VectorField; import org.opensearch.knn.index.query.KNNQueryFactory; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.query.KNNQuery; -import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.query.KNNWeight; -import org.opensearch.knn.index.SpaceType; -import org.opensearch.knn.index.VectorField; import org.apache.lucene.codecs.Codec; import org.apache.lucene.document.Document; import org.apache.lucene.document.FieldType; @@ -213,7 +214,8 @@ public void testBuildFromModelTemplate(Codec codec) throws IOException, Executio "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model mockModel = new Model(modelMetadata1, modelBlob, modelId); diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index c3ddcf185..0e3fc6fb4 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -175,7 +175,8 @@ public void testBuilder_build_fromModel() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); builder.modelId.setValue(modelId); Mapper.BuilderContext builderContext = new Mapper.BuilderContext(settings, new ContentPath()); @@ -676,7 +677,8 @@ public void testKNNVectorFieldMapper_merge_fromModel() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); when(mockModelDao.getMetadata(modelId)).thenReturn(mockModelMetadata); @@ -747,7 +749,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.FLOAT ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnVectorField @@ -791,7 +794,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withFloats() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.FLOAT ); // Document should have 1 field: one for KnnVectorField @@ -826,7 +830,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.BYTE ); // Document should have 2 fields: one for VectorField (binary doc values) and one for KnnByteVectorField @@ -869,7 +874,8 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() { parseContext, TEST_DIMENSION, luceneFieldMapper.fieldType().spaceType, - luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext() + luceneFieldMapper.fieldType().knnMethodContext.getMethodComponentContext(), + VectorDataType.BYTE ); // Document should have 1 field: one for KnnByteVectorField diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 06e370026..0ab5a34cc 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -903,6 +903,7 @@ public void testDoToQuery_FromModel() { when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -940,6 +941,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); KNNQueryBuilder.initialize(modelDao); @@ -975,6 +977,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); ModelDao modelDao = mock(ModelDao.class); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 9769bcc15..25b24812c 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -38,6 +38,7 @@ import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.KNNCodecVersion; @@ -62,6 +63,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static java.util.Collections.emptyMap; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; @@ -199,6 +201,8 @@ public void testQueryScoreForFaissWithModel() { when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS); when(modelMetadata.getSpaceType()).thenReturn(spaceType); when(modelMetadata.getState()).thenReturn(ModelState.CREATED); + when(modelMetadata.getVectorDataType()).thenReturn(VectorDataType.FLOAT); + when(modelMetadata.getMethodComponentContext()).thenReturn(new MethodComponentContext("ivf", emptyMap())); when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata); KNNWeight.initialize(modelDao); diff --git a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java index 3a5255cd3..7a42e8a25 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelCacheTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.time.ZoneOffset; @@ -45,7 +46,8 @@ public void testGet_normal() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), "hello".getBytes(), modelId @@ -82,7 +84,8 @@ public void testGet_modelDoesNotFitInCache() throws ExecutionException, Interrup "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[BYTES_PER_KILOBYTES + 1], modelId @@ -140,7 +143,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size1], modelId1 @@ -156,7 +160,8 @@ public void testGetTotalWeight() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size2], modelId2 @@ -200,7 +205,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size1], modelId1 @@ -216,8 +222,8 @@ public void testRemove_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY - + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size2], modelId2 @@ -266,7 +272,8 @@ public void testRebuild_normal() throws ExecutionException, InterruptedException "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), "hello".getBytes(), modelId @@ -312,7 +319,8 @@ public void testRebuild_afterSettingUpdate() throws ExecutionException, Interrup "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize], modelId @@ -381,7 +389,8 @@ public void testContains() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize1], modelId1 @@ -423,7 +432,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize1], modelId1 @@ -441,7 +451,8 @@ public void testRemoveAll() throws ExecutionException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[modelSize2], modelId2 @@ -487,7 +498,8 @@ public void testModelCacheEvictionDueToSize() throws ExecutionException, Interru "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[BYTES_PER_KILOBYTES * 2], modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 75c523332..e3619975e 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -35,6 +35,7 @@ import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; @@ -139,7 +140,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -159,7 +161,8 @@ public void testModelIndexHealth() throws InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -187,7 +190,8 @@ public void testPut_withId() throws InterruptedException, IOException { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.FLOAT ), modelBlob, modelId @@ -248,7 +252,8 @@ public void testPut_withoutModel() throws InterruptedException, IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -310,7 +315,8 @@ public void testPut_invalid_badState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, "any-id" @@ -347,7 +353,8 @@ public void testUpdate() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, modelId @@ -386,7 +393,8 @@ public void testUpdate() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -437,7 +445,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -456,7 +465,8 @@ public void testGet() throws IOException, InterruptedException, ExecutionExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, modelId @@ -493,7 +503,8 @@ public void testGetMetadata() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model model = new Model(modelMetadata, modelBlob, modelId); @@ -570,7 +581,8 @@ public void testDelete() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -604,7 +616,8 @@ public void testDelete() throws IOException, InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId1 @@ -672,7 +685,8 @@ public void testDeleteModelInTrainingWithStepListeners() throws IOException, Exe "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId @@ -714,7 +728,8 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java index 74715671f..cd36496f8 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelMetadataTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.io.IOException; @@ -45,7 +46,8 @@ public void testStreams() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -67,7 +69,8 @@ public void testGetKnnEngine() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(knnEngine, modelMetadata.getKnnEngine()); @@ -84,7 +87,8 @@ public void testGetSpaceType() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(spaceType, modelMetadata.getSpaceType()); @@ -101,7 +105,8 @@ public void testGetDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(dimension, modelMetadata.getDimension()); @@ -118,7 +123,8 @@ public void testGetState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(modelState, modelMetadata.getState()); @@ -135,7 +141,8 @@ public void testGetTimestamp() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(timeValue, modelMetadata.getTimestamp()); @@ -152,7 +159,8 @@ public void testDescription() { description, "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(description, modelMetadata.getDescription()); @@ -169,7 +177,8 @@ public void testGetError() { "", error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(error, modelMetadata.getError()); @@ -186,7 +195,8 @@ public void testSetState() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(modelState, modelMetadata.getState()); @@ -207,7 +217,8 @@ public void testSetError() { "", error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(error, modelMetadata.getError()); @@ -244,7 +255,9 @@ public void testToString() { + "," + nodeAssignment + "," - + methodComponentContext.toClusterStateString(); + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.FLOAT.getValue(); ModelMetadata modelMetadata = new ModelMetadata( knnEngine, @@ -255,7 +268,8 @@ public void testToString() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); assertEquals(expected, modelMetadata.toString()); @@ -275,7 +289,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -286,7 +301,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -298,7 +314,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -309,7 +326,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -320,7 +338,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -331,7 +350,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -342,7 +362,8 @@ public void testEquals() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -353,7 +374,8 @@ public void testEquals() { "diff descript", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -364,7 +386,8 @@ public void testEquals() { "", "diff error", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -376,7 +399,8 @@ public void testEquals() { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.FLOAT ); assertEquals(modelMetadata1, modelMetadata1); @@ -406,7 +430,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata2 = new ModelMetadata( KNNEngine.FAISS, @@ -417,7 +442,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata3 = new ModelMetadata( @@ -429,7 +455,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata4 = new ModelMetadata( KNNEngine.FAISS, @@ -440,7 +467,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata5 = new ModelMetadata( KNNEngine.FAISS, @@ -451,7 +479,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata6 = new ModelMetadata( KNNEngine.FAISS, @@ -462,7 +491,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata7 = new ModelMetadata( KNNEngine.FAISS, @@ -473,7 +503,8 @@ public void testHashCode() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata8 = new ModelMetadata( KNNEngine.FAISS, @@ -484,7 +515,8 @@ public void testHashCode() { "diff descript", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata9 = new ModelMetadata( KNNEngine.FAISS, @@ -495,7 +527,8 @@ public void testHashCode() { "", "diff error", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata modelMetadata10 = new ModelMetadata( @@ -507,7 +540,8 @@ public void testHashCode() { "", "", "", - new MethodComponentContext("test", Collections.emptyMap()) + new MethodComponentContext("test", Collections.emptyMap()), + VectorDataType.FLOAT ); assertEquals(modelMetadata1.hashCode(), modelMetadata1.hashCode()); @@ -550,7 +584,9 @@ public void testFromString() { + "," + nodeAssignment + "," - + methodComponentContext.toClusterStateString(); + + methodComponentContext.toClusterStateString() + + "," + + VectorDataType.FLOAT.getValue(); String stringRep2 = knnEngine.getName() + "," @@ -564,7 +600,9 @@ public void testFromString() { + "," + description + "," - + error; + + error + + "," + + VectorDataType.FLOAT.getValue(); ModelMetadata expected1 = new ModelMetadata( knnEngine, @@ -575,7 +613,8 @@ public void testFromString() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata expected2 = new ModelMetadata( @@ -587,7 +626,8 @@ public void testFromString() { description, error, "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); ModelMetadata fromString1 = ModelMetadata.fromString(stringRep1); @@ -620,7 +660,8 @@ public void testFromResponseMap() throws IOException { description, error, nodeAssignment, - methodComponentContext + methodComponentContext, + VectorDataType.FLOAT ); ModelMetadata expected2 = new ModelMetadata( @@ -632,7 +673,8 @@ public void testFromResponseMap() throws IOException { description, error, "", - emptyMethodComponentContext + emptyMethodComponentContext, + VectorDataType.FLOAT ); Map metadataAsMap = new HashMap<>(); metadataAsMap.put(KNNConstants.KNN_ENGINE, knnEngine.getName()); @@ -643,6 +685,7 @@ public void testFromResponseMap() throws IOException { metadataAsMap.put(KNNConstants.MODEL_DESCRIPTION, description); metadataAsMap.put(KNNConstants.MODEL_ERROR, error); metadataAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); + metadataAsMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); builder = methodComponentContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); @@ -678,7 +721,8 @@ public void testBlockCommasInDescription() { description, error, nodeAssignment, - methodComponentContext + methodComponentContext, + VectorDataType.FLOAT ) ); assertEquals("Model description cannot contain any commas: ','", e.getMessage()); diff --git a/src/test/java/org/opensearch/knn/indices/ModelTests.java b/src/test/java/org/opensearch/knn/indices/ModelTests.java index 13579acad..773a10a2c 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import java.time.ZoneOffset; @@ -41,7 +42,8 @@ public void testInvalidConstructor() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, "test-model" @@ -62,7 +64,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[16], "test-model" @@ -80,7 +83,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[16], "test-model" @@ -98,7 +102,8 @@ public void testInvalidDimension() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[16], "test-model" @@ -117,7 +122,8 @@ public void testGetModelMetadata() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Model model = new Model(modelMetadata, new byte[16], "test-model"); assertEquals(modelMetadata, model.getModelMetadata()); @@ -135,7 +141,8 @@ public void testGetModelBlob() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, "test-model" @@ -155,7 +162,8 @@ public void testGetLength() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[size], "test-model" @@ -172,7 +180,8 @@ public void testGetLength() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, "test-model" @@ -192,7 +201,8 @@ public void testSetModelBlob() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), blob1, "test-model" @@ -209,17 +219,50 @@ public void testEquals() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L2, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L2, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-2" ); @@ -234,17 +277,50 @@ public void testHashCode() { String time = ZonedDateTime.now(ZoneOffset.UTC).toString(); Model model1 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model2 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-1" ); Model model3 = new Model( - new ModelMetadata(KNNEngine.DEFAULT, SpaceType.L1, 2, ModelState.CREATED, time, "", "", "", MethodComponentContext.EMPTY), + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.L1, + 2, + ModelState.CREATED, + time, + "", + "", + "", + MethodComponentContext.EMPTY, + VectorDataType.FLOAT + ), new byte[16], "test-model-2" ); @@ -274,7 +350,8 @@ public void testModelFromSourceMap() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); Map modelAsMap = new HashMap<>(); modelAsMap.put(KNNConstants.MODEL_ID, modelID); @@ -287,6 +364,7 @@ public void testModelFromSourceMap() { modelAsMap.put(KNNConstants.MODEL_ERROR, error); modelAsMap.put(KNNConstants.MODEL_NODE_ASSIGNMENT, nodeAssignment); modelAsMap.put(KNNConstants.MODEL_BLOB_PARAMETER, "aGVsbG8="); + modelAsMap.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()); byte[] blob1 = "hello".getBytes(); Model expected = new Model(metadata, blob1, modelID); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java index d9949aaf2..5d16fe59d 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestKNNStatsHandlerIT.java @@ -36,7 +36,20 @@ import static org.opensearch.knn.TestUtils.KNN_VECTOR; import static org.opensearch.knn.TestUtils.PROPERTIES; import static org.opensearch.knn.TestUtils.VECTOR_TYPE; -import static org.opensearch.knn.common.KNNConstants.*; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.LUCENE_NAME; +import static org.opensearch.knn.common.KNNConstants.MAX_DISTANCE; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MIN_SCORE; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.NMSLIB_NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.NAME; /** * Integration tests to check the correctness of RestKNNStatsHandler diff --git a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java index a6985e72a..2106e31ac 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/GetModelResponseTests.java @@ -20,6 +20,7 @@ import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; @@ -43,7 +44,8 @@ private ModelMetadata getModelMetadata(ModelState state) { "test model", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); } @@ -68,7 +70,7 @@ public void testXContent() throws IOException { Model model = new Model(getModelMetadata(ModelState.CREATED), testModelBlob, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"aGVsbG8=\",\"state\":\"created\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); @@ -84,7 +86,7 @@ public void testXContentWithNoModelBlob() throws IOException { Model model = new Model(getModelMetadata(ModelState.FAILED), null, modelId); GetModelResponse getModelResponse = new GetModelResponse(model); String expectedResponseString = - "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}}}"; + "{\"model_id\":\"test-model\",\"model_blob\":\"\",\"state\":\"failed\",\"timestamp\":\"2021-03-27 10:15:30 AM +05:30\",\"description\":\"test model\",\"error\":\"\",\"space_type\":\"l2\",\"dimension\":4,\"engine\":\"nmslib\",\"training_node_assignment\":\"\",\"model_definition\":{\"name\":\"\",\"parameters\":{}},\"data_type\":\"float\"}"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); getModelResponse.toXContent(xContentBuilder, null); assertEquals(expectedResponseString, xContentBuilder.toString()); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java index a2da83dad..3b25bc0eb 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/RemoveModelFromCacheTransportActionTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; @@ -78,7 +79,8 @@ public void testNodeOperation_modelInCache() throws ExecutionException, Interrup "description", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), new byte[128], modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java index 56c50aca1..bdca896ff 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingJobRouterTransportActionTests.java @@ -24,6 +24,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.KNNMethodContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.transport.TransportService; @@ -307,7 +308,8 @@ public void testTrainingIndexSize() { trainingIndexName, "training-field", null, - "description" + "description", + VectorDataType.FLOAT ); // Mock client to return the right number of docs diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index b39c48635..28e09a61e 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -25,6 +25,7 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; @@ -61,7 +62,8 @@ public void testStreams() throws IOException { trainingIndex, trainingField, preferredNode, - description + description, + VectorDataType.FLOAT ); BytesStreamOutput streamOutput = new BytesStreamOutput(); @@ -74,6 +76,7 @@ public void testStreams() throws IOException { assertEquals(original1.getTrainingIndex(), copy1.getTrainingIndex()); assertEquals(original1.getTrainingField(), copy1.getTrainingField()); assertEquals(original1.getPreferredNodeId(), copy1.getPreferredNodeId()); + assertEquals(original1.getVectorDataType(), copy1.getVectorDataType()); // Also, check when preferred node and model id and description are null TrainingModelRequest original2 = new TrainingModelRequest( @@ -83,7 +86,8 @@ public void testStreams() throws IOException { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); streamOutput = new BytesStreamOutput(); @@ -96,6 +100,7 @@ public void testStreams() throws IOException { assertEquals(original2.getTrainingIndex(), copy2.getTrainingIndex()); assertEquals(original2.getTrainingField(), copy2.getTrainingField()); assertEquals(original2.getPreferredNodeId(), copy2.getPreferredNodeId()); + assertEquals(original2.getVectorDataType(), copy2.getVectorDataType()); } public void testGetters() { @@ -117,7 +122,8 @@ public void testGetters() { trainingIndex, trainingField, preferredNode, - description + description, + VectorDataType.FLOAT ); trainingModelRequest.setMaximumVectorCount(maxVectorCount); @@ -156,7 +162,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -170,7 +177,8 @@ public void testValidation_invalid_modelIdAlreadyExists() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); @@ -211,7 +219,8 @@ public void testValidation_blocked_modelId() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return true to recognize that the modelId is in graveyard @@ -257,7 +266,8 @@ public void testValidation_invalid_invalidMethodContext() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -300,7 +310,8 @@ public void testValidation_invalid_trainingIndexDoesNotExist() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -346,7 +357,8 @@ public void testValidation_invalid_trainingFieldDoesNotExist() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -397,7 +409,8 @@ public void testValidation_invalid_trainingFieldNotKnnVector() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -452,7 +465,8 @@ public void testValidation_invalid_dimensionDoesNotMatch() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return null so that no exception is produced @@ -509,7 +523,8 @@ public void testValidation_invalid_preferredNodeDoesNotExist() { trainingIndex, trainingField, preferredNode, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -574,7 +589,8 @@ public void testValidation_invalid_descriptionToLong() { trainingIndex, trainingField, null, - description + description, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -618,7 +634,8 @@ public void testValidation_valid_trainingIndexBuiltFromMethod() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate @@ -655,7 +672,8 @@ public void testValidation_valid_trainingIndexBuiltFromModel() { trainingIndex, trainingField, null, - null + null, + VectorDataType.FLOAT ); // Mock the model dao to return metadata for modelId to recognize it is a duplicate diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java index 950ce1fd0..221f50fe3 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; @@ -72,7 +73,8 @@ public void testDoExecute() throws InterruptedException, ExecutionException, IOE trainingIndexName, trainingFieldName, null, - "test-detector" + "test-detector", + VectorDataType.FLOAT ); trainingModelRequest.setTrainingDataSizeInKB(estimateVectorSetSizeInKB(trainingDataCount, dimension)); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index 5be907ebd..e0d7c521c 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -17,6 +17,7 @@ import org.opensearch.knn.common.exception.DeleteModelException; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelGraveyard; @@ -210,7 +211,8 @@ public void testClusterManagerOperation_GetIndicesUsingModel() throws IOExceptio "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), modelBlob, modelId diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java index 3719d124a..2a016d98b 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataRequestTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -42,7 +43,8 @@ public void testStreams() throws IOException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest(modelId, isRemoveRequest, modelMetadata); @@ -67,7 +69,8 @@ public void testValidate() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); UpdateModelMetadataRequest updateModelMetadataRequest1 = new UpdateModelMetadataRequest("test", true, null); @@ -107,7 +110,8 @@ public void testGetModelMetadata() { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); UpdateModelMetadataRequest updateModelMetadataRequest = new UpdateModelMetadataRequest("test", true, modelMetadata); diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java index ab0e4f506..e16b720f0 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelMetadataTransportActionTests.java @@ -19,6 +19,7 @@ import org.opensearch.knn.KNNSingleNodeTestCase; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; @@ -68,7 +69,8 @@ public void testClusterManagerOperation() throws InterruptedException { "", "", "", - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ); // Get update transport action diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 06b96c57c..57ecb8323 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -18,6 +18,7 @@ import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -67,7 +68,8 @@ public void testGetModelId() { mock(NativeMemoryEntryContext.AnonymousEntryContext.class), 10, "", - "test-node" + "test-node", + VectorDataType.FLOAT ); assertEquals(modelId, trainingJob.getModelId()); @@ -96,7 +98,8 @@ public void testGetModel() { mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, description, - nodeAssignment + nodeAssignment, + VectorDataType.FLOAT ); Model model = new Model( @@ -109,7 +112,8 @@ public void testGetModel() { description, error, nodeAssignment, - MethodComponentContext.EMPTY + MethodComponentContext.EMPTY, + VectorDataType.FLOAT ), null, modelID @@ -183,8 +187,8 @@ public void testRun_success() throws IOException, ExecutionException { modelContext, dimension, "", - "test-node" - + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -262,8 +266,8 @@ public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionExcept modelContext, dimension, "", - - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -330,8 +334,8 @@ public void testRun_failure_onGetModelAnonymousAllocation() throws ExecutionExce modelContext, dimension, "", - - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -397,7 +401,8 @@ public void testRun_failure_closedTrainingDataAllocation() throws ExecutionExcep mock(NativeMemoryEntryContext.AnonymousEntryContext.class), dimension, "", - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); @@ -470,7 +475,8 @@ public void testRun_failure_notEnoughTrainingData() throws ExecutionException { modelContext, dimension, "", - "test-node" + "test-node", + VectorDataType.FLOAT ); trainingJob.run(); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 860cd2efa..1a17559a8 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -6,6 +6,7 @@ package org.opensearch.knn; import com.google.common.primitives.Floats; +import com.google.common.primitives.Ints; import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; @@ -1000,6 +1001,28 @@ public void bulkIngestRandomVectors(String indexName, String fieldName, int numV } + /** + * Bulk ingest random binary vectors + * @param indexName index name + * @param fieldName field name + * @param numVectors number of vectors + * @param dimension vector dimension + */ + public void bulkIngestRandomBinaryVectors(String indexName, String fieldName, int numVectors, int dimension) throws IOException { + if (dimension % 8 != 0) { + throw new IllegalArgumentException("Dimension must be a multiple of 8"); + } + for (int i = 0; i < numVectors; i++) { + int binaryDimension = dimension / 8; + int[] vector = new int[binaryDimension]; + for (int j = 0; j < binaryDimension; j++) { + vector[j] = randomIntBetween(-128, 127); + } + + addKnnDoc(indexName, String.valueOf(i + 1), fieldName, Ints.asList(vector).toArray()); + } + } + /** * Bulk ingest random vectors with nested field * @@ -1337,6 +1360,18 @@ public Response trainModel( return client().performRequest(request); } + public Response trainModel(String modelId, XContentBuilder builder) throws IOException { + if (modelId == null) { + modelId = ""; + } else { + modelId = "/" + modelId; + } + + Request request = new Request("POST", "/_plugins/_knn/models" + modelId + "/_train"); + request.setJsonEntity(builder.toString()); + return client().performRequest(request); + } + /** * Retrieve the model *