diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index 5a0910d9a..572dd19e2 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -27,6 +27,7 @@ #include #include #include +#include // Defines type of IDSelector enum FilterIdsSelectorType{ diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 1824ccfa4..388567939 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -496,14 +496,32 @@ std::vector generateRandomBinaryData(int dim, int numVectors) { return data; } +// Function to check if an index is binary +bool isBinaryIndex(const faiss::Index* index) { + return index->metric_type == faiss::METRIC_Canberra; +} + +bool isIndexBinaryIVF(faiss::Index * index) { + faiss::Index * candidateIndex = index; + if (auto indexIDMap = dynamic_cast(index)) { + candidateIndex = indexIDMap->index; + } + + if (auto indexBinaryIVF = dynamic_cast(candidateIndex)) { + return true; + } + + return false; +} + TEST(FaissBinaryIVFIndexTest, BasicIVFSearch) { // Dimension of the vectors, should be a multiple of 8. - int d = 256; + int d = 8; // Number of database vectors, training vectors, and query vectors - int nb = 1000; // Database vectors - int nt = 500; // Training vectors - int nq = 10; // Query vectors + int nb = 500; // Database vectors + int nt = 400; // Training vectors + int nq = 1; // Query vectors // Generate binary data for db, training, and queries std::vector db = generateRandomBinaryData(d, nb); @@ -514,7 +532,7 @@ TEST(FaissBinaryIVFIndexTest, BasicIVFSearch) { faiss::IndexBinaryFlat quantizer(d); // Number of clusters - int nlist = 100; + int nlist = 10; // Initializing index faiss::IndexBinaryIVF index(&quantizer, d, nlist); diff --git a/src/main/java/org/opensearch/knn/index/VectorDataType.java b/src/main/java/org/opensearch/knn/index/VectorDataType.java index 98b767f8d..c757fb024 100644 --- a/src/main/java/org/opensearch/knn/index/VectorDataType.java +++ b/src/main/java/org/opensearch/knn/index/VectorDataType.java @@ -62,6 +62,18 @@ public float[] getVectorFromBytesRef(BytesRef binaryValue) { return vectorSerializer.byteToFloatArray(byteStream); } + }, + BINARY("binary") { + + @Override + public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) { + throw new UnsupportedOperationException("Binary vector data type is not supported for lucene engine"); + } + + @Override + public float[] getVectorFromBytesRef(BytesRef binaryValue) { + throw new UnsupportedOperationException("Binary vector data type is not supported for lucene engine"); + } }; public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values()) diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java index 7e697fed7..d8bbc1860 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java @@ -133,6 +133,10 @@ public static class Builder extends ParametrizedFieldMapper.Builder { m -> toType(m).vectorDataType ); +// .setSerializer((n, c, o) -> { +// n.field(c, o.getValue()); +// }, m -> m.getValue()); + /** * modelId provides a way for a user to generate the underlying library indices from an already serialized * model template index. If this parameter is set, it will take precedence. This parameter is only relevant for @@ -268,7 +272,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { // Validates and throws exception if data_type field is set in the index mapping // using any VectorDataType (other than float, which is default) because other // VectorDataTypes are only supported for lucene engine. - validateVectorDataTypeWithEngine(vectorDataType); +// validateVectorDataTypeWithEngine(vectorDataType); return new MethodFieldMapper( name, @@ -319,7 +323,7 @@ public KNNVectorFieldMapper build(BuilderContext context) { // Validates and throws exception if index.knn is set to true in the index settings // using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper // and it only supports float VectorDataType - validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType); +// validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType); return new LegacyFieldMapper( name, @@ -588,7 +592,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s final byte[] array = bytesArrayOptional.get(); spaceType.validateVector(array); context.doc().addAll(getFieldsForByteVector(array, fieldType)); - } else if (VectorDataType.FLOAT == vectorDataType) { + } else if (VectorDataType.BINARY == vectorDataType) { Optional floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext); if (floatsArrayOptional.isEmpty()) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java index d67ffc73c..92a493848 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java +++ b/src/main/java/org/opensearch/knn/index/mapper/LegacyFieldMapper.java @@ -15,12 +15,7 @@ import org.opensearch.knn.index.util.IndexHyperParametersUtil; import org.opensearch.knn.index.util.KNNEngine; -import static org.opensearch.knn.common.KNNConstants.DIMENSION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION; -import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M; -import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; -import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; -import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.*; /** * Field mapper for original implementation. It defaults to using nmslib as the engine and retrieves parameters from index settings. @@ -63,6 +58,7 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper { this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension)); this.fieldType.putAttribute(SPACE_TYPE, spaceType); this.fieldType.putAttribute(KNN_ENGINE, KNNEngine.NMSLIB.getName()); +// this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, mappedFieldType.getVectorDataType().getValue()); // These are extra just for legacy this.fieldType.putAttribute(HNSW_ALGO_M, m); diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 7f14a2341..b5ddff1e2 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -14,6 +14,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; import org.opensearch.knn.index.IndexUtil; +import org.opensearch.knn.index.VectorDataType; import java.io.IOException; import java.util.Map; @@ -169,6 +170,7 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext indexParameters, int dimension, long trainVectorsPointer); + public static native byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, VectorDataType vectorDataType); /** *

diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 20c418819..cc732e394 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -12,6 +12,7 @@ package org.opensearch.knn.jni; import org.apache.commons.lang.ArrayUtils; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.util.KNNEngine; @@ -240,9 +241,9 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE * @param knnEngine engine to perform the training * @return bytes array of trained template index */ - public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) { + public static byte[] trainIndex(Map indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine, VectorDataType vectorDataType) { if (KNNEngine.FAISS == knnEngine) { - return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer); + return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer, vectorDataType); } throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName())); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java index 33b420e2c..4ada939b9 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelTransportAction.java @@ -11,11 +11,16 @@ package org.opensearch.knn.plugin.transport; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.MappingMetadata; +import org.opensearch.cluster.metadata.Metadata; import org.opensearch.core.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; @@ -26,8 +31,12 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.ExecutionException; +import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD; + /** * Transport action that trains a model and serializes it to model system index */ @@ -44,6 +53,32 @@ public TrainingModelTransportAction(TransportService transportService, ActionFil @Override protected void doExecute(Task task, TrainingModelRequest request, ActionListener listener) { + ClusterState clusterState = clusterService.state(); + Metadata metadata = clusterState.getMetadata(); + IndexMetadata indexMetadata = metadata.index(request.getTrainingIndex()); + MappingMetadata mappingMetadata = indexMetadata.mapping(); + Map mapping = mappingMetadata.getSourceAsMap(); + + Map properties = (Map) mappingMetadata.getSourceAsMap().get("properties"); + + VectorDataType vectorDataType; + + // check if mapping have data_type field +// if (properties.containsKey(VECTOR_DATA_TYPE_FIELD)) { +// vectorDataType = VectorDataType.valueOf((String) mapping.get(VECTOR_DATA_TYPE_FIELD)); +// } else { +// // if not, default to float +// vectorDataType = VectorDataType.FLOAT; +// } + + if (mapping.containsKey("properties") && + ((Map) mapping.get("properties")).containsKey(request.getTrainingField()) && + ((Map) ((Map) mapping.get("properties")).get(request.getTrainingField())).containsKey(VECTOR_DATA_TYPE_FIELD)) { + vectorDataType = VectorDataType.valueOf((String) ((Map) ((Map) mapping.get("properties")).get(request.getTrainingField())).get(VECTOR_DATA_TYPE_FIELD)); + } else { + vectorDataType = VectorDataType.FLOAT; + } + NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext( request.getTrainingDataSizeInKB(), request.getTrainingIndex(), @@ -51,7 +86,8 @@ protected void doExecute(Task task, TrainingModelRequest request, ActionListener NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(), clusterService, request.getMaximumVectorCount(), - request.getSearchSize() + request.getSearchSize(), + vectorDataType ); // Allocation representing size model will occupy in memory during training diff --git a/src/main/java/org/opensearch/knn/training/TrainingJob.java b/src/main/java/org/opensearch/knn/training/TrainingJob.java index aa2786c0a..bc27c0d78 100644 --- a/src/main/java/org/opensearch/knn/training/TrainingJob.java +++ b/src/main/java/org/opensearch/knn/training/TrainingJob.java @@ -186,7 +186,8 @@ public void run() { trainParameters, model.getModelMetadata().getDimension(), trainingDataAllocation.getMemoryAddress(), - model.getModelMetadata().getKnnEngine() + model.getModelMetadata().getKnnEngine(), + trainingDataEntryContext.getVectorDataType() ); // Once training finishes, update model