From d6fdc0be74d3e25aa8691591673408954e2d93e0 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Mon, 18 Mar 2024 10:24:31 -0700 Subject: [PATCH] Integrate index state sharing into mem management (#1545) Adds the ability to share index state amongst indices during index load operations into the plugins memory management system. Introduces a manager of the shared state that will properly manage the lifecycle of the shared state. There was a bug in clear cache that had to be fixed to get this change working as well. Previously, only one index file per clear cache would be freed. This fixes that logic to clear everything. Added unit tests and an integration test to confirm functionality. In addition, modified recall integration tests to get more coverage on the different algo configs. Along with this, had to fix a few things around the computation of recall for non-l2 space types. Signed-off-by: John Mazanec --- CHANGELOG.md | 1 + .../org/opensearch/knn/index/IndexUtil.java | 16 + .../opensearch/knn/index/KNNIndexShard.java | 72 ++- .../index/memory/NativeMemoryAllocation.java | 31 + .../memory/NativeMemoryEntryContext.java | 32 ++ .../memory/NativeMemoryLoadStrategy.java | 19 +- .../knn/index/memory/SharedIndexState.java | 21 + .../index/memory/SharedIndexStateManager.java | 150 +++++ .../opensearch/knn/index/query/KNNWeight.java | 3 +- .../org/opensearch/knn/index/FaissIT.java | 112 ++++ .../opensearch/knn/index/IndexUtilTests.java | 35 ++ .../knn/index/KNNIndexShardTests.java | 39 +- .../memory/SharedIndexStateManagerTests.java | 62 ++ .../index/memory/SharedIndexStateTests.java | 29 + .../opensearch/knn/recall/RecallTestsIT.java | 528 +++++++++++++++++- .../org/opensearch/knn/KNNRestTestCase.java | 16 +- .../java/org/opensearch/knn/TestUtils.java | 11 + 17 files changed, 1099 insertions(+), 78 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/memory/SharedIndexState.java create mode 100644 src/main/java/org/opensearch/knn/index/memory/SharedIndexStateManager.java create mode 100644 src/test/java/org/opensearch/knn/index/memory/SharedIndexStateManagerTests.java create mode 100644 src/test/java/org/opensearch/knn/index/memory/SharedIndexStateTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index c5f52a431..9f815f792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Disable sdc table for HNSWPQ read-only indices [#1518](https://github.com/opensearch-project/k-NN/pull/1518) * Switch SpaceType.INNERPRODUCT's vector similarity function to MAXIMUM_INNER_PRODUCT [#1532](https://github.com/opensearch-project/k-NN/pull/1532) * Add patch to fix arm segfault in nmslib during ingestion [#1541](https://github.com/opensearch-project/k-NN/pull/1541) +* Share ivfpq-l2 table allocations across indices on load [#1558](https://github.com/opensearch-project/k-NN/pull/1558) ### Infrastructure * Manually install zlib for win CI [#1513](https://github.com/opensearch-project/k-NN/pull/1513) * Update k-NN build artifact script to enable SIMD on ARM for Faiss [#1543](https://github.com/opensearch-project/k-NN/pull/1543) diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 1b385319a..adfa611c7 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -23,6 +23,7 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.jni.JNIService; import java.io.File; import java.util.Collections; @@ -266,4 +267,19 @@ public static boolean isVersionOnOrAfterMinRequiredVersion(Version version, Stri } return version.onOrAfter(minimalRequiredVersion); } + + /** + * Checks if index requires shared state + * + * @param knnEngine The knnEngine associated with the index + * @param modelId The modelId associated with the index + * @param indexAddr Address to check if loaded index requires shared state + * @return true if state can be shared; false otherwise + */ + public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String modelId, long indexAddr) { + if (StringUtils.isEmpty(modelId)) { + return false; + } + return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine); + } } diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index fec93a734..71f6596dc 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -5,6 +5,9 @@ package org.opensearch.knn.index; +import com.google.common.annotations.VisibleForTesting; +import lombok.AllArgsConstructor; +import lombok.Getter; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FilterLeafReader; @@ -24,13 +27,14 @@ import java.io.IOException; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; -import java.util.Map; +import java.util.List; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; import static org.opensearch.knn.index.IndexUtil.getParametersAtLoading; import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFilePrefix; @@ -41,8 +45,8 @@ */ @Log4j2 public class KNNIndexShard { - private IndexShard indexShard; - private NativeMemoryCacheManager nativeMemoryCacheManager; + private final IndexShard indexShard; + private final NativeMemoryCacheManager nativeMemoryCacheManager; private static final String INDEX_SHARD_CLEAR_CACHE_SEARCHER = "knn-clear-cache"; /** @@ -83,14 +87,19 @@ public String getIndexName() { public void warmup() throws IOException { log.info("[KNN] Warming up index: [{}]", getIndexName()); try (Engine.Searcher searcher = indexShard.acquireSearcher("knn-warmup")) { - getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> { + getAllEngineFileContexts(searcher.getIndexReader()).forEach((engineFileContext) -> { try { nativeMemoryCacheManager.get( new NativeMemoryEntryContext.IndexEntryContext( - key, + engineFileContext.getIndexPath(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - getParametersAtLoading(value, KNNEngine.getEngineNameFromPath(key), getIndexName()), - getIndexName() + getParametersAtLoading( + engineFileContext.getSpaceType(), + KNNEngine.getEngineNameFromPath(engineFileContext.getIndexPath()), + getIndexName() + ), + getIndexName(), + engineFileContext.getModelId() ), true ); @@ -117,7 +126,9 @@ public void clearCache() { indexAllocation.writeLock(); log.info("[KNN] Evicting index from cache: [{}]", indexName); try (Engine.Searcher searcher = indexShard.acquireSearcher(INDEX_SHARD_CLEAR_CACHE_SEARCHER)) { - getAllEnginePaths(searcher.getIndexReader()).forEach((key, value) -> nativeMemoryCacheManager.invalidate(key)); + getAllEngineFileContexts(searcher.getIndexReader()).forEach( + (engineFileContext) -> nativeMemoryCacheManager.invalidate(engineFileContext.getIndexPath()) + ); } catch (IOException ex) { log.error("[KNN] Failed to evict index from cache: [{}]", indexName, ex); throw new RuntimeException(ex); @@ -128,22 +139,23 @@ public void clearCache() { } /** - * For the given shard, get all of its engine paths + * For the given shard, get all of its engine file context objects * - * @param indexReader IndexReader to read the file paths for the shard - * @return List of engine file Paths + * @param indexReader IndexReader to read the information for each segment in the shard + * @return List of engine contexts * @throws IOException Thrown when the SegmentReader is attempting to read the segments files */ - public Map getAllEnginePaths(IndexReader indexReader) throws IOException { - Map engineFiles = new HashMap<>(); + @VisibleForTesting + List getAllEngineFileContexts(IndexReader indexReader) throws IOException { + List engineFiles = new ArrayList<>(); for (KNNEngine knnEngine : KNNEngine.getEnginesThatCreateCustomSegmentFiles()) { - engineFiles.putAll(getEnginePaths(indexReader, knnEngine)); + engineFiles.addAll(getEngineFileContexts(indexReader, knnEngine)); } return engineFiles; } - private Map getEnginePaths(IndexReader indexReader, KNNEngine knnEngine) throws IOException { - Map engineFiles = new HashMap<>(); + List getEngineFileContexts(IndexReader indexReader, KNNEngine knnEngine) throws IOException { + List engineFiles = new ArrayList<>(); for (LeafReaderContext leafReaderContext : indexReader.leaves()) { SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(leafReaderContext.reader()); @@ -158,15 +170,17 @@ private Map getEnginePaths(IndexReader indexReader, KNNEngine // was L2. So, if Space Type is not present, just fall back to L2 String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); SpaceType spaceType = SpaceType.getSpace(spaceTypeName); + String modelId = fieldInfo.attributes().getOrDefault(MODEL_ID, null); - engineFiles.putAll( - getEnginePaths( + engineFiles.addAll( + getEngineFileContexts( reader.getSegmentInfo().files(), reader.getSegmentInfo().info.name, fieldInfo.name, fileExtension, shardPath, - spaceType + spaceType, + modelId ) ); } @@ -175,13 +189,15 @@ private Map getEnginePaths(IndexReader indexReader, KNNEngine return engineFiles; } - protected Map getEnginePaths( + @VisibleForTesting + List getEngineFileContexts( Collection files, String segmentName, String fieldName, String fileExtension, Path shardPath, - SpaceType spaceType + SpaceType spaceType, + String modelId ) { String prefix = buildEngineFilePrefix(segmentName); String suffix = buildEngineFileSuffix(fieldName, fileExtension); @@ -189,6 +205,16 @@ protected Map getEnginePaths( .filter(fileName -> fileName.startsWith(prefix)) .filter(fileName -> fileName.endsWith(suffix)) .map(fileName -> shardPath.resolve(fileName).toString()) - .collect(Collectors.toMap(fileName -> fileName, fileName -> spaceType)); + .map(fileName -> new EngineFileContext(spaceType, modelId, fileName)) + .collect(Collectors.toList()); + } + + @AllArgsConstructor + @Getter + @VisibleForTesting + static class EngineFileContext { + private final SpaceType spaceType; + private final String modelId; + private final String indexPath; } } diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java index 416980759..286e6265c 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryAllocation.java @@ -91,6 +91,7 @@ class IndexAllocation implements NativeMemoryAllocation { private final String openSearchIndexName; private final ReadWriteLock readWriteLock; private final WatcherHandle watcherHandle; + private final SharedIndexState sharedIndexState; /** * Constructor @@ -111,6 +112,31 @@ class IndexAllocation implements NativeMemoryAllocation { String indexPath, String openSearchIndexName, WatcherHandle watcherHandle + ) { + this(executorService, memoryAddress, size, knnEngine, indexPath, openSearchIndexName, watcherHandle, null); + } + + /** + * Constructor + * + * @param executorService Executor service used to close the allocation + * @param memoryAddress Pointer in memory to the index + * @param size Size this index consumes in kilobytes + * @param knnEngine KNNEngine associated with the index allocation + * @param indexPath File path to index + * @param openSearchIndexName Name of OpenSearch index this index is associated with + * @param watcherHandle Handle for watching index file + * @param sharedIndexState Shared index state. If not shared state present, pass null. + */ + IndexAllocation( + ExecutorService executorService, + long memoryAddress, + int size, + KNNEngine knnEngine, + String indexPath, + String openSearchIndexName, + WatcherHandle watcherHandle, + SharedIndexState sharedIndexState ) { this.executor = executorService; this.closed = false; @@ -121,6 +147,7 @@ class IndexAllocation implements NativeMemoryAllocation { this.readWriteLock = new ReentrantReadWriteLock(); this.size = size; this.watcherHandle = watcherHandle; + this.sharedIndexState = sharedIndexState; } @Override @@ -145,6 +172,10 @@ private void cleanup() { if (memoryAddress != 0) { JNIService.free(memoryAddress, knnEngine); } + + if (sharedIndexState != null) { + SharedIndexStateManager.getInstance().release(sharedIndexState); + } } @Override diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java index 13f8dae10..7f14a2341 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryEntryContext.java @@ -12,6 +12,7 @@ package org.opensearch.knn.index.memory; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; import org.opensearch.knn.index.IndexUtil; import java.io.IOException; @@ -63,6 +64,8 @@ public static class IndexEntryContext extends NativeMemoryEntryContext parameters; + @Nullable + private final String modelId; /** * Constructor @@ -77,11 +80,31 @@ public IndexEntryContext( NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy, Map parameters, String openSearchIndexName + ) { + this(indexPath, indexLoadStrategy, parameters, openSearchIndexName, null); + } + + /** + * Constructor + * + * @param indexPath path to index file. Also used as key in cache. + * @param indexLoadStrategy strategy to load index into memory + * @param parameters load time parameters + * @param openSearchIndexName opensearch index associated with index + * @param modelId model to be loaded. If none available, pass null + */ + public IndexEntryContext( + String indexPath, + NativeMemoryLoadStrategy.IndexLoadStrategy indexLoadStrategy, + Map parameters, + String openSearchIndexName, + String modelId ) { super(indexPath); this.indexLoadStrategy = indexLoadStrategy; this.openSearchIndexName = openSearchIndexName; this.parameters = parameters; + this.modelId = modelId; } @Override @@ -112,6 +135,15 @@ public Map getParameters() { return parameters; } + /** + * Getter + * + * @return return model ID for the index. null if no model is in use + */ + public String getModelId() { + return modelId; + } + private static class IndexSizeCalculator implements Function { static IndexSizeCalculator INSTANCE = new IndexSizeCalculator(); diff --git a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java index 568cc892b..cb7dafdfc 100644 --- a/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java +++ b/src/main/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategy.java @@ -11,7 +11,9 @@ package org.opensearch.knn.index.memory; +import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; +import org.opensearch.knn.index.IndexUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.training.TrainingDataConsumer; @@ -41,6 +43,7 @@ public interface NativeMemoryLoadStrategy, @@ -92,17 +95,25 @@ public NativeMemoryAllocation.IndexAllocation load(NativeMemoryEntryContext.Inde fileWatcher.init(); KNNEngine knnEngine = KNNEngine.getEngineNameFromPath(indexPath.toString()); - long memoryAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), knnEngine); - final WatcherHandle watcherHandle = resourceWatcherService.add(fileWatcher); + long indexAddress = JNIService.loadIndex(indexPath.toString(), indexEntryContext.getParameters(), knnEngine); + SharedIndexState sharedIndexState = null; + String modelId = indexEntryContext.getModelId(); + if (IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, indexAddress)) { + log.info("Index with model: \"{}\" requires shared state. Retrieving shared state.", modelId); + sharedIndexState = SharedIndexStateManager.getInstance().get(indexAddress, modelId, knnEngine); + JNIService.setSharedIndexState(indexAddress, sharedIndexState.getSharedIndexStateAddress(), knnEngine); + } + final WatcherHandle watcherHandle = resourceWatcherService.add(fileWatcher); return new NativeMemoryAllocation.IndexAllocation( executor, - memoryAddress, + indexAddress, indexEntryContext.calculateSizeInKB(), knnEngine, indexPath.toString(), indexEntryContext.getOpenSearchIndexName(), - watcherHandle + watcherHandle, + sharedIndexState ); } diff --git a/src/main/java/org/opensearch/knn/index/memory/SharedIndexState.java b/src/main/java/org/opensearch/knn/index/memory/SharedIndexState.java new file mode 100644 index 000000000..2ffadb22e --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/SharedIndexState.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.knn.index.util.KNNEngine; + +/** + * Class stores information about the shared memory allocations between loaded native indices. + */ +@RequiredArgsConstructor +@Getter +public class SharedIndexState { + private final long sharedIndexStateAddress; + private final String modelId; + private final KNNEngine knnEngine; +} diff --git a/src/main/java/org/opensearch/knn/index/memory/SharedIndexStateManager.java b/src/main/java/org/opensearch/knn/index/memory/SharedIndexStateManager.java new file mode 100644 index 000000000..896113834 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/memory/SharedIndexStateManager.java @@ -0,0 +1,150 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.memory; + +import com.google.common.annotations.VisibleForTesting; +import lombok.Getter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * Class manages allocations that can be shared between native indices. No locking is required. + * Once a caller obtain an instance of a {@link org.opensearch.knn.index.memory.SharedIndexState}, it is guaranteed to + * be valid until it is returned. {@link org.opensearch.knn.index.memory.SharedIndexState} are reference counted + * internally. Once the reference count goes to 0, it will be freed. + */ +@Log4j2 +class SharedIndexStateManager { + // Map storing the shared index state with key being the modelId. + private final ConcurrentHashMap sharedIndexStateCache; + private final ReadWriteLock readWriteLock; + + private static SharedIndexStateManager INSTANCE; + + // TODO: Going to refactor away from doing this in the future. For now, keeping for simplicity. + public static synchronized SharedIndexStateManager getInstance() { + if (INSTANCE == null) { + INSTANCE = new SharedIndexStateManager(); + } + return INSTANCE; + } + + /** + * Constructor + */ + @VisibleForTesting + SharedIndexStateManager() { + this.sharedIndexStateCache = new ConcurrentHashMap<>(); + this.readWriteLock = new ReentrantReadWriteLock(); + } + + /** + * Return a {@link SharedIndexState} associated with the key. If no value exists, it will attempt to create it. + * Once returned, the {@link SharedIndexState} will be valid until + * {@link SharedIndexStateManager#release(SharedIndexState)} is called. Caller must ensure that this is + * called after it is done using it. + * + * In order to create the shared state, it will use the indexAddress passed in to create the shared state from + * using {@link org.opensearch.knn.jni.JNIService#initSharedIndexState(long, KNNEngine)}. + * + * @param indexAddress Address of index to initialize the shared state from + * @param knnEngine engine index belongs to + * @return ShareModelContext + */ + public SharedIndexState get(long indexAddress, String modelId, KNNEngine knnEngine) { + this.readWriteLock.readLock().lock(); + try { + // This can be done safely with readLock because the ConcurrentHasMap.computeIfAbsent guarantees: + // + // "If the specified key is not already associated with a value, attempts to compute its value using the given + // mapping function and enters it into this map unless null. The entire method invocation is performed + // atomically, so the function is applied at most once per key. Some attempted update operations on this map + // by other threads may be blocked while computation is in progress, so the computation should be short and + // simple, and must not attempt to update any other mappings of this map." + // + // Ref: + // https://docs.oracle.com/javase/8/docs/api/java/util/concurrent/ConcurrentHashMap.html#computeIfAbsent-K-java.util.function.Function- + SharedIndexStateEntry entry = sharedIndexStateCache.computeIfAbsent(modelId, m -> { + log.info("Loading entry to shared index state cache for model {}", modelId); + long sharedIndexStateAddress = JNIService.initSharedIndexState(indexAddress, knnEngine); + return new SharedIndexStateEntry(new SharedIndexState(sharedIndexStateAddress, modelId, knnEngine)); + }); + entry.incRef(); + return entry.getSharedIndexState(); + } finally { + this.readWriteLock.readLock().unlock(); + } + } + + /** + * Indicate that the {@link SharedIndexState} is no longer being used. If nothing else is using it, it will be + * removed from the cache and evicted. + * + * After calling this method, {@link SharedIndexState} should no longer be used by calling thread. + * + * @param sharedIndexState to return to the system. + */ + public void release(SharedIndexState sharedIndexState) { + this.readWriteLock.writeLock().lock(); + + try { + SharedIndexStateEntry sharedIndexStateEntry; + if ((sharedIndexStateEntry = sharedIndexStateCache.get(sharedIndexState.getModelId())) == null) { + // This should not happen. Will log the error and return to prevent crash + log.error("Attempting to evict model from cache but it is not present: {}", sharedIndexState.getModelId()); + return; + } + + if (sharedIndexStateEntry.decRef() <= 0) { + log.info("Evicting entry from shared index state cache for key {}", sharedIndexState.getModelId()); + sharedIndexStateCache.remove(sharedIndexState.getModelId()); + JNIService.freeSharedIndexState(sharedIndexState.getSharedIndexStateAddress(), sharedIndexState.getKnnEngine()); + } + } finally { + this.readWriteLock.writeLock().unlock(); + } + } + + private static final class SharedIndexStateEntry { + @Getter + private final SharedIndexState sharedIndexState; + private final AtomicLong referenceCount; + + /** + * Constructor + * + * @param sharedIndexState sharedIndexStateContext being wrapped + */ + private SharedIndexStateEntry(SharedIndexState sharedIndexState) { + this.sharedIndexState = sharedIndexState; + this.referenceCount = new AtomicLong(0); + } + + /** + * Increases reference count by 1 + * + * @return ++referenceCount + */ + private long incRef() { + return referenceCount.incrementAndGet(); + } + + /** + * Decrease reference count by 1 + * + * @return --referenceCount + */ + private long decRef() { + return referenceCount.decrementAndGet(); + } + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 7e2fa19cc..1c4d0a646 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -256,7 +256,8 @@ private Map doANNSearch(final LeafReaderContext context, final B indexPath.toString(), NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()), - knnQuery.getIndexName() + knnQuery.getIndexName(), + modelId ), true ); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 15c7119a8..b8149e4f7 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -40,6 +40,7 @@ import java.util.TreeMap; import java.util.stream.Collectors; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; @@ -51,6 +52,7 @@ import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.NAME; @@ -497,6 +499,116 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed( fail("Graphs are not getting evicted"); } + /** + * This test confirms that sharing index state for IVFPQ-l2 indices functions properly. The main functionality that + * needs to be confirmed is that once an index gets deleted, it will not cause a failure for the non-deleted index. + * + * The workflow will be: + * 1. Create a model + * 2. Create two indices index from the model + * 3. Load the native index files from the first index + * 4. Assert search works + * 5. Load the native index files (which will reuse the shared state from the initial index) + * 6. Assert search works on the second index + * 7. Delete the first index and wait + * 8. Assert search works on the second index + */ + @SneakyThrows + public void testSharedIndexState_whenOneIndexDeleted_thenSecondIndexIsStillSearchable() { + String firstIndexName = "test-index-1"; + String secondIndexName = "test-index-2"; + String trainingIndexName = "training-index"; + + String modelId = "test-model"; + String modelDescription = "ivfpql2 model for testing shared state"; + + int dimension = testData.indexData.vectors[0].length; + SpaceType spaceType = SpaceType.L2; + int ivfNlist = 4; + int ivfNprobes = 4; + int pqCodeSize = 8; + int pqM = 1; + int docCount = 100; + + // training data needs to be at least equal to the number of centroids for PQ + // which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ + int trainingDataCount = 256; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NPROBES, ivfNprobes) + .field(METHOD_PARAMETER_NLIST, ivfNlist) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_M, pqCodeSize) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, pqM) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + createBasicKnnIndex(trainingIndexName, FIELD_NAME, dimension); + ingestDataAndTrainModel(modelId, trainingIndexName, FIELD_NAME, dimension, modelDescription, in, trainingDataCount); + assertTrainingSucceeds(modelId, 360, 1000); + + createIndexFromModelAndIngestDocuments(firstIndexName, modelId, docCount); + createIndexFromModelAndIngestDocuments(secondIndexName, modelId, docCount); + + doKnnWarmup(List.of(firstIndexName)); + validateSearchWorkflow(firstIndexName, testData.queries, 10); + doKnnWarmup(List.of(secondIndexName)); + validateSearchWorkflow(secondIndexName, testData.queries, 10); + deleteKNNIndex(firstIndexName); + // wait for all index files to be cleaned up from original index. empirically determined to take 25 seconds. + // will give 15 second buffer from that + Thread.sleep(1000 * 45); + validateSearchWorkflow(secondIndexName, testData.queries, 10); + deleteModel(modelId); + } + + @SneakyThrows + private void createIndexFromModelAndIngestDocuments(String indexName, String modelId, int docCount) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("model_id", modelId) + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + + for (int i = 0; i < Math.min(testData.indexData.docs.length, docCount); i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + refreshAllNonSystemIndices(); + assertEquals(Math.min(testData.indexData.docs.length, docCount), getDocCount(indexName)); + } + + @SneakyThrows + private void validateSearchWorkflow(String indexName, float[][] queries, int k) { + for (float[] query : queries) { + Response response = searchKNNIndex(indexName, new KNNQueryBuilder(FIELD_NAME, query, k), k); + String responseBody = EntityUtils.toString(response.getEntity()); + List knnResults = parseSearchResponse(responseBody, FIELD_NAME); + assertEquals(k, knnResults.size()); + } + } + public void testDocUpdate() throws IOException { String indexName = "test-index-1"; String fieldName = "test-field-1"; diff --git a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java index c8b29b6ef..00493b293 100644 --- a/src/test/java/org/opensearch/knn/index/IndexUtilTests.java +++ b/src/test/java/org/opensearch/knn/index/IndexUtilTests.java @@ -12,6 +12,8 @@ package org.opensearch.knn.index; import com.google.common.collect.ImmutableMap; +import org.junit.BeforeClass; +import org.mockito.MockedStatic; import org.opensearch.Version; import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.metadata.IndexMetadata; @@ -24,12 +26,16 @@ import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.jni.JNIService; import java.util.Map; import java.util.Objects; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.when; import static org.opensearch.knn.common.KNNConstants.HNSW_ALGO_EF_SEARCH; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; @@ -37,6 +43,15 @@ import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_EF_SEARCH; public class IndexUtilTests extends KNNTestCase { + + private static MockedStatic jniServiceMockedStatic; + private static final long TEST_INDEX_ADDRESS = 0; + + @BeforeClass + public static void setUpClass() { + jniServiceMockedStatic = mockStatic(JNIService.class); + } + public void testGetLoadParameters() { // Test faiss to ensure that space type gets set properly SpaceType spaceType1 = SpaceType.COSINESIMIL; @@ -206,4 +221,24 @@ public void testValidateKnnField_EmptyIndexMetadata() { assert (Objects.requireNonNull(e).getMessage().matches("Validation Failed: 1: Invalid index. Index does not contain a mapping;")); } + + public void testIsShareableStateContainedInIndex_whenIndexNotModelBased_thenReturnFalse() { + String modelId = null; + KNNEngine knnEngine = KNNEngine.FAISS; + assertFalse(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); + } + + public void testIsShareableStateContainedInIndex_whenFaissHNSWIsUsed_thenReturnFalse() { + jniServiceMockedStatic.when(() -> JNIService.isSharedIndexStateRequired(anyLong(), any())).thenReturn(false); + String modelId = "test-model"; + KNNEngine knnEngine = KNNEngine.FAISS; + assertFalse(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); + } + + public void testIsShareableStateContainedInIndex_whenJNIIsSharedIndexStateRequiredIsTrue_thenReturnTrue() { + jniServiceMockedStatic.when(() -> JNIService.isSharedIndexStateRequired(anyLong(), any())).thenReturn(true); + String modelId = "test-model"; + KNNEngine knnEngine = KNNEngine.FAISS; + assertTrue(IndexUtil.isSharedIndexStateRequired(knnEngine, modelId, TEST_INDEX_ADDRESS)); + } } diff --git a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java index e476969e1..7b6f96d5a 100644 --- a/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNIndexShardTests.java @@ -17,9 +17,7 @@ import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; -import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -98,38 +96,35 @@ public void testWarmup_shardNotPresentInCache() throws InterruptedException, Exe assertEquals(2, NativeMemoryCacheManager.getInstance().getIndicesCacheStats().get(testIndexName).get(GRAPH_COUNT)); } - public void testGetHNSWPaths() throws IOException, ExecutionException, InterruptedException { + public void testGetAllEngineFileContexts() throws IOException, ExecutionException, InterruptedException { IndexService indexService = createKNNIndex(testIndexName); createKnnIndexMapping(testIndexName, testFieldName, dimensions); - IndexShard indexShard; - KNNIndexShard knnIndexShard; - Engine.Searcher searcher; - Map hnswPaths; - indexShard = indexService.iterator().next(); - knnIndexShard = new KNNIndexShard(indexShard); + IndexShard indexShard = indexService.iterator().next(); + KNNIndexShard knnIndexShard = new KNNIndexShard(indexShard); - searcher = indexShard.acquireSearcher("test-hnsw-paths-1"); - hnswPaths = knnIndexShard.getAllEnginePaths(searcher.getIndexReader()); - assertEquals(0, hnswPaths.size()); + Engine.Searcher searcher = indexShard.acquireSearcher("test-hnsw-paths-1"); + List engineFileContexts = knnIndexShard.getAllEngineFileContexts(searcher.getIndexReader()); + assertEquals(0, engineFileContexts.size()); searcher.close(); addKnnDoc(testIndexName, "1", testFieldName, new Float[] { 2.5F, 3.5F }); searcher = indexShard.acquireSearcher("test-hnsw-paths-2"); - hnswPaths = knnIndexShard.getAllEnginePaths(searcher.getIndexReader()); - assertEquals(1, hnswPaths.size()); - List paths = new ArrayList<>(hnswPaths.keySet()); + engineFileContexts = knnIndexShard.getAllEngineFileContexts(searcher.getIndexReader()); + assertEquals(1, engineFileContexts.size()); + List paths = engineFileContexts.stream().map(KNNIndexShard.EngineFileContext::getIndexPath).collect(Collectors.toList()); assertTrue(paths.get(0).contains("hnsw") || paths.get(0).contains("hnswc")); searcher.close(); } - public void testGetEnginePaths() { + public void testGetEngineFileContexts() { // Check that the correct engine paths are being returned by the KNNIndexShard String segmentName = "_0"; String fieldName = "test_field"; String fileExt = ".test"; SpaceType spaceType = SpaceType.L2; + String modelId = "test-model"; Set includedFileNames = ImmutableSet.of( String.format("%s_111_%s%s", segmentName, fieldName, fileExt), @@ -148,10 +143,18 @@ public void testGetEnginePaths() { KNNIndexShard knnIndexShard = new KNNIndexShard(null); Path path = Paths.get(""); - Map included = knnIndexShard.getEnginePaths(files, segmentName, fieldName, fileExt, path, spaceType); + List included = knnIndexShard.getEngineFileContexts( + files, + segmentName, + fieldName, + fileExt, + path, + spaceType, + modelId + ); assertEquals(includedFileNames.size(), included.size()); - included.keySet().forEach(o -> assertTrue(includedFileNames.contains(o))); + included.stream().map(KNNIndexShard.EngineFileContext::getIndexPath).forEach(o -> assertTrue(includedFileNames.contains(o))); } @SneakyThrows diff --git a/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateManagerTests.java b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateManagerTests.java new file mode 100644 index 000000000..daf02c611 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateManagerTests.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.memory; + +import org.junit.BeforeClass; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.jni.JNIService; + +import static org.mockito.Mockito.mockStatic; + +public class SharedIndexStateManagerTests extends KNNTestCase { + private static MockedStatic jniServiceMockedStatic; + private final static long TEST_SHARED_TABLE_ADDRESS = 123; + private final static long TEST_INDEX_ADDRESS = 1234; + private final static String TEST_MODEL_ID = "test-model-id"; + private final static KNNEngine TEST_KNN_ENGINE = KNNEngine.DEFAULT; + + @BeforeClass + public static void setUpClass() { + jniServiceMockedStatic = mockStatic(JNIService.class); + jniServiceMockedStatic.when(() -> JNIService.freeSharedIndexState(TEST_SHARED_TABLE_ADDRESS, TEST_KNN_ENGINE)) + .then(invocation -> null); + jniServiceMockedStatic.when(() -> JNIService.initSharedIndexState(TEST_INDEX_ADDRESS, TEST_KNN_ENGINE)) + .thenReturn(TEST_SHARED_TABLE_ADDRESS); + } + + public void testGet_whenNormalWorkfloatApplied_thenSucceed() { + SharedIndexStateManager sharedIndexStateManager = new SharedIndexStateManager(); + SharedIndexState firstSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + assertEquals(TEST_SHARED_TABLE_ADDRESS, firstSharedIndexStateRetrieved.getSharedIndexStateAddress()); + assertEquals(TEST_MODEL_ID, firstSharedIndexStateRetrieved.getModelId()); + assertEquals(TEST_KNN_ENGINE, firstSharedIndexStateRetrieved.getKnnEngine()); + + SharedIndexState secondSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + assertEquals(TEST_SHARED_TABLE_ADDRESS, secondSharedIndexStateRetrieved.getSharedIndexStateAddress()); + assertEquals(TEST_MODEL_ID, secondSharedIndexStateRetrieved.getModelId()); + assertEquals(TEST_KNN_ENGINE, secondSharedIndexStateRetrieved.getKnnEngine()); + } + + public void testRelease_whenNormalWorkflowApplied_thenSucceed() { + SharedIndexStateManager sharedIndexStateManager = new SharedIndexStateManager(); + SharedIndexState firstSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + SharedIndexState secondSharedIndexStateRetrieved = sharedIndexStateManager.get(TEST_INDEX_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + + sharedIndexStateManager.release(firstSharedIndexStateRetrieved); + jniServiceMockedStatic.verify(() -> JNIService.freeSharedIndexState(TEST_SHARED_TABLE_ADDRESS, TEST_KNN_ENGINE), Mockito.times(0)); + sharedIndexStateManager.release(secondSharedIndexStateRetrieved); + jniServiceMockedStatic.verify(() -> JNIService.freeSharedIndexState(TEST_SHARED_TABLE_ADDRESS, TEST_KNN_ENGINE), Mockito.times(1)); + } +} diff --git a/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateTests.java b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateTests.java new file mode 100644 index 000000000..cddbec5c0 --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/memory/SharedIndexStateTests.java @@ -0,0 +1,29 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.knn.index.memory; + +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.util.KNNEngine; + +public class SharedIndexStateTests extends KNNTestCase { + + private static final String TEST_MODEL_ID = "test-model"; + private static final long TEST_SHARED_INDEX_STATE_ADDRESS = 22L; + private static final KNNEngine TEST_KNN_ENGINE = KNNEngine.DEFAULT; + + public void testSharedIndexState() { + SharedIndexState sharedIndexState = new SharedIndexState(TEST_SHARED_INDEX_STATE_ADDRESS, TEST_MODEL_ID, TEST_KNN_ENGINE); + assertEquals(TEST_MODEL_ID, sharedIndexState.getModelId()); + assertEquals(TEST_SHARED_INDEX_STATE_ADDRESS, sharedIndexState.getSharedIndexStateAddress()); + assertEquals(TEST_KNN_ENGINE, sharedIndexState.getKnnEngine()); + } +} diff --git a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java index 889f3916f..c90afaa62 100644 --- a/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java +++ b/src/test/java/org/opensearch/knn/recall/RecallTestsIT.java @@ -11,54 +11,520 @@ package org.opensearch.knn.recall; +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.TestUtils; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.util.KNNEngine; + import java.util.List; +import java.util.Map; import java.util.Set; + +import static org.opensearch.knn.common.KNNConstants.DIMENSION; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_M; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.common.KNNConstants.TYPE; +import static org.opensearch.knn.common.KNNConstants.TYPE_KNN_VECTOR; +import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_EF_SEARCH; import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY; import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; +/** + * Tests confirm that for the different supported configurations, recall is sound. The recall thresholds are + * conservatively and empirically determined to prevent flakiness. + * + * This test suite can take a long time to run. The primary reason is that training can take a long time for PQ. + * The parameters for PQ have been reduced significantly, but it still takes time. + */ public class RecallTestsIT extends KNNRestTestCase { - private final String testFieldName = "test-field"; - private final int dimensions = 50; - private final int docCount = 10000; - private final int queryCount = 100; - private final int k = 5; - private final double expRecallValue = 1.0; - - public void testRecallL2StandardData() throws Exception { - String testIndexStandard = "test-index-standard"; - - addDocs(testIndexStandard, testFieldName, dimensions, docCount, true); - float[][] indexVectors = getIndexVectorsFromIndex(testIndexStandard, testFieldName, docCount, dimensions); - float[][] queryVectors = TestUtils.getQueryVectors(queryCount, dimensions, docCount, true); - List> groundTruthValues = TestUtils.computeGroundTruthValues(indexVectors, queryVectors, SpaceType.L2, k); - List> searchResults = bulkSearch(testIndexStandard, testFieldName, queryVectors, k); - double recallValue = TestUtils.calculateRecallValue(searchResults, groundTruthValues, k); - assertEquals(expRecallValue, recallValue, 0.2); + private static final String PROPERTIES_FIELD = "properties"; + private final static String TEST_INDEX_PREFIX_NAME = "test_index"; + private final static String TEST_FIELD_NAME = "test_field"; + private final static String TRAIN_INDEX_NAME = "train_index"; + private final static String TRAIN_FIELD_NAME = "train_field"; + private final static String TEST_MODEL_ID = "test_model_id"; + private final static int TEST_DIMENSION = 32; + private final static int DOC_COUNT = 500; + private final static int QUERY_COUNT = 100; + private final static int TEST_K = 100; + private final static double PERFECT_RECALL = 1.0; + private final static int SHARD_COUNT = 1; + private final static int REPLICA_COUNT = 0; + private final static int MAX_SEGMENT_COUNT = 10; + + // Standard algorithm parameters + private final static int HNSW_M = 16; + private final static int HNSW_EF_CONSTRUCTION = 100; + private final static int HNSW_EF_SEARCH = TEST_K; // For consistency with lucene + private final static int IVF_NLIST = 4; + private final static int IVF_NPROBES = IVF_NLIST; // This equates to essentially a brute force search + private final static int PQ_CODE_SIZE = 8; // This is low and going to produce bad recall, but reduces build time + private final static int PQ_M = TEST_DIMENSION / 8; // Will give low recall, but required for test time + + // Setup ground truth for all tests once + private final static float[][] INDEX_VECTORS = TestUtils.getIndexVectors(DOC_COUNT, TEST_DIMENSION, true); + private final static float[][] QUERY_VECTORS = TestUtils.getQueryVectors(QUERY_COUNT, TEST_DIMENSION, DOC_COUNT, true); + private final static Map>> GROUND_TRUTH = Map.of( + SpaceType.L2, + TestUtils.computeGroundTruthValues(INDEX_VECTORS, QUERY_VECTORS, SpaceType.L2, TEST_K), + SpaceType.COSINESIMIL, + TestUtils.computeGroundTruthValues(INDEX_VECTORS, QUERY_VECTORS, SpaceType.COSINESIMIL, TEST_K), + SpaceType.INNER_PRODUCT, + TestUtils.computeGroundTruthValues(INDEX_VECTORS, QUERY_VECTORS, SpaceType.INNER_PRODUCT, TEST_K) + ); + + @SneakyThrows + @Before + public void setupClusterSettings() { + updateClusterSettings(KNN_ALGO_PARAM_INDEX_THREAD_QTY, 2); + updateClusterSettings(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED, true); + } + + /** + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"nmslib", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH} + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenNmslibHnswFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.NMSLIB, spaceType); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .startObject(KNN_METHOD) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .field(METHOD_PARAMETER_M, HNSW_M) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs( + indexName, + TEST_FIELD_NAME, + Settings.builder() + .put("number_of_shards", SHARD_COUNT) + .put("number_of_replicas", REPLICA_COUNT) + .put("index.knn", true) + .put(KNN_ALGO_PARAM_EF_SEARCH, HNSW_EF_SEARCH) + .build(), + builder.toString() + ); + assertRecall(indexName, spaceType, 0.25f); + } } - public void testRecallL2RandomData() throws Exception { - String testIndexRandom = "test-index-random"; + /** + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"lucene", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION} + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenLuceneHnswFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.COSINESIMIL); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.LUCENE, spaceType); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .startObject(KNN_METHOD) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.LUCENE.getName()) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .field(METHOD_PARAMETER_M, HNSW_M) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), builder.toString()); + assertRecall(indexName, spaceType, 0.25f); + } + } - addDocs(testIndexRandom, testFieldName, dimensions, docCount, false); - float[][] indexVectors = getIndexVectorsFromIndex(testIndexRandom, testFieldName, docCount, dimensions); - float[][] queryVectors = TestUtils.getQueryVectors(queryCount, dimensions, docCount, false); - List> groundTruthValues = TestUtils.computeGroundTruthValues(indexVectors, queryVectors, SpaceType.L2, k); - List> searchResults = bulkSearch(testIndexRandom, testFieldName, queryVectors, k); - double recallValue = TestUtils.calculateRecallValue(searchResults, groundTruthValues, k); - assertEquals(expRecallValue, recallValue, 0.2); + /** + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {TEST_DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH}, + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissHnswFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .startObject(KNN_METHOD) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .field(NAME, METHOD_HNSW) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .field(METHOD_PARAMETER_M, HNSW_M) + .field(METHOD_PARAMETER_EF_SEARCH, HNSW_EF_SEARCH) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), builder.toString()); + assertRecall(indexName, spaceType, 0.25f); + } } - private void addDocs(String testIndex, String testField, int dimensions, int docCount, boolean isStandard) throws Exception { - createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(testField, dimensions)); + /** + * Train context: + * { + * "method": { + * "name":"ivf", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "nlist":{IVF_NLIST}, + * "nprobes": {IVF_NPROBES} + * } + * } + * } + * + * Index Mapping: + * { + * "properties": { + * { + * "type": "knn_vector", + * "model_id": {MODEL_ID} + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissIVFFP32_thenRecallAbove75percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + setupTrainingIndex(); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); - updateClusterSettings(KNN_ALGO_PARAM_INDEX_THREAD_QTY, 2); - updateClusterSettings(KNN_MEMORY_CIRCUIT_BREAKER_ENABLED, true); + // Train the model + XContentBuilder trainingBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, IVF_NLIST) + .field(METHOD_PARAMETER_NPROBES, IVF_NPROBES) + .endObject() + .endObject(); + trainModel( + TEST_MODEL_ID, + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + TEST_DIMENSION, + xContentBuilderToMap(trainingBuilder), + String.format("%s-%s", KNNEngine.FAISS.getName(), spaceType.getValue()) + ); + assertTrainingSucceeds(TEST_MODEL_ID, 100, 1000 * 5); + + // Build the index + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); + assertRecall(indexName, spaceType, 0.25f); + + // Delete the model + deleteModel(TEST_MODEL_ID); + } + } + + /** + * Train context: + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {TEST_DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH}, + * } + * } + * } + * } + * } + * + * Index Mapping: + * { + * "properties": { + * { + * "type": "knn_vector", + * "model_id": {MODEL_ID} + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissIVFPQFP32_thenRecallAbove50percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + setupTrainingIndex(); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); - bulkAddKnnDocs(testIndex, testField, TestUtils.getIndexVectors(docCount, dimensions, isStandard), docCount); + // Train the model + XContentBuilder trainingBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_IVF) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, IVF_NLIST) + .field(METHOD_PARAMETER_NPROBES, IVF_NPROBES) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, PQ_CODE_SIZE) + .field(ENCODER_PARAMETER_PQ_M, PQ_M) + .endObject() + .endObject() + .endObject() + .endObject(); + trainModel( + TEST_MODEL_ID, + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + TEST_DIMENSION, + xContentBuilderToMap(trainingBuilder), + String.format("%s-%s", KNNEngine.FAISS.getName(), spaceType.getValue()) + ); + assertTrainingSucceeds(TEST_MODEL_ID, 100, 1000 * 5); + + // Build the index + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); + assertRecall(indexName, spaceType, 0.5f); + + // Delete the model + deleteModel(TEST_MODEL_ID); + } + } + + /** + * Train context: + * { + * "properties": { + * { + * "type": "knn_vector", + * "dimension": {TEST_DIMENSION}, + * "method": { + * "name":"hnsw", + * "engine":"faiss", + * "space_type": "{SPACE_TYPE}", + * "parameters":{ + * "m":{HNSW_M}, + * "ef_construction": {HNSW_EF_CONSTRUCTION}, + * "ef_search": {HNSW_EF_SEARCH}, + * } + * } + * } + * } + * } + * + * Index Mapping: + * { + * "properties": { + * { + * "type": "knn_vector", + * "model_id": {MODEL_ID} + * } + * } + * } + */ + @SneakyThrows + public void testRecall_whenFaissHNSWPQFP32_thenRecallAbove50percent() { + List spaceTypes = List.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + setupTrainingIndex(); + for (SpaceType spaceType : spaceTypes) { + String indexName = createIndexName(KNNEngine.FAISS, spaceType); + + // Train the model + XContentBuilder trainingBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, HNSW_M) + .field(METHOD_PARAMETER_EF_SEARCH, HNSW_EF_SEARCH) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, HNSW_EF_CONSTRUCTION) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_PQ) + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, PQ_CODE_SIZE) + .field(ENCODER_PARAMETER_PQ_M, PQ_M) + .endObject() + .endObject() + .endObject() + .endObject(); + trainModel( + TEST_MODEL_ID, + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + TEST_DIMENSION, + xContentBuilderToMap(trainingBuilder), + String.format("%s-%s", KNNEngine.FAISS.getName(), spaceType.getValue()) + ); + assertTrainingSucceeds(TEST_MODEL_ID, 100, 1000 * 5); + + // Build the index + createIndexAndIngestDocs(indexName, TEST_FIELD_NAME, getSettings(), getModelMapping()); + assertRecall(indexName, spaceType, 0.5f); + + // Delete the model + deleteModel(TEST_MODEL_ID); + } + } + + @SneakyThrows + private void assertRecall(String testIndexName, SpaceType spaceType, float acceptableRecallFromPerfect) { + List> searchResults = bulkSearch(testIndexName, TEST_FIELD_NAME, QUERY_VECTORS, TEST_K); + double recallValue = TestUtils.calculateRecallValue(searchResults, GROUND_TRUTH.get(spaceType), TEST_K); + logger.info("Recall value = {}", recallValue); + assertEquals(PERFECT_RECALL, recallValue, acceptableRecallFromPerfect); } + private String createIndexName(KNNEngine knnEngine, SpaceType spaceType) { + return String.format("%s_%s_%s", TEST_INDEX_PREFIX_NAME, knnEngine.getName(), spaceType.getValue()); + } + + @SneakyThrows + private void createIndexAndIngestDocs(String indexName, String fieldName, Settings settings, String mapping) { + createKnnIndex(indexName, settings, mapping); + bulkAddKnnDocs(indexName, fieldName, INDEX_VECTORS, DOC_COUNT); + forceMergeKnnIndex(indexName, MAX_SEGMENT_COUNT); + } + + @SneakyThrows + private void setupTrainingIndex() { + XContentBuilder trainingIndexBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TRAIN_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(DIMENSION, TEST_DIMENSION) + .endObject() + .endObject() + .endObject(); + createIndexAndIngestDocs( + TRAIN_INDEX_NAME, + TRAIN_FIELD_NAME, + Settings.builder().put("number_of_shards", SHARD_COUNT).put("number_of_replicas", REPLICA_COUNT).build(), + trainingIndexBuilder.toString() + ); + } + + @SneakyThrows + private String getModelMapping() { + return XContentFactory.jsonBuilder() + .startObject() + .startObject(PROPERTIES_FIELD) + .startObject(TEST_FIELD_NAME) + .field(TYPE, TYPE_KNN_VECTOR) + .field(MODEL_ID, TEST_MODEL_ID) + .endObject() + .endObject() + .endObject() + .toString(); + } + + private Settings getSettings() { + return Settings.builder() + .put("number_of_shards", SHARD_COUNT) + .put("number_of_replicas", REPLICA_COUNT) + .put("index.knn", true) + .build(); + } } diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index d21ab57f7..0b6ae3a5e 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -6,6 +6,7 @@ package org.opensearch.knn; import com.google.common.primitives.Floats; +import lombok.SneakyThrows; import lombok.extern.log4j.Log4j2; import org.apache.commons.lang.StringUtils; import org.apache.hc.core5.http.io.entity.EntityUtils; @@ -455,6 +456,13 @@ public int getDocCount(String indexName) throws Exception { * Force merge KNN index segments */ protected void forceMergeKnnIndex(String index) throws Exception { + forceMergeKnnIndex(index, 1); + } + + /** + * Force merge KNN index segments + */ + protected void forceMergeKnnIndex(String index, int maxSegments) throws Exception { Request request = new Request("POST", "/" + index + "/_refresh"); Response response = client().performRequest(request); @@ -462,7 +470,7 @@ protected void forceMergeKnnIndex(String index) throws Exception { request = new Request("POST", "/" + index + "/_forcemerge"); - request.addParameter("max_num_segments", "1"); + request.addParameter("max_num_segments", String.valueOf(maxSegments)); request.addParameter("flush", "true"); response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); @@ -665,6 +673,12 @@ protected Response executeKnnStatRequest(List nodeIds, List stat return response; } + @SneakyThrows + protected void doKnnWarmup(List indices) { + Response response = knnWarmup(indices); + assertEquals(response.getStatusLine().getStatusCode(), 200); + } + /** * Warmup KNN Index */ diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index e1d15e2bd..0de11b2f2 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -143,6 +143,12 @@ public static List> computeGroundTruthValues(float[][] indexVectors, pq = new PriorityQueue<>(k, new DistComparator()); for (int j = 0; j < indexVectors.length; j++) { float dist = computeDistFromSpaceType(spaceType, indexVectors[j], queryVectors[i]); + + // Need to invert distance for IP or COSINE because higher is better in these cases + if (spaceType == SpaceType.INNER_PRODUCT || spaceType == SpaceType.COSINESIMIL) { + dist *= -1; + } + pq = insertWithOverflow(pq, k, dist, j); } @@ -203,6 +209,11 @@ public static PriorityQueue computeGroundTruthValues(int k, SpaceTyp for (int id = 0; id < numDocs; id++) { float[] indexVector = idVectorProducer.getVector(id); float dist = computeDistFromSpaceType(spaceType, indexVector, queryVector); + // Need to invert distance for IP or COSINE because higher is better in these cases + if (spaceType == SpaceType.INNER_PRODUCT || spaceType == SpaceType.COSINESIMIL) { + dist *= -1; + } + pq = insertWithOverflow(pq, k, dist, id); } return pq;