diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index e98c00197..5e30ad5a1 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -23,6 +23,7 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.jni.JNIService; import java.io.File; import java.util.Collections; @@ -263,4 +264,19 @@ public static boolean isVersionOnOrAfterMinRequiredVersion(Version version, Stri } return version.onOrAfter(minimalRequiredVersion); } + + /** + * Checks if index requires shared state + * + * @param knnEngine The knnEngine associated with the index + * @param modelId The modelId associated with the index + * @param indexAddr Address to check if loaded index requires shared state + * @return true if state can be shared; false otherwise + */ + public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String modelId, long indexAddr) { + if (StringUtils.isEmpty(modelId) || KNNEngine.FAISS != knnEngine) { + return false; + } + return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index fec93a734..71f6596dc 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -5,6 +5,9 @@ package org.opensearch.knn.index; +import com.google.common.annotations.VisibleForTesting; +import lombok.AllArgsConstructor; +import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FilterLeafReader; @@ -24,13 +27,14 @@ import java.io.IOException; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; -import java.util.Map; +import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.index.IndexUtil.getParametersAtLoading; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFilePrefix; @@ -41,8 +45,8 @@ */ @Log4j2 public class KNNIndexShard { - private IndexShard indexShard; - private NativeMemoryCacheManager nativeMemoryCacheManager; + private final IndexShard indexShard; + private final NativeMemoryCacheManager nativeMemoryCacheManager; private static final String INDEX_SHARD_CLEAR_CACHE_SEARCHER = "knn-clear-cache"; /** @@ -83,14 +87,19 @@ public String getIndexName() { public void warmup() throws IOException { log.info("[KNN] Warming up index: [{}]", getIndexName()); try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) { - getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> { + getAllEngineFileContexts(searcher.getIndexReader()).forEach((engineFileContext) -> { try { nativeMemoryCacheManager.get( new NativeMemoryEntryContext.IndexEntryContext( - key, + engineFileContext.getIndexPath(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(value, KNNEngine.getEngineNameFromPath(key), getIndexName()), - getIndexName() + getParametersAtLoading( + engineFileContext.getSpaceType(), + KNNEngine.getEngineNameFromPath(engineFileContext.getIndexPath()), + getIndexName() + ), + getIndexName(), + engineFileContext.getModelId() ), true ); @@ -117,7 +126,9 @@ public void clearCache() { indexAllocation.writeLock(); log.info("[KNN] Evicting index from cache: [{}]", indexName); try (Engine.Searcher searcher = indexShard.acquireSearcher(INDEX_SHARD_CLEAR_CACHE_SEARCHER)) { - getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> nativeMemoryCacheManager.invalidate(key)); + getAllEngineFileContexts(searcher.getIndexReader()).forEach( + (engineFileContext) -> nativeMemoryCacheManager.invalidate(engineFileContext.getIndexPath()) + ); } catch (IOException ex) { log.error("[KNN] Failed to evict index from cache: [{}]", indexName, ex); throw new RuntimeException(ex); @@ -128,22 +139,23 @@ public void clearCache() { } /** - * For the given shard, get all of its engine paths + * For the given shard, get all of its engine file context objects * - * @param indexReader IndexReader to read the file paths for the shard - * @return List of engine file Paths + * @param indexReader IndexReader to read the information for each segment in the shard + * @return List of engine contexts * @throws IOException Thrown when the SegmentReader is attempting to read the segments files */ - public Map getAllEnginePaths(IndexReader indexReader) throws IOException { - Map engineFiles = new HashMap<>(); + @VisibleForTesting + List getAllEngineFileContexts(IndexReader indexReader) throws IOException { + List engineFiles = new ArrayList<>(); for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { - engineFiles.putAll(getEnginePaths(indexReader, knnEngine)); + engineFiles.addAll(getEngineFileContexts(indexReader, knnEngine)); } return engineFiles; } - private Map getEnginePaths(IndexReader indexReader, KNNEngine knnEngine) throws IOException { - Map engineFiles = new HashMap<>(); + List getEngineFileContexts(IndexReader indexReader, KNNEngine knnEngine) throws IOException { + List engineFiles = new ArrayList<>(); for (LeafReaderContext leafReaderContext : indexReader.leaves()) { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); @@ -158,15 +170,17 @@ private Map getEnginePaths(IndexReader indexReader, KNNEngine // was L2. So, if Space Type is not present, just fall back to L2 String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); SpaceType spaceType = SpaceType.getSpace(spaceTypeName); + String modelId = fieldInfo.attributes().getOrDefault(MODEL_ID, null); - engineFiles.putAll( - getEnginePaths( + engineFiles.addAll( + getEngineFileContexts( reader.getSegmentInfo().files(), reader.getSegmentInfo().info.name, fieldInfo.name, fileExtension, shardPath, - spaceType + spaceType, + modelId ) ); } @@ -175,13 +189,15 @@ private Map getEnginePaths(IndexReader indexReader, KNNEngine return engineFiles; } - protected Map getEnginePaths( + @VisibleForTesting + List getEngineFileContexts( Collection files, String segmentName, String fieldName, String fileExtension, Path shardPath, - SpaceType spaceType + SpaceType spaceType, + String modelId ) { String prefix = buildEngineFilePrefix(segmentName); String suffix = buildEngineFileSuffix(fieldName, fileExtension); @@ -189,6 +205,16 @@ protected Map getEnginePaths( .filter(fileName -> fileName.startsWith(prefix)) .filter(fileName -> fileName.endsWith(suffix)) .map(fileName -> shardPath.resolve(fileName).toString()) - .collect(Collectors.toMap(fileName -> fileName, fileName -> spaceType)); + .map(fileName -> new EngineFileContext(spaceType, modelId, fileName)) + .collect(Collectors.toList()); + } + + @AllArgsConstructor + @Getter + @VisibleForTesting + static class EngineFileContext { + private final SpaceType spaceType; + private final String modelId; + private final String indexPath; } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 416980759..286e6265c 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -91,6 +91,7 @@ class IndexAllocation implements NativeMemoryAllocation { private final String openSearchIndexName; private final ReadWriteLock readWriteLock; private final WatcherHandle watcherHandle; + private final SharedIndexState sharedIndexState; /** * Constructor @@ -111,6 +112,31 @@ class IndexAllocation implements NativeMemoryAllocation { String indexPath, String openSearchIndexName, WatcherHandle watcherHandle + ) { + this(executorService, memoryAddress, size, knnEngine, indexPath, openSearchIndexName, watcherHandle, null); + } + + /** + * Constructor + * + * @param executorService Executor service used to close the allocation + * @param memoryAddress Pointer in memory to the index + * @param size Size this index consumes in kilobytes + * @param knnEngine KNNEngine associated with the index allocation + * @param indexPath File path to index + * @param openSearchIndexName Name of OpenSearch index this index is associated with + * @param watcherHandle Handle for watching index file + * @param sharedIndexState Shared index state. If not shared state present, pass null. + */ + IndexAllocation( + ExecutorService executorService, + long memoryAddress, + int size, + KNNEngine knnEngine, + String indexPath, + String openSearchIndexName, + WatcherHandle watcherHandle, + SharedIndexState sharedIndexState ) { this.executor = executorService; this.closed = false; @@ -121,6 +147,7 @@ class IndexAllocation implements NativeMemoryAllocation { this.readWriteLock = new ReentrantReadWriteLock(); this.size = size; this.watcherHandle = watcherHandle; + this.sharedIndexState = sharedIndexState; } @Override @@ -145,6 +172,10 @@ private void cleanup() { if (memoryAddress != 0) { JNIService.free(memoryAddress, knnEngine); } + + if (sharedIndexState != null) { + SharedIndexStateManager.getInstance().release(sharedIndexState); + } } @Override diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 13f8dae10..7f14a2341 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index.memory; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; import org.opensearch.knn.index.IndexUtil; import java.io.IOException; @@ -63,6 +64,8 @@ public static class IndexEntryContext extends NativeMemoryEntryContext parameters; + @Nullable + private final String modelId; /** * Constructor @@ -77,11 +80,31 @@ public IndexEntryContext( NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy, Map parameters, String openSearchIndexName + ) { + this(indexPath, indexLoadStrategy, parameters, openSearchIndexName, null); + } + + /** + * Constructor + * + * @param indexPath path to index file. Also used as key in cache. + * @param indexLoadStrategy strategy to load index into memory + * @param parameters load time parameters + * @param openSearchIndexName opensearch index associated with index + * @param modelId model to be loaded. If none available, pass null + */ + public IndexEntryContext( + String indexPath, + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy, + Map parameters, + String openSearchIndexName, + String modelId ) { super(indexPath); this.indexLoadStrategy = indexLoadStrategy; this.openSearchIndexName = openSearchIndexName; this.parameters = parameters; + this.modelId = modelId; } @Override @@ -112,6 +135,15 @@ public Map getParameters() { return parameters; } + /** + * Getter + * + * @return return model ID for the index. null if no model is in use + */ + public String getModelId() { + return modelId; + } + private static class IndexSizeCalculator implements Function { static IndexSizeCalculator INSTANCE = new IndexSizeCalculator(); diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java index 568cc892b..cb7dafdfc 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java @@ -11,7 +11,9 @@ package org.opensearch.knn.index.memory; +import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.training.TrainingDataConsumer; @@ -41,6 +43,7 @@ public interface NativeMemoryLoadStrategy, @@ -92,17 +95,25 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde fileWatcher.init(); KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(indexPath.toString()); - long memoryAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), knnEngine); - final WatcherHandle watcherHandle = resourceWatcherService.add(fileWatcher); + long indexAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), knnEngine); + SharedIndexState sharedIndexState = null; + String modelId = indexEntryContext.getModelId(); + if (IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, indexAddress)) { + log.info("Index with model: \"{}\" requires shared state. Retrieving shared state.", modelId); + sharedIndexState = SharedIndexStateManager.getInstance().get(indexAddress, modelId, knnEngine); + JNIService.setSharedIndexState(indexAddress, sharedIndexState.getSharedIndexStateAddress(), knnEngine); + } + final WatcherHandle watcherHandle = resourceWatcherService.add(fileWatcher); return new NativeMemoryAllocation.IndexAllocation( executor, - memoryAddress, + indexAddress, indexEntryContext.calculateSizeInKB(), knnEngine, indexPath.toString(), indexEntryContext.getOpenSearchIndexName(), - watcherHandle + watcherHandle, + sharedIndexState ); } diff --git a/src/main/java/org/opensearch/knn/index/memory/SharedIndexState.java b/src/main/java/org/opensearch/knn/index/memory/SharedIndexState.java new file mode 100644 index 000000000..2ffadb22e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/SharedIndexState.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.knn.index.util.KNNEngine; + +/** + * Class stores information about the shared memory allocations between loaded native indices. + */ +@RequiredArgsConstructor +@Getter +public class SharedIndexState { + private final long sharedIndexStateAddress; + private final String modelId; + private final KNNEngine knnEngine; +} diff --git a/src/main/java/org/opensearch/knn/index/memory/SharedIndexStateManager.java b/src/main/java/org/opensearch/knn/index/memory/SharedIndexStateManager.java new file mode 100644 index 000000000..bdf7f14fc --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/SharedIndexStateManager.java @@ -0,0 +1,145 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory; + +import com.google.common.annotations.VisibleForTesting; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * Class manages allocations that can be shared between native indices. No locking is required. + * Once a caller obtain an instance of a {@link org.opensearch.knn.index.memory.SharedIndexState}, it is guaranteed to + * be valid until it is returned. {@link org.opensearch.knn.index.memory.SharedIndexState} are reference counted + * internally. Once the reference count goes to 0, it will be freed. + */ +@Log4j2 +class SharedIndexStateManager { + // Map storing the shared index state with key being the modelId. + private final ConcurrentHashMap sharedIndexStateCache; + private final ReadWriteLock readWriteLock; + + private static SharedIndexStateManager INSTANCE; + + // TODO: Going to refactor away from doing this in the future. For now, keeping for simplicity. + public static synchronized SharedIndexStateManager getInstance() { + if (INSTANCE == null) { + INSTANCE = new SharedIndexStateManager(); + } + return INSTANCE; + } + + /** + * Constructor + */ + @VisibleForTesting + SharedIndexStateManager() { + this.sharedIndexStateCache = new ConcurrentHashMap<>(); + this.readWriteLock = new ReentrantReadWriteLock(); + } + + /** + * Return a {@link SharedIndexState} associated with the key. If no value exists, it will attempt to create it. + * Once returned, the {@link SharedIndexState} will be valid until + * {@link SharedIndexStateManager#release(SharedIndexState)} is called. Caller must ensure that this is + * called after it is done using it. + * + * In order to create the shared state, it will use the indexAddress passed in to create the shared state from + * using {@link org.opensearch.knn.jni.JNIService#initSharedIndexState(long, KNNEngine)}. + * + * @param indexAddress Address of index to initialize the shared state from + * @param knnEngine engine index belongs to + * @return ShareModelContext + */ + public SharedIndexState get(long indexAddress, String modelId, KNNEngine knnEngine) { + this.readWriteLock.readLock().lock(); + // This can be done safely with readLock because the ConcurrentHasMap.computeIfAbsent guarantees: + // + // "If the specified key is not already associated with a value, attempts to compute its value using the given + // mapping function and enters it into this map unless null. The entire method invocation is performed + // atomically, so the function is applied at most once per key. Some attempted update operations on this map + // by other threads may be blocked while computation is in progress, so the computation should be short and + // simple, and must not attempt to update any other mappings of this map." + // + // Ref: + // https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/ConcurrentHashMap.html#computeIfAbsent-K-java.util.function.Function- + SharedIndexStateEntry entry = sharedIndexStateCache.computeIfAbsent(modelId, m -> { + log.info("Loading entry to shared index state cache for model {}", modelId); + long sharedIndexStateAddress = JNIService.initSharedIndexState(indexAddress, knnEngine); + return new SharedIndexStateEntry(new SharedIndexState(sharedIndexStateAddress, modelId, knnEngine)); + }); + entry.incRef(); + this.readWriteLock.readLock().unlock(); + return entry.getSharedIndexState(); + } + + /** + * Indicate that the {@link SharedIndexState} is no longer being used. If nothing else is using it, it will be + * removed from the cache and evicted. + * + * After calling this method, {@link SharedIndexState} should no longer be used by calling thread. + * + * @param sharedIndexState to return to the system. + */ + public void release(SharedIndexState sharedIndexState) { + this.readWriteLock.writeLock().lock(); + + if (!sharedIndexStateCache.containsKey(sharedIndexState.getModelId())) { + // This should not happen. Will log the error and return to prevent crash + log.error("Attempting to evict model from cache but it is not present: {}", sharedIndexState.getModelId()); + this.readWriteLock.writeLock().unlock(); + return; + } + + long refCount = sharedIndexStateCache.get(sharedIndexState.getModelId()).decRef(); + if (refCount <= 0) { + log.info("Evicting entry from shared index state cache for key {}", sharedIndexState.getModelId()); + sharedIndexStateCache.remove(sharedIndexState.getModelId()); + JNIService.freeSharedIndexState(sharedIndexState.getSharedIndexStateAddress(), sharedIndexState.getKnnEngine()); + } + this.readWriteLock.writeLock().unlock(); + } + + private static final class SharedIndexStateEntry { + @Getter + private final SharedIndexState sharedIndexState; + private final AtomicLong referenceCount; + + /** + * Constructor + * + * @param sharedIndexState sharedIndexStateContext being wrapped + */ + private SharedIndexStateEntry(SharedIndexState sharedIndexState) { + this.sharedIndexState = sharedIndexState; + this.referenceCount = new AtomicLong(0); + } + + /** + * Increases reference count by 1 + * + * @return ++referenceCount + */ + private long incRef() { + return referenceCount.incrementAndGet(); + } + + /** + * Decrease reference count by 1 + * + * @return --referenceCount + */ + private long decRef() { + return referenceCount.decrementAndGet(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 7e2fa19cc..1c4d0a646 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -256,7 +256,8 @@ private Map doANNSearch(final LeafReaderContext context, final B indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), - knnQuery.getIndexName() + knnQuery.getIndexName(), + modelId ), true ); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 9da067fde..4bf4e73a6 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -39,6 +39,7 @@ import java.util.TreeMap; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -50,6 +51,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.NAME; @@ -493,6 +495,116 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( fail("Graphs are not getting evicted"); } + /** + * This test confirms that sharing index state for IVFPQ-l2 indices functions properly. The main functionality that + * needs to be confirmed is that once an index gets deleted, it will not cause a failure for the non-deleted index. + * + * The workflow will be: + * 1. Create a model + * 2. Create two indices index from the model + * 3. Load the native index files from the first index + * 4. Assert search works + * 5. Load the native index files (which will reuse the shared state from the initial index) + * 6. Assert search works on the second index + * 7. Delete the first index and wait + * 8. Assert search works on the second index + */ + @SneakyThrows + public void testSharedIndexState_whenOneIndexDeleted_thenSecondIndexIsStillSearchable() { + String firstIndexName = "test-index-1"; + String secondIndexName = "test-index-2"; + String trainingIndexName = "training-index"; + + String modelId = "test-model"; + String modelDescription = "ivfpql2 model for testing shared state"; + + int dimension = testData.indexData.vectors[0].length; + SpaceType spaceType = SpaceType.L2; + int ivfNlist = 4; + int ivfNprobes = 4; + int pqCodeSize = 8; + int pqM = 1; + int docCount = 100; + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ + int trainingDataCount = 256; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NPROBES, ivfNprobes) + .field(METHOD_PARAMETER_NLIST, ivfNlist) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqCodeSize) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqM) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + createBasicKnnIndex(trainingIndexName, FIELD_NAME, dimension); + ingestDataAndTrainModel(modelId, trainingIndexName, FIELD_NAME, dimension, modelDescription, in, trainingDataCount); + assertTrainingSucceeds(modelId, 360, 1000); + + createIndexFromModelAndIngestDocuments(firstIndexName, modelId, docCount); + createIndexFromModelAndIngestDocuments(secondIndexName, modelId, docCount); + + doKnnWarmup(List.of(firstIndexName)); + validateSearchWorkflow(firstIndexName, testData.queries, 10); + doKnnWarmup(List.of(secondIndexName)); + validateSearchWorkflow(secondIndexName, testData.queries, 10); + deleteKNNIndex(firstIndexName); + // wait for all index files to be cleaned up from original index. empirically determined to take 25 seconds. + // will give 15 second buffer from that + Thread.sleep(1000 * 45); + validateSearchWorkflow(secondIndexName, testData.queries, 10); + deleteModel(modelId); + } + + @SneakyThrows + private void createIndexFromModelAndIngestDocuments(String indexName, String modelId, int docCount) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + for (int i = 0; i < Math.min(testData.indexData.docs.length, docCount); i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + refreshAllNonSystemIndices(); + assertEquals(Math.min(testData.indexData.docs.length, docCount), getDocCount(indexName)); + } + + @SneakyThrows + private void validateSearchWorkflow(String indexName, float[][] queries, int k) { + for (float[] query : queries) { + Response response = searchKNNIndex(indexName, new KNNQueryBuilder(FIELD_NAME, query, k), k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + assertEquals(k, knnResults.size()); + } + } + public void testDocUpdate() throws IOException { String indexName = "test-index-1"; String fieldName = "test-field-1"; diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index c8b29b6ef..0eb136687 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -12,6 +12,8 @@ package org.opensearch.knn.index; import com.google.common.collect.ImmutableMap; +import org.junit.BeforeClass; +import org.mockito.MockedStatic; import org.opensearch.Version; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -24,12 +26,16 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.jni.JNIService; import java.util.Map; import java.util.Objects; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; @@ -37,6 +43,15 @@ import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_EF_SEARCH; public class IndexUtilTests extends KNNTestCase { + + private static MockedStatic jniServiceMockedStatic; + private static final long TEST_INDEX_ADDRESS = 0; + + @BeforeClass + public static void setUpClass() { + jniServiceMockedStatic = mockStatic(JNIService.class); + } + public void testGetLoadParameters() { // Test faiss to ensure that space type gets set properly SpaceType spaceType1 = SpaceType.COSINESIMIL; @@ -206,4 +221,30 @@ public void testValidateKnnField_EmptyIndexMetadata() { assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;")); } + + public void testIsShareableStateContainedInIndex_whenIndexNotModelBased_thenReturnFalse() { + String modelId = null; + KNNEngine knnEngine = KNNEngine.FAISS; + assertFalse(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); + } + + public void testIsShareableStateContainedInIndex_whenEngineIsNotFaiss_thenReturnFalse() { + String modelId = "test-model"; + KNNEngine knnEngine = KNNEngine.NMSLIB; + assertFalse(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); + } + + public void testIsShareableStateContainedInIndex_whenFaissHNSWIsUsed_thenReturnFalse() { + jniServiceMockedStatic.when(() -> JNIService.isSharedIndexStateRequired(anyLong(), any())).thenReturn(false); + String modelId = "test-model"; + KNNEngine knnEngine = KNNEngine.FAISS; + assertFalse(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); + } + + public void testIsShareableStateContainedInIndex_whenJNIIsSharedIndexStateRequiredIsTrue_thenReturnTrue() { + jniServiceMockedStatic.when(() -> JNIService.isSharedIndexStateRequired(anyLong(), any())).thenReturn(true); + String modelId = "test-model"; + KNNEngine knnEngine = KNNEngine.FAISS; + assertTrue(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java index e476969e1..7b6f96d5a 100644 --- a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java @@ -17,9 +17,7 @@ import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -98,38 +96,35 @@ public void testWarmup_shardNotPresentInCache() throws InterruptedException, Exe assertEquals(2, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName).get(GRAPH_COUNT)); } - public void testGetHNSWPaths() throws IOException, ExecutionException, InterruptedException { + public void testGetAllEngineFileContexts() throws IOException, ExecutionException, InterruptedException { IndexService indexService = createKNNIndex(testIndexName); createKnnIndexMapping(testIndexName, testFieldName, dimensions); - IndexShard indexShard; - KNNIndexShard knnIndexShard; - Engine.Searcher searcher; - Map hnswPaths; - indexShard = indexService.iterator().next(); - knnIndexShard = new KNNIndexShard(indexShard); + IndexShard indexShard = indexService.iterator().next(); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); - searcher = indexShard.acquireSearcher("test-hnsw-paths-1"); - hnswPaths = knnIndexShard.getAllEnginePaths(searcher.getIndexReader()); - assertEquals(0, hnswPaths.size()); + Engine.Searcher searcher = indexShard.acquireSearcher("test-hnsw-paths-1"); + List engineFileContexts = knnIndexShard.getAllEngineFileContexts(searcher.getIndexReader()); + assertEquals(0, engineFileContexts.size()); searcher.close(); addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 2.5F, 3.5F }); searcher = indexShard.acquireSearcher("test-hnsw-paths-2"); - hnswPaths = knnIndexShard.getAllEnginePaths(searcher.getIndexReader()); - assertEquals(1, hnswPaths.size()); - List paths = new ArrayList<>(hnswPaths.keySet()); + engineFileContexts = knnIndexShard.getAllEngineFileContexts(searcher.getIndexReader()); + assertEquals(1, engineFileContexts.size()); + List paths = engineFileContexts.stream().map(KNNIndexShard.EngineFileContext::getIndexPath).collect(Collectors.toList()); assertTrue(paths.get(0).contains("hnsw") || paths.get(0).contains("hnswc")); searcher.close(); } - public void testGetEnginePaths() { + public void testGetEngineFileContexts() { // Check that the correct engine paths are being returned by the KNNIndexShard String segmentName = "_0"; String fieldName = "test_field"; String fileExt = ".test"; SpaceType spaceType = SpaceType.L2; + String modelId = "test-model"; Set includedFileNames = ImmutableSet.of( String.format("%s_111_%s%s", segmentName, fieldName, fileExt), @@ -148,10 +143,18 @@ public void testGetEnginePaths() { KNNIndexShard knnIndexShard = new KNNIndexShard(null); Path path = Paths.get(""); - Map included = knnIndexShard.getEnginePaths(files, segmentName, fieldName, fileExt, path, spaceType); + List included = knnIndexShard.getEngineFileContexts( + files, + segmentName, + fieldName, + fileExt, + path, + spaceType, + modelId + ); assertEquals(includedFileNames.size(), included.size()); - included.keySet().forEach(o -> assertTrue(includedFileNames.contains(o))); + included.stream().map(KNNIndexShard.EngineFileContext::getIndexPath).forEach(o -> assertTrue(includedFileNames.contains(o))); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateManagerTests.java new file mode 100644 index 000000000..daf02c611 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateManagerTests.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.memory; + +import org.junit.BeforeClass; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import static org.mockito.Mockito.mockStatic; + +public class SharedIndexStateManagerTests extends KNNTestCase { + private static MockedStatic jniServiceMockedStatic; + private final static long TEST_SHARED_TABLE_ADDRESS = 123; + private final static long TEST_INDEX_ADDRESS = 1234; + private final static String TEST_MODEL_ID = "test-model-id"; + private final static KNNEngine TEST_KNN_ENGINE = KNNEngine.DEFAULT; + + @BeforeClass + public static void setUpClass() { + jniServiceMockedStatic = mockStatic(JNIService.class); + jniServiceMockedStatic.when(() -> JNIService.freeSharedIndexState(TEST_SHARED_TABLE_ADDRESS, TEST_KNN_ENGINE)) + .then(invocation -> null); + jniServiceMockedStatic.when(() -> JNIService.initSharedIndexState(TEST_INDEX_ADDRESS, TEST_KNN_ENGINE)) + .thenReturn(TEST_SHARED_TABLE_ADDRESS); + } + + public void testGet_whenNormalWorkfloatApplied_thenSucceed() { + SharedIndexStateManager sharedIndexStateManager = new SharedIndexStateManager(); + SharedIndexState firstSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + assertEquals(TEST_SHARED_TABLE_ADDRESS, firstSharedIndexStateRetrieved.getSharedIndexStateAddress()); + assertEquals(TEST_MODEL_ID, firstSharedIndexStateRetrieved.getModelId()); + assertEquals(TEST_KNN_ENGINE, firstSharedIndexStateRetrieved.getKnnEngine()); + + SharedIndexState secondSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + assertEquals(TEST_SHARED_TABLE_ADDRESS, secondSharedIndexStateRetrieved.getSharedIndexStateAddress()); + assertEquals(TEST_MODEL_ID, secondSharedIndexStateRetrieved.getModelId()); + assertEquals(TEST_KNN_ENGINE, secondSharedIndexStateRetrieved.getKnnEngine()); + } + + public void testRelease_whenNormalWorkflowApplied_thenSucceed() { + SharedIndexStateManager sharedIndexStateManager = new SharedIndexStateManager(); + SharedIndexState firstSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + SharedIndexState secondSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + + sharedIndexStateManager.release(firstSharedIndexStateRetrieved); + jniServiceMockedStatic.verify(() -> JNIService.freeSharedIndexState(TEST_SHARED_TABLE_ADDRESS, TEST_KNN_ENGINE), Mockito.times(0)); + sharedIndexStateManager.release(secondSharedIndexStateRetrieved); + jniServiceMockedStatic.verify(() -> JNIService.freeSharedIndexState(TEST_SHARED_TABLE_ADDRESS, TEST_KNN_ENGINE), Mockito.times(1)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateTests.java b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateTests.java new file mode 100644 index 000000000..cddbec5c0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateTests.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.memory; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.util.KNNEngine; + +public class SharedIndexStateTests extends KNNTestCase { + + private static final String TEST_MODEL_ID = "test-model"; + private static final long TEST_SHARED_INDEX_STATE_ADDRESS = 22L; + private static final KNNEngine TEST_KNN_ENGINE = KNNEngine.DEFAULT; + + public void testSharedIndexState() { + SharedIndexState sharedIndexState = new SharedIndexState(TEST_SHARED_INDEX_STATE_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + assertEquals(TEST_MODEL_ID, sharedIndexState.getModelId()); + assertEquals(TEST_SHARED_INDEX_STATE_ADDRESS, sharedIndexState.getSharedIndexStateAddress()); + assertEquals(TEST_KNN_ENGINE, sharedIndexState.getKnnEngine()); + } +} diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index 889f3916f..c90afaa62 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -11,54 +11,520 @@ package org.opensearch.knn.recall; +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.TestUtils; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.util.KNNEngine; + import java.util.List; +import java.util.Map; import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_EF_SEARCH; import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY; import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; +/** + * Tests confirm that for the different supported configurations, recall is sound. The recall thresholds are + * conservatively and empirically determined to prevent flakiness. + * + * This test suite can take a long time to run. The primary reason is that training can take a long time for PQ. + * The parameters for PQ have been reduced significantly, but it still takes time. + */ public class RecallTestsIT extends KNNRestTestCase { - private final String testFieldName = "test-field"; - private final int dimensions = 50; - private final int docCount = 10000; - private final int queryCount = 100; - private final int k = 5; - private final double expRecallValue = 1.0; - - public void testRecallL2StandardData() throws Exception { - String testIndexStandard = "test-index-standard"; - - addDocs(testIndexStandard, testFieldName, dimensions, docCount, true); - float[][] indexVectors = getIndexVectorsFromIndex(testIndexStandard, testFieldName, docCount, dimensions); - float[][] queryVectors = TestUtils.getQueryVectors(queryCount, dimensions, docCount, true); - List> groundTruthValues = TestUtils.computeGroundTruthValues(indexVectors, queryVectors, SpaceType.L2, k); - List> searchResults = bulkSearch(testIndexStandard, testFieldName, queryVectors, k); - double recallValue = TestUtils.calculateRecallValue(searchResults, groundTruthValues, k); - assertEquals(expRecallValue, recallValue, 0.2); + private static final String PROPERTIES_FIELD = "properties"; + private final static String TEST_INDEX_PREFIX_NAME = "test_index"; + private final static String TEST_FIELD_NAME = "test_field"; + private final static String TRAIN_INDEX_NAME = "train_index"; + private final static String TRAIN_FIELD_NAME = "train_field"; + private final static String TEST_MODEL_ID = "test_model_id"; + private final static int TEST_DIMENSION = 32; + private final static int DOC_COUNT = 500; + private final static int QUERY_COUNT = 100; + private final static int TEST_K = 100; + private final static double PERFECT_RECALL = 1.0; + private final static int SHARD_COUNT = 1; + private final static int REPLICA_COUNT = 0; + private final static int MAX_SEGMENT_COUNT = 10; + + // Standard algorithm parameters + private final static int HNSW_M = 16; + private final static int HNSW_EF_CONSTRUCTION = 100; + private final static int HNSW_EF_SEARCH = TEST_K; // For consistency with lucene + private final static int IVF_NLIST = 4; + private final static int IVF_NPROBES = IVF_NLIST; // This equates to essentially a brute force search + private final static int PQ_CODE_SIZE = 8; // This is low and going to produce bad recall, but reduces build time + private final static int PQ_M = TEST_DIMENSION / 8; // Will give low recall, but required for test time + + // Setup ground truth for all tests once + private final static float[][] INDEX_VECTORS = TestUtils.getIndexVectors(DOC_COUNT, TEST_DIMENSION, true); + private final static float[][] QUERY_VECTORS = TestUtils.getQueryVectors(QUERY_COUNT, TEST_DIMENSION, DOC_COUNT, true); + private final static Map>> GROUND_TRUTH = Map.of( + SpaceType.L2, + TestUtils.computeGroundTruthValues(INDEX_VECTORS, QUERY_VECTORS, SpaceType.L2, TEST_K), + SpaceType.COSINESIMIL, + TestUtils.computeGroundTruthValues(INDEX_VECTORS, QUERY_VECTORS, SpaceType.COSINESIMIL, TEST_K), + SpaceType.INNER_PRODUCT, + TestUtils.computeGroundTruthValues(INDEX_VECTORS, QUERY_VECTORS, SpaceType.INNER_PRODUCT, TEST_K) + ); + + @SneakyThrows + @Before + public void setupClusterSettings() { + updateClusterSettings(KNN_ALGO_PARAM_INDEX_THREAD_QTY, 2); + updateClusterSettings(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED, true); + } + + /** + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"nmslib", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH} + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenNmslibHnswFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.NMSLIB, spaceType); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .startObject(KNN_METHOD) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .field(METHOD_PARAMETER_M, HNSW_M) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs( + indexName, + TEST_FIELD_NAME, + Settings.builder() + .put("number_of_shards", SHARD_COUNT) + .put("number_of_replicas", REPLICA_COUNT) + .put("index.knn", true) + .put(KNN_ALGO_PARAM_EF_SEARCH, HNSW_EF_SEARCH) + .build(), + builder.toString() + ); + assertRecall(indexName, spaceType, 0.25f); + } } - public void testRecallL2RandomData() throws Exception { - String testIndexRandom = "test-index-random"; + /** + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"lucene", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION} + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenLuceneHnswFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.COSINESIMIL); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.LUCENE, spaceType); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .startObject(KNN_METHOD) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.LUCENE.getName()) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .field(METHOD_PARAMETER_M, HNSW_M) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), builder.toString()); + assertRecall(indexName, spaceType, 0.25f); + } + } - addDocs(testIndexRandom, testFieldName, dimensions, docCount, false); - float[][] indexVectors = getIndexVectorsFromIndex(testIndexRandom, testFieldName, docCount, dimensions); - float[][] queryVectors = TestUtils.getQueryVectors(queryCount, dimensions, docCount, false); - List> groundTruthValues = TestUtils.computeGroundTruthValues(indexVectors, queryVectors, SpaceType.L2, k); - List> searchResults = bulkSearch(testIndexRandom, testFieldName, queryVectors, k); - double recallValue = TestUtils.calculateRecallValue(searchResults, groundTruthValues, k); - assertEquals(expRecallValue, recallValue, 0.2); + /** + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {TEST_DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH}, + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissHnswFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .startObject(KNN_METHOD) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .field(METHOD_PARAMETER_M, HNSW_M) + .field(METHOD_PARAMETER_EF_SEARCH, HNSW_EF_SEARCH) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), builder.toString()); + assertRecall(indexName, spaceType, 0.25f); + } } - private void addDocs(String testIndex, String testField, int dimensions, int docCount, boolean isStandard) throws Exception { - createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(testField, dimensions)); + /** + * Train context: + * { + * "method": { + * "name":"ivf", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "nlist":{IVF_NLIST}, + * "nprobes": {IVF_NPROBES} + * } + * } + * } + * + * Index Mapping: + * { + * "properties": { + * { + * "type": "knn_vector", + * "model_id": {MODEL_ID} + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissIVFFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + setupTrainingIndex(); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); - updateClusterSettings(KNN_ALGO_PARAM_INDEX_THREAD_QTY, 2); - updateClusterSettings(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED, true); + // Train the model + XContentBuilder trainingBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, IVF_NLIST) + .field(METHOD_PARAMETER_NPROBES, IVF_NPROBES) + .endObject() + .endObject(); + trainModel( + TEST_MODEL_ID, + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + TEST_DIMENSION, + xContentBuilderToMap(trainingBuilder), + String.format("%s-%s", KNNEngine.FAISS.getName(), spaceType.getValue()) + ); + assertTrainingSucceeds(TEST_MODEL_ID, 100, 1000 * 5); + + // Build the index + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); + assertRecall(indexName, spaceType, 0.25f); + + // Delete the model + deleteModel(TEST_MODEL_ID); + } + } + + /** + * Train context: + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {TEST_DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH}, + * } + * } + * } + * } + * } + * + * Index Mapping: + * { + * "properties": { + * { + * "type": "knn_vector", + * "model_id": {MODEL_ID} + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissIVFPQFP32_thenRecallAbove50percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + setupTrainingIndex(); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); - bulkAddKnnDocs(testIndex, testField, TestUtils.getIndexVectors(docCount, dimensions, isStandard), docCount); + // Train the model + XContentBuilder trainingBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, IVF_NLIST) + .field(METHOD_PARAMETER_NPROBES, IVF_NPROBES) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, PQ_CODE_SIZE) + .field(ENCODER_PARAMETER_PQ_M, PQ_M) + .endObject() + .endObject() + .endObject() + .endObject(); + trainModel( + TEST_MODEL_ID, + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + TEST_DIMENSION, + xContentBuilderToMap(trainingBuilder), + String.format("%s-%s", KNNEngine.FAISS.getName(), spaceType.getValue()) + ); + assertTrainingSucceeds(TEST_MODEL_ID, 100, 1000 * 5); + + // Build the index + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); + assertRecall(indexName, spaceType, 0.5f); + + // Delete the model + deleteModel(TEST_MODEL_ID); + } + } + + /** + * Train context: + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {TEST_DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH}, + * } + * } + * } + * } + * } + * + * Index Mapping: + * { + * "properties": { + * { + * "type": "knn_vector", + * "model_id": {MODEL_ID} + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissHNSWPQFP32_thenRecallAbove50percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + setupTrainingIndex(); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); + + // Train the model + XContentBuilder trainingBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, HNSW_M) + .field(METHOD_PARAMETER_EF_SEARCH, HNSW_EF_SEARCH) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, PQ_CODE_SIZE) + .field(ENCODER_PARAMETER_PQ_M, PQ_M) + .endObject() + .endObject() + .endObject() + .endObject(); + trainModel( + TEST_MODEL_ID, + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + TEST_DIMENSION, + xContentBuilderToMap(trainingBuilder), + String.format("%s-%s", KNNEngine.FAISS.getName(), spaceType.getValue()) + ); + assertTrainingSucceeds(TEST_MODEL_ID, 100, 1000 * 5); + + // Build the index + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); + assertRecall(indexName, spaceType, 0.5f); + + // Delete the model + deleteModel(TEST_MODEL_ID); + } + } + + @SneakyThrows + private void assertRecall(String testIndexName, SpaceType spaceType, float acceptableRecallFromPerfect) { + List> searchResults = bulkSearch(testIndexName, TEST_FIELD_NAME, QUERY_VECTORS, TEST_K); + double recallValue = TestUtils.calculateRecallValue(searchResults, GROUND_TRUTH.get(spaceType), TEST_K); + logger.info("Recall value = {}", recallValue); + assertEquals(PERFECT_RECALL, recallValue, acceptableRecallFromPerfect); } + private String createIndexName(KNNEngine knnEngine, SpaceType spaceType) { + return String.format("%s_%s_%s", TEST_INDEX_PREFIX_NAME, knnEngine.getName(), spaceType.getValue()); + } + + @SneakyThrows + private void createIndexAndIngestDocs(String indexName, String fieldName, Settings settings, String mapping) { + createKnnIndex(indexName, settings, mapping); + bulkAddKnnDocs(indexName, fieldName, INDEX_VECTORS, DOC_COUNT); + forceMergeKnnIndex(indexName, MAX_SEGMENT_COUNT); + } + + @SneakyThrows + private void setupTrainingIndex() { + XContentBuilder trainingIndexBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TRAIN_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs( + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + Settings.builder().put("number_of_shards", SHARD_COUNT).put("number_of_replicas", REPLICA_COUNT).build(), + trainingIndexBuilder.toString() + ); + } + + @SneakyThrows + private String getModelMapping() { + return XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(MODEL_ID, TEST_MODEL_ID) + .endObject() + .endObject() + .endObject() + .toString(); + } + + private Settings getSettings() { + return Settings.builder() + .put("number_of_shards", SHARD_COUNT) + .put("number_of_replicas", REPLICA_COUNT) + .put("index.knn", true) + .build(); + } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 1d984927f..02537400d 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -6,6 +6,7 @@ package org.opensearch.knn; import com.google.common.primitives.Floats; +import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; @@ -427,6 +428,13 @@ public int getDocCount(String indexName) throws Exception { * Force merge KNN index segments */ protected void forceMergeKnnIndex(String index) throws Exception { + forceMergeKnnIndex(index, 1); + } + + /** + * Force merge KNN index segments + */ + protected void forceMergeKnnIndex(String index, int maxSegments) throws Exception { Request request = new Request("POST", "/" + index + "/_refresh"); Response response = client().performRequest(request); @@ -434,7 +442,7 @@ protected void forceMergeKnnIndex(String index) throws Exception { request = new Request("POST", "/" + index + "/_forcemerge"); - request.addParameter("max_num_segments", "1"); + request.addParameter("max_num_segments", String.valueOf(maxSegments)); request.addParameter("flush", "true"); response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -637,6 +645,12 @@ protected Response executeKnnStatRequest(List nodeIds, List stat return response; } + @SneakyThrows + protected void doKnnWarmup(List indices) { + Response response = knnWarmup(indices); + assertEquals(response.getStatusLine().getStatusCode(), 200); + } + /** * Warmup KNN Index */ diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index e1d15e2bd..47870804b 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -70,9 +70,9 @@ public class TestUtils { SpaceType.LINF, KNNScoringUtil::lInfNorm, SpaceType.COSINESIMIL, - KNNScoringUtil::cosinesimil, + (a, b) -> -1 * KNNScoringUtil.cosinesimil(a, b), SpaceType.INNER_PRODUCT, - KNNScoringUtil::innerProduct + (a, b) -> -1 * KNNScoringUtil.innerProduct(a, b) ); public static final String KNN_BWC_PREFIX = "knn-bwc-";