diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index ecb884ba5..77e993297 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -5,8 +5,10 @@ package org.opensearch.knn.index.query; +import com.google.common.base.Predicates; import lombok.AllArgsConstructor; import lombok.Builder; +import lombok.NonNull; import lombok.Value; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; @@ -65,7 +67,7 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } - return searchTopK(iterator, exactSearcherContext.getK()); + return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue()); } /** @@ -104,7 +106,8 @@ private Map scoreAllDocs(KNNIterator iterator) throws IOExceptio return docToScore; } - private Map searchTopCandidates(KNNIterator iterator, int limit, Predicate filterScore) throws IOException { + private Map searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate filterScore) + throws IOException { // Creating min heap and init with MAX DocID and Score as -INF. final HitQueue queue = new HitQueue(limit, true); ScoreDoc topDoc = queue.top(); @@ -112,10 +115,7 @@ private Map searchTopCandidates(KNNIterator iterator, int limit, int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { final float currentScore = iterator.score(); - if (filterScore != null && Predicate.not(filterScore).test(currentScore)) { - continue; - } - if (currentScore > topDoc.score) { + if (filterScore.test(currentScore) && currentScore > topDoc.score) { topDoc.score = currentScore; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we @@ -138,10 +138,6 @@ private Map searchTopCandidates(KNNIterator iterator, int limit, return docToScore; } - private Map searchTopK(KNNIterator iterator, int k) throws IOException { - return searchTopCandidates(iterator, k, null); - } - private Map filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) throws IOException { int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow(); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index 7a71c44be..6cfdecb22 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -68,6 +68,7 @@ import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Locale; @@ -97,6 +98,8 @@ import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.common.KNNConstants.SPACE_TYPE; +import static org.opensearch.knn.index.SpaceType.INNER_PRODUCT; +import static org.opensearch.knn.index.SpaceType.L2; public class KNNWeightTests extends KNNTestCase { private static final String FIELD_NAME = "target_field"; @@ -171,7 +174,7 @@ public void tearDownAfterTest() { @SneakyThrows public void testQueryResultScoreNmslib() { - for (SpaceType space : List.of(SpaceType.L2, SpaceType.L1, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT, SpaceType.LINF)) { + for (SpaceType space : List.of(L2, SpaceType.L1, SpaceType.COSINESIMIL, SpaceType.INNER_PRODUCT, SpaceType.LINF)) { testQueryScore(space::scoreTranslation, SEGMENT_FILES_NMSLIB, Map.of(SPACE_TYPE, space.getValue())); } } @@ -179,11 +182,11 @@ public void testQueryResultScoreNmslib() { @SneakyThrows public void testQueryResultScoreFaiss() { testQueryScore( - SpaceType.L2::scoreTranslation, + L2::scoreTranslation, SEGMENT_FILES_FAISS, Map.of( SPACE_TYPE, - SpaceType.L2.getValue(), + L2.getValue(), KNN_ENGINE, KNNEngine.FAISS.getName(), PARAMETERS, @@ -221,7 +224,7 @@ public void testQueryResultScoreFaiss() { @SneakyThrows public void testQueryScoreForFaissWithModel() { - SpaceType spaceType = SpaceType.L2; + SpaceType spaceType = L2; final Function scoreTranslator = spaceType::scoreTranslation; final String modelId = "modelId"; jniServiceMockedStatic.when(() -> JNIService.queryIndex(anyLong(), any(), eq(K), isNull(), any(), any(), anyInt(), any())) @@ -293,7 +296,7 @@ public void testQueryScoreForFaissWithModel() { @SneakyThrows public void testQueryScoreForFaissWithNonExistingModel() throws IOException { - SpaceType spaceType = SpaceType.L2; + SpaceType spaceType = L2; final String modelId = "modelId"; final KNNQuery query = new KNNQuery(FIELD_NAME, QUERY_VECTOR, K, INDEX_NAME, (BitSetProducer) null); @@ -600,7 +603,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, - isBinary ? SpaceType.HAMMING.getValue() : SpaceType.L2.getValue() + isBinary ? SpaceType.HAMMING.getValue() : L2.getValue() ); when(reader.getFieldInfos()).thenReturn(fieldInfos); @@ -638,7 +641,7 @@ public void validateANNWithFilterQuery_whenDoingANN_thenSuccess(final boolean is } final List actualDocIds = new ArrayList<>(); - final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + final Map translatedScores = getTranslatedScores(L2::scoreTranslation); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); @@ -715,7 +718,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, - isBinary ? SpaceType.HAMMING.getValue() : SpaceType.L2.getValue() + isBinary ? SpaceType.HAMMING.getValue() : L2.getValue() ); final FieldInfos fieldInfos = mock(FieldInfos.class); final FieldInfo fieldInfo = mock(FieldInfo.class); @@ -727,7 +730,7 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean if (isBinary) { when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.HAMMING.getValue()); } else { - when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.getValue()); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(L2.getValue()); } when(fieldInfo.getName()).thenReturn(FIELD_NAME); @@ -763,6 +766,95 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean } } + @SneakyThrows + public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() { + try (MockedStatic valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) { + final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f }; + final SpaceType spaceType = randomFrom(L2, INNER_PRODUCT); + final List dataVectors = Arrays.asList( + new float[] { 11.0f, 12.0f, 13.0f }, + new float[] { 14.0f, 15.0f, 16.0f }, + new float[] { 17.0f, 18.0f, 19.0f } + ); + final List expectedScores = dataVectors.stream() + .map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector)) + .collect(Collectors.toList()); + final Float score = Collections.min(expectedScores); + final float radius = KNNEngine.FAISS.scoreToRadialThreshold(score, spaceType); + final int maxResults = 1000; + final KNNQuery.Context context = mock(KNNQuery.Context.class); + when(context.getMaxResultWindow()).thenReturn(maxResults); + KNNWeight.initialize(null); + + final KNNQuery query = KNNQuery.builder() + .field(FIELD_NAME) + .queryVector(queryVector) + .radius(radius) + .indexName(INDEX_NAME) + .context(context) + .methodParameters(HNSW_METHOD_PARAMETERS) + .build(); + final KNNWeight knnWeight = new KNNWeight(query, 1.0f); + + final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + final SegmentReader reader = mock(SegmentReader.class); + when(leafReaderContext.reader()).thenReturn(reader); + + final FSDirectory directory = mock(FSDirectory.class); + when(reader.directory()).thenReturn(directory); + final SegmentInfo segmentInfo = new SegmentInfo( + directory, + Version.LATEST, + Version.LATEST, + SEGMENT_NAME, + 100, + false, + false, + KNNCodecVersion.current().getDefaultCodecDelegate(), + Map.of(), + new byte[StringHelper.ID_LENGTH], + Map.of(), + Sort.RELEVANCE + ); + segmentInfo.setFiles(Set.of()); + final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + + final Path path = mock(Path.class); + when(directory.getDirectory()).thenReturn(path); + final FieldInfos fieldInfos = mock(FieldInfos.class); + final FieldInfo fieldInfo = mock(FieldInfo.class); + when(reader.getFieldInfos()).thenReturn(fieldInfos); + when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + when(fieldInfo.attributes()).thenReturn( + Map.of( + SPACE_TYPE, + spaceType.getValue(), + KNN_ENGINE, + KNNEngine.FAISS.getName(), + PARAMETERS, + String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + ) + ); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue()); + KNNFloatVectorValues floatVectorValues = mock(KNNFloatVectorValues.class); + valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(floatVectorValues); + when(floatVectorValues.nextDoc()).thenReturn(0, 1, 2, NO_MORE_DOCS); + when(floatVectorValues.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2)); + + final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + assertNotNull(knnScorer); + final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + assertNotNull(docIdSetIterator); + assertEquals(docIdSetIterator.cost(), dataVectors.size()); + List actualScores = new ArrayList<>(); + while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) { + actualScores.add(knnScorer.score()); + } + assertEquals(expectedScores, actualScores); + } + } + @SneakyThrows public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() { ModelDao modelDao = mock(ModelDao.class); @@ -791,7 +883,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, - SpaceType.L2.name(), + L2.name(), PARAMETERS, String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") ); @@ -801,7 +893,7 @@ public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenS when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); - when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(L2.name()); when(fieldInfo.getName()).thenReturn(FIELD_NAME); when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); when(binaryDocValues.advance(filterDocId)).thenReturn(filterDocId); @@ -861,7 +953,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, - SpaceType.L2.name(), + L2.name(), PARAMETERS, String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") ); @@ -871,7 +963,7 @@ public void testANNWithFilterQuery_whenExactSearchViaThresholdSetting_thenSucces when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); - when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(L2.name()); when(fieldInfo.getName()).thenReturn(FIELD_NAME); when(reader.getBinaryDocValues(FIELD_NAME)).thenReturn(binaryDocValues); when(binaryDocValues.advance(0)).thenReturn(0); @@ -1043,7 +1135,7 @@ public void testANNWithParentsFilter_whenExactSearch_thenSuccess() { // Verify final List expectedScores = vectors.stream() - .map(vector -> SpaceType.L2.getKnnVectorSimilarityFunction().compare(QUERY_VECTOR, vector)) + .map(vector -> L2.getKnnVectorSimilarityFunction().compare(QUERY_VECTOR, vector)) .collect(Collectors.toList()); final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); assertEquals(1, docIdSetIterator.nextDoc()); @@ -1177,7 +1269,7 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { when(fieldInfo.attributes()).thenReturn( Map.of( SPACE_TYPE, - SpaceType.L2.getValue(), + L2.getValue(), KNN_ENGINE, KNNEngine.FAISS.getName(), PARAMETERS, @@ -1204,7 +1296,7 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); final List actualDocIds = new ArrayList<>(); - final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + final Map translatedScores = getTranslatedScores(L2::scoreTranslation); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); @@ -1213,6 +1305,84 @@ public void testDoANNSearch_whenRadialIsDefined_thenCallJniRadiusQueryIndex() { assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); } + // public void testDoANNSearch_whenRadialIsDefined_whenNoEngineFiles_thenCallExactSearch() throws IOException { + // final float[] queryVector = new float[] { 0.1f, 0.3f }; + // final float radius = 0.5f; + // final int maxResults = 1000; + // KNNQuery.Context context = mock(KNNQuery.Context.class); + // when(context.getMaxResultWindow()).thenReturn(maxResults); + // KNNWeight.initialize(null); + // + // final KNNQuery query = KNNQuery.builder() + // .field(FIELD_NAME) + // .queryVector(queryVector) + // .radius(radius) + // .indexName(INDEX_NAME) + // .context(context) + // .methodParameters(HNSW_METHOD_PARAMETERS) + // .build(); + // final float boost = (float) randomDoubleBetween(0, 10, true); + // final KNNWeight knnWeight = new KNNWeight(query, boost); + // + // final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class); + // final SegmentReader reader = mock(SegmentReader.class); + // when(leafReaderContext.reader()).thenReturn(reader); + // + // final FSDirectory directory = mock(FSDirectory.class); + // when(reader.directory()).thenReturn(directory); + // final SegmentInfo segmentInfo = new SegmentInfo( + // directory, + // Version.LATEST, + // Version.LATEST, + // SEGMENT_NAME, + // 100, + // true, + // false, + // KNNCodecVersion.current().getDefaultCodecDelegate(), + // Map.of(), + // new byte[StringHelper.ID_LENGTH], + // Map.of(), + // Sort.RELEVANCE + // ); + // segmentInfo.setFiles(Set.of()); + // final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]); + // when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo); + // + // final Path path = mock(Path.class); + // when(directory.getDirectory()).thenReturn(path); + // final FieldInfos fieldInfos = mock(FieldInfos.class); + // final FieldInfo fieldInfo = mock(FieldInfo.class); + // when(reader.getFieldInfos()).thenReturn(fieldInfos); + // when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); + // when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.getValue()); + // when(fieldInfo.hasVectorValues()).thenReturn(true); + // when(fieldInfo.getVectorEncoding()).thenReturn(VectorEncoding.FLOAT32); + // when(fieldInfo.attributes()).thenReturn( + // Map.of( + // SPACE_TYPE, + // SpaceType.L2.getValue(), + // KNN_ENGINE, + // KNNEngine.FAISS.getName(), + // PARAMETERS, + // String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") + // ) + // ); + // + // final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext); + // assertNotNull(knnScorer); + // verifyNoInteractions(jniServiceMockedStatic); + // final DocIdSetIterator docIdSetIterator = knnScorer.iterator(); + // + // final List actualDocIds = new ArrayList<>(); + // final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + // for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { + // actualDocIds.add(docId); + // assertEquals(translatedScores.get(docId) * boost, knnScorer.score(), 0.01f); + // } + // assertEquals(docIdSetIterator.cost(), actualDocIds.size()); + // assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder())); + // } + private SegmentReader getMockedSegmentReader() { final SegmentReader reader = mock(SegmentReader.class); when(reader.maxDoc()).thenReturn(1); @@ -1250,13 +1420,13 @@ private SegmentReader getMockedSegmentReader() { KNN_ENGINE, KNNEngine.FAISS.getName(), SPACE_TYPE, - SpaceType.L2.name(), + L2.name(), PARAMETERS, String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32") ); final FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.attributes()).thenReturn(attributesMap); - when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(SpaceType.L2.name()); + when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(L2.name()); when(fieldInfo.getName()).thenReturn(FIELD_NAME); // Prepare fieldInfos