diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index ecb884ba5..77e993297 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.query; +import com.google.common.base.Predicates; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.NonNull; import lombok.Value; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; @@ -65,7 +67,7 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } - return searchTopK(iterator, exactSearcherContext.getK()); + return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); } /** @@ -104,7 +106,8 @@ private Map scoreAllDocs(KNNIterator iterator) throws IOExceptio return docToScore; } - private Map searchTopCandidates(KNNIterator iterator, int limit, Predicate filterScore) throws IOException { + private Map searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate filterScore) + throws IOException { // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(limit, true); ScoreDoc topDoc = queue.top(); @@ -112,10 +115,7 @@ private Map searchTopCandidates(KNNIterator iterator, int limit, int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { final float currentScore = iterator.score(); - if (filterScore != null && Predicate.not(filterScore).test(currentScore)) { - continue; - } - if (currentScore > topDoc.score) { + if (filterScore.test(currentScore) && currentScore > topDoc.score) { topDoc.score = currentScore; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we @@ -138,10 +138,6 @@ private Map searchTopCandidates(KNNIterator iterator, int limit, return docToScore; } - private Map searchTopK(KNNIterator iterator, int k) throws IOException { - return searchTopCandidates(iterator, k, null); - } - private Map filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) throws IOException { int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow(); diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index c7862bae8..c494f7f1f 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -623,10 +623,11 @@ public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCrea // Assert we have the right number of documents in the index assertEquals(numDocs, getDocCount(indexName)); - // KNN Query should return empty result + final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); - assertEquals(0, results.size()); + // expect result due to exact search + assertEquals(1, results.size()); deleteKNNIndex(indexName); validateGraphEviction(); @@ -682,7 +683,7 @@ public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph() // KNN Query should return empty result final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); - assertEquals(0, results.size()); + assertEquals(1, results.size()); // update index setting to build graph and do force merge // update build vector data structure setting diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7a71c44be..2540446eb 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -68,6 +68,7 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Locale; @@ -360,7 +361,6 @@ public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned() final Path path = mock(Path.class); when(directory.getDirectory()).thenReturn(path); final FieldInfos fieldInfos = mock(FieldInfos.class); - final FieldInfo fieldInfo = mock(FieldInfo.class); when(reader.getFieldInfos()).thenReturn(fieldInfos); // When no knn fields are available , field info for vector field will be null when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null); @@ -763,6 +763,95 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean } } + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { + try (MockedStatic valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT); + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 14.0f, 15.0f, 16.0f }, + new float[] { 17.0f, 18.0f, 19.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + final Float score = Collections.min(expectedScores); + final float radius = KNNEngine.FAISS.scoreToRadialThreshold(score, spaceType); + final int maxResults = 1000; + final KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + KNNWeight.initialize(null); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue()); + KNNFloatVectorValues floatVectorValues = mock(KNNFloatVectorValues.class); + valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(floatVectorValues); + when(floatVectorValues.nextDoc()).thenReturn(0, 1, 2, NO_MORE_DOCS); + when(floatVectorValues.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(docIdSetIterator.cost(), dataVectors.size()); + List actualScores = new ArrayList<>(); + while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) { + actualScores.add(knnScorer.score()); + } + assertEquals(expectedScores, actualScores); + } + } + @SneakyThrows public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() { ModelDao modelDao = mock(ModelDao.class); diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index cca40c55d..7784c4bf4 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -118,7 +118,10 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); // update build vector data structure setting - updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + updateIndexSettings( + INDEX_NAME, + Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH) + ); forceMergeKnnIndex(INDEX_NAME, 1); int k = 100; @@ -133,7 +136,7 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_ } @SneakyThrows - public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception { + public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() { // Create Index createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length); ingestTestData(INDEX_NAME, FIELD_NAME, false); @@ -141,7 +144,10 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size()); // update build vector data structure setting - updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + updateIndexSettings( + INDEX_NAME, + Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH) + ); forceMergeKnnIndex(INDEX_NAME, 1); int k = 100; @@ -155,6 +161,17 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ } } + @SneakyThrows + public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { + // Create Index + createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); + + // Query + float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; + Exception e = expectThrows(Exception.class, () -> runRnnQuery(INDEX_NAME, FIELD_NAME, queryVector, 1, 4)); + assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search")); + } + private float getRecall(final Set truth, final Set result) { // Count the number of relevant documents retrieved result.retainAll(truth); @@ -167,6 +184,23 @@ private float getRecall(final Set truth, final Set result) { return (float) relevantRetrieved / totalRelevant; } + private List runRnnQuery( + final String indexName, + final String fieldName, + final float[] queryVector, + final float minScore, + final int size + ) throws Exception { + String query = KNNJsonQueryBuilder.builder() + .fieldName(fieldName) + .vector(ArrayUtils.toObject(queryVector)) + .minScore(minScore) + .build() + .getQueryString(); + Response response = searchKNNIndex(indexName, query, size); + return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName); + } + private List runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k) throws Exception { String query = KNNJsonQueryBuilder.builder()