Skip to content

Commit

Permalink
Pass in vector data type in train api
Browse files Browse the repository at this point in the history
Signed-off-by: Junqiu Lei <[email protected]>
  • Loading branch information
junqiu-lei committed Jun 25, 2024
1 parent 2757338 commit 3266695
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 20 deletions.
1 change: 1 addition & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <jni.h>
#include <string>
#include <vector>
#include <faiss/IndexBinaryIVF.h>

// Defines type of IDSelector
enum FilterIdsSelectorType{
Expand Down
28 changes: 23 additions & 5 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,14 +496,32 @@ std::vector<uint8_t> generateRandomBinaryData(int dim, int numVectors) {
return data;
}

// Function to check if an index is binary
bool isBinaryIndex(const faiss::Index* index) {
return index->metric_type == faiss::METRIC_Canberra;
}

bool isIndexBinaryIVF(faiss::Index * index) {
faiss::Index * candidateIndex = index;
if (auto indexIDMap = dynamic_cast<faiss::IndexIDMap *>(index)) {
candidateIndex = indexIDMap->index;
}

if (auto indexBinaryIVF = dynamic_cast<faiss::IndexBinaryIVF *>(candidateIndex)) {
return true;
}

return false;
}

TEST(FaissBinaryIVFIndexTest, BasicIVFSearch) {
// Dimension of the vectors, should be a multiple of 8.
int d = 256;
int d = 8;

// Number of database vectors, training vectors, and query vectors
int nb = 1000; // Database vectors
int nt = 500; // Training vectors
int nq = 10; // Query vectors
int nb = 500; // Database vectors
int nt = 400; // Training vectors
int nq = 1; // Query vectors

// Generate binary data for db, training, and queries
std::vector<uint8_t> db = generateRandomBinaryData(d, nb);
Expand All @@ -514,7 +532,7 @@ TEST(FaissBinaryIVFIndexTest, BasicIVFSearch) {
faiss::IndexBinaryFlat quantizer(d);

// Number of clusters
int nlist = 100;
int nlist = 10;

// Initializing index
faiss::IndexBinaryIVF index(&quantizer, d, nlist);
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ public float[] getVectorFromBytesRef(BytesRef binaryValue) {
return vectorSerializer.byteToFloatArray(byteStream);
}

},
BINARY("binary") {

@Override
public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunction vectorSimilarityFunction) {
throw new UnsupportedOperationException("Binary vector data type is not supported for lucene engine");
}

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
throw new UnsupportedOperationException("Binary vector data type is not supported for lucene engine");
}
};

public static final String SUPPORTED_VECTOR_DATA_TYPES = Arrays.stream(VectorDataType.values())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
m -> toType(m).vectorDataType
);

// .setSerializer((n, c, o) -> {
// n.field(c, o.getValue());
// }, m -> m.getValue());

/**
* modelId provides a way for a user to generate the underlying library indices from an already serialized
* model template index. If this parameter is set, it will take precedence. This parameter is only relevant for
Expand Down Expand Up @@ -268,7 +272,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {
// Validates and throws exception if data_type field is set in the index mapping
// using any VectorDataType (other than float, which is default) because other
// VectorDataTypes are only supported for lucene engine.
validateVectorDataTypeWithEngine(vectorDataType);
// validateVectorDataTypeWithEngine(vectorDataType);

return new MethodFieldMapper(
name,
Expand Down Expand Up @@ -319,7 +323,7 @@ public KNNVectorFieldMapper build(BuilderContext context) {
// Validates and throws exception if index.knn is set to true in the index settings
// using any VectorDataType (other than float, which is default) because we are using NMSLIB engine for LegacyFieldMapper
// and it only supports float VectorDataType
validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType);
// validateVectorDataTypeWithKnnIndexSetting(context.indexSettings().getAsBoolean(KNN_INDEX, false), vectorDataType);

return new LegacyFieldMapper(
name,
Expand Down Expand Up @@ -588,7 +592,7 @@ protected void parseCreateField(ParseContext context, int dimension, SpaceType s
final byte[] array = bytesArrayOptional.get();
spaceType.validateVector(array);
context.doc().addAll(getFieldsForByteVector(array, fieldType));
} else if (VectorDataType.FLOAT == vectorDataType) {
} else if (VectorDataType.BINARY == vectorDataType) {
Optional<float[]> floatsArrayOptional = getFloatsFromContext(context, dimension, methodComponentContext);

if (floatsArrayOptional.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.index.util.KNNEngine;

import static org.opensearch.knn.common.KNNConstants.DIMENSION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_CONSTRUCTION;
import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_M;
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE;
import static org.opensearch.knn.common.KNNConstants.*;

/**
* Field mapper for original implementation. It defaults to using nmslib as the engine and retrieves parameters from index settings.
Expand Down Expand Up @@ -63,6 +58,7 @@ public class LegacyFieldMapper extends KNNVectorFieldMapper {
this.fieldType.putAttribute(DIMENSION, String.valueOf(dimension));
this.fieldType.putAttribute(SPACE_TYPE, spaceType);
this.fieldType.putAttribute(KNN_ENGINE, KNNEngine.NMSLIB.getName());
// this.fieldType.putAttribute(VECTOR_DATA_TYPE_FIELD, mappedFieldType.getVectorDataType().getValue());

// These are extra just for legacy
this.fieldType.putAttribute(HNSW_ALGO_M, m);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.VectorDataType;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -169,6 +170,7 @@ public static class TrainingDataEntryContext extends NativeMemoryEntryContext<Na
private final String trainFieldName;
private final int maxVectorCount;
private final int searchSize;
private final VectorDataType vectorDataType;

/**
* Constructor
Expand All @@ -188,7 +190,8 @@ public TrainingDataEntryContext(
NativeMemoryLoadStrategy.TrainingLoadStrategy trainingLoadStrategy,
ClusterService clusterService,
int maxVectorCount,
int searchSize
int searchSize,
VectorDataType vectorDataType
) {
super(generateKey(trainIndexName, trainFieldName));
this.size = size;
Expand All @@ -198,6 +201,7 @@ public TrainingDataEntryContext(
this.clusterService = clusterService;
this.maxVectorCount = maxVectorCount;
this.searchSize = searchSize;
this.vectorDataType = vectorDataType;
}

@Override
Expand Down Expand Up @@ -255,6 +259,15 @@ public ClusterService getClusterService() {
return clusterService;
}

/**
* Getter for vector data type.
*
* @return vector data type
*/
public VectorDataType getVectorDataType() {
return vectorDataType;
}

private static String generateKey(String trainIndexName, String trainFieldName) {
return KEY_PREFIX + trainIndexName + DELIMETER + trainFieldName;
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.jni;

import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.KNNQueryResult;
import org.opensearch.knn.index.util.KNNEngine;

Expand Down Expand Up @@ -176,7 +177,7 @@ public static native KNNQueryResult[] queryIndexWithFilter(
* @param trainVectorsPointer pointer to where training vectors are stored in native memory
* @return bytes array of trained template index
*/
public static native byte[] trainIndex(Map<String, Object> indexParameters, int dimension, long trainVectorsPointer);
public static native byte[] trainIndex(Map<String, Object> indexParameters, int dimension, long trainVectorsPointer, VectorDataType vectorDataType);

/**
* <p>
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
package org.opensearch.knn.jni;

import org.apache.commons.lang.ArrayUtils;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.KNNQueryResult;
import org.opensearch.knn.index.util.KNNEngine;

Expand Down Expand Up @@ -240,9 +241,9 @@ public static void freeSharedIndexState(long shareIndexStateAddr, KNNEngine knnE
* @param knnEngine engine to perform the training
* @return bytes array of trained template index
*/
public static byte[] trainIndex(Map<String, Object> indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine) {
public static byte[] trainIndex(Map<String, Object> indexParameters, int dimension, long trainVectorsPointer, KNNEngine knnEngine, VectorDataType vectorDataType) {
if (KNNEngine.FAISS == knnEngine) {
return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer);
return FaissService.trainIndex(indexParameters, dimension, trainVectorsPointer, vectorDataType);
}

throw new IllegalArgumentException(String.format("TrainIndex not supported for provided engine : %s", knnEngine.getName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@

package org.opensearch.knn.plugin.transport;

import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.MappingMetadata;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryEntryContext;
import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy;
Expand All @@ -26,8 +31,12 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

/**
* Transport action that trains a model and serializes it to model system index
*/
Expand All @@ -44,14 +53,41 @@ public TrainingModelTransportAction(TransportService transportService, ActionFil
@Override
protected void doExecute(Task task, TrainingModelRequest request, ActionListener<TrainingModelResponse> listener) {

ClusterState clusterState = clusterService.state();
Metadata metadata = clusterState.getMetadata();
IndexMetadata indexMetadata = metadata.index(request.getTrainingIndex());
MappingMetadata mappingMetadata = indexMetadata.mapping();
Map<String, Object> mapping = mappingMetadata.getSourceAsMap();

Map<String, Object> properties = (Map<String, Object>) mappingMetadata.getSourceAsMap().get("properties");

VectorDataType vectorDataType;

// check if mapping have data_type field
// if (properties.containsKey(VECTOR_DATA_TYPE_FIELD)) {
// vectorDataType = VectorDataType.valueOf((String) mapping.get(VECTOR_DATA_TYPE_FIELD));
// } else {
// // if not, default to float
// vectorDataType = VectorDataType.FLOAT;
// }

if (mapping.containsKey("properties") &&
((Map<String, Object>) mapping.get("properties")).containsKey(request.getTrainingField()) &&
((Map<String, Object>) ((Map<String, Object>) mapping.get("properties")).get(request.getTrainingField())).containsKey(VECTOR_DATA_TYPE_FIELD)) {
vectorDataType = VectorDataType.valueOf((String) ((Map<String, Object>) ((Map<String, Object>) mapping.get("properties")).get(request.getTrainingField())).get(VECTOR_DATA_TYPE_FIELD));
} else {
vectorDataType = VectorDataType.FLOAT;
}

NativeMemoryEntryContext.TrainingDataEntryContext trainingDataEntryContext = new NativeMemoryEntryContext.TrainingDataEntryContext(
request.getTrainingDataSizeInKB(),
request.getTrainingIndex(),
request.getTrainingField(),
NativeMemoryLoadStrategy.TrainingLoadStrategy.getInstance(),
clusterService,
request.getMaximumVectorCount(),
request.getSearchSize()
request.getSearchSize(),
vectorDataType
);

// Allocation representing size model will occupy in memory during training
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ public void run() {
trainParameters,
model.getModelMetadata().getDimension(),
trainingDataAllocation.getMemoryAddress(),
model.getModelMetadata().getKnnEngine()
model.getModelMetadata().getKnnEngine(),
trainingDataEntryContext.getVectorDataType()
);

// Once training finishes, update model
Expand Down

0 comments on commit 3266695

Please sign in to comment.