diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 8e5849abb..ecb884ba5 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -21,6 +21,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator; +import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator; import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator; @@ -36,7 +37,9 @@ import java.io.IOException; import java.util.HashMap; +import java.util.Locale; import java.util.Map; +import java.util.function.Predicate; @Log4j2 @AllArgsConstructor @@ -55,6 +58,9 @@ public class ExactSearcher { public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + if (exactSearcherContext.getKnnQuery().getRadius() != null) { + return doRadialSearch(leafReaderContext, exactSearcherContext, iterator); + } if (exactSearcherContext.getMatchedDocs() != null && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); @@ -62,6 +68,33 @@ public Map searchLeaf(final LeafReaderContext leafReaderContext, return searchTopK(iterator, exactSearcherContext.getK()); } + /** + * Perform radial search by comparing scores with min score. Currently, FAISS from native engine supports radial search. + * Hence, we assume that Radius from knnQuery is always distance, and we convert it to score since we do exact search uses scores + * to filter out the documents that does not have given min score. + * @param leafReaderContext + * @param exactSearcherContext + * @param iterator {@link KNNIterator} + * @return Map of docId and score + * @throws IOException exception raised by iterator during traversal + */ + private Map doRadialSearch( + LeafReaderContext leafReaderContext, + ExactSearcherContext exactSearcherContext, + KNNIterator iterator + ) throws IOException { + final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); + final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); + final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); + final KNNEngine engine = FieldInfoExtractor.extractKNNEngine(fieldInfo); + if (KNNEngine.FAISS != engine) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "Engine [%s] does not support radial search", engine)); + } + final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); + final float minScore = spaceType.scoreTranslation(knnQuery.getRadius()); + return filterDocsByMinScore(exactSearcherContext, iterator, minScore); + } + private Map scoreAllDocs(KNNIterator iterator) throws IOException { final Map docToScore = new HashMap<>(); int docId; @@ -71,15 +104,19 @@ private Map scoreAllDocs(KNNIterator iterator) throws IOExceptio return docToScore; } - private Map searchTopK(KNNIterator iterator, int k) throws IOException { + private Map searchTopCandidates(KNNIterator iterator, int limit, Predicate filterScore) throws IOException { // Creating min heap and init with MAX DocID and Score as -INF. - final HitQueue queue = new HitQueue(k, true); + final HitQueue queue = new HitQueue(limit, true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); int docId; while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { - if (iterator.score() > topDoc.score) { - topDoc.score = iterator.score(); + final float currentScore = iterator.score(); + if (filterScore != null && Predicate.not(filterScore).test(currentScore)) { + continue; + } + if (currentScore > topDoc.score) { + topDoc.score = currentScore; topDoc.doc = docId; // As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we // have seen till now on top. @@ -98,10 +135,20 @@ private Map searchTopK(KNNIterator iterator, int k) throws IOExc final ScoreDoc doc = queue.pop(); docToScore.put(doc.doc, doc.score); } - return docToScore; } + private Map searchTopK(KNNIterator iterator, int k) throws IOException { + return searchTopCandidates(iterator, k, null); + } + + private Map filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore) + throws IOException { + int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow(); + Predicate scoreGreaterThanOrEqualToMinScore = score -> score >= minScore; + return searchTopCandidates(iterator, maxResultWindow, scoreGreaterThanOrEqualToMinScore); + } + private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index d5cd80934..d43a97c43 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -204,8 +204,8 @@ private int[] bitSetToIntArray(final BitSet bitSet) { private Map doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException { final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder() - .k(k) .isParentHits(true) + .k(k) // setting to true, so that if quantization details are present we want to do search on the quantized // vectors as this flow is used in first pass of search. .useQuantizedVectorsForSearch(true) @@ -398,12 +398,9 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { filterIdsCount, KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()) ); - if (knnQuery.getRadius() != null) { - return false; - } int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (filterIdsCount <= knnQuery.getK()) { + if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { return true; } // See user has defined Exact Search filtered threshold. if yes, then use that setting. diff --git a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java index 99152ef6b..b5166866c 100644 --- a/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java +++ b/src/main/java/org/opensearch/knn/index/query/RNNQueryFactory.java @@ -88,6 +88,7 @@ public static Query create(RNNQueryFactory.CreateQueryRequest createQueryRequest .indexName(indexName) .parentsFilter(parentFilter) .radius(radius) + .vectorDataType(vectorDataType) .methodParameters(methodParameters) .context(knnQueryContext) .filterQuery(filterQuery) diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index eec520a63..c7862bae8 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -92,6 +92,7 @@ public class FaissIT extends KNNRestTestCase { private static final String INTEGER_FIELD_NAME = "int_field"; private static final String FILED_TYPE_INTEGER = "integer"; private static final String NON_EXISTENT_INTEGER_FIELD_NAME = "nonexistent_int_field"; + public static final int NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD = -1; static TestUtils.TestData testData; @@ -1826,6 +1827,111 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() { validateGraphEviction(); } + @SneakyThrows + public void testEndToEnd_whenDoRadiusSearch_whenNoGraphFileIsCreated_whenDistanceThreshold_thenSucceed() { + final SpaceType spaceType = SpaceType.L2; + + final List mValues = ImmutableList.of(16, 32, 64, 128); + final List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + final List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + createKnnIndex(INDEX_NAME, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + INDEX_NAME, + Integer.toString(testData.indexData.docs[i]), + FIELD_NAME, + Floats.asList(testData.indexData.vectors[i]).toArray() + ); + } + + // Assert we have the right number of documents + refreshAllNonSystemIndices(); + assertEquals(testData.indexData.docs.length, getDocCount(INDEX_NAME)); + + final float distance = 300000000000f; + final List> resultsFromDistance = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + distance, + null, + spaceType, + null, + null + ); + assertFalse(resultsFromDistance.isEmpty()); + resultsFromDistance.forEach(result -> { assertFalse(result.isEmpty()); }); + final float score = spaceType.scoreTranslation(distance); + final List> resultsFromScore = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + testData.queries, + null, + score, + spaceType, + null, + null + ); + assertFalse(resultsFromScore.isEmpty()); + resultsFromScore.forEach(result -> { assertFalse(result.isEmpty()); }); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + + @SneakyThrows + public void testRadialQueryWithFilter_whenNoGraphIsCreated_thenSuccess() { + setupKNNIndexForFilterQuery(buildKNNIndexSettings(NEVER_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD)); + + final float[][] searchVector = new float[][] { { 3.3f, 3.0f, 5.0f } }; + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery("color", "red"); + List expectedDocIds = Arrays.asList(DOC_ID_3); + + float distance = 15f; + List> queryResult = validateRadiusSearchResults( + INDEX_NAME, + FIELD_NAME, + searchVector, + distance, + null, + SpaceType.L2, + termQueryBuilder, + null + ); + + assertEquals(1, queryResult.get(0).size()); + assertEquals(expectedDocIds.get(0), queryResult.get(0).get(0).getDocId()); + + // Delete index + deleteKNNIndex(INDEX_NAME); + } + @SneakyThrows public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful() { XContentBuilder builder = XContentFactory.jsonBuilder() @@ -1898,6 +2004,10 @@ public void testQueryWithFilter_whenNonExistingFieldUsedInFilter_thenSuccessful( } protected void setupKNNIndexForFilterQuery() throws Exception { + setupKNNIndexForFilterQuery(getKNNDefaultIndexSettings()); + } + + protected void setupKNNIndexForFilterQuery(Settings settings) throws Exception { // Create Mappings XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() @@ -1915,7 +2025,7 @@ protected void setupKNNIndexForFilterQuery() throws Exception { .endObject(); final String mapping = builder.toString(); - createKnnIndex(INDEX_NAME, mapping); + createKnnIndex(INDEX_NAME, settings, mapping); addKnnDocWithAttributes( DOC_ID_1, diff --git a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java index bf8168d37..c6e5c8fd4 100644 --- a/src/test/java/org/opensearch/knn/index/OpenSearchIT.java +++ b/src/test/java/org/opensearch/knn/index/OpenSearchIT.java @@ -814,6 +814,92 @@ public void testKNNIndex_whenBuildVectorDataStructureIsLessThanDocCount_thenBuil deleteKNNIndex(indexName); } + /* + For this testcase, we will create index with setting build_vector_data_structure_threshold as -1, then index few documents, perform knn search, + then, confirm hits because of exact search though there are no graph. In next step, update setting to 0, force merge segment to 1, perform knn search and confirm expected + hits are returned. + */ + public void testKNNIndex_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSettingUsingRadialSearch() + throws Exception { + final String indexName = "test-index-1"; + final String fieldName1 = "test-field-1"; + final String fieldName2 = "test-field-2"; + + final Integer dimension = testData.indexData.vectors[0].length; + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName1) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .startObject(fieldName2) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNNConstants.KNN_METHOD) + .field(KNNConstants.NAME, KNNConstants.METHOD_HNSW) + .field(KNNConstants.KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(KNNConstants.PARAMETERS) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + createKnnIndex(indexName, knnIndexSettings, builder.toString()); + + // Index the test data + for (int i = 0; i < testData.indexData.docs.length; i++) { + addKnnDoc( + indexName, + Integer.toString(testData.indexData.docs[i]), + ImmutableList.of(fieldName1, fieldName2), + ImmutableList.of( + Floats.asList(testData.indexData.vectors[i]).toArray(), + Floats.asList(testData.indexData.vectors[i]).toArray() + ) + ); + } + + refreshAllIndices(); + // Assert we have the right number of documents in the index + assertEquals(testData.indexData.docs.length, getDocCount(indexName)); + + final List nmslibNeighbors = getResults(indexName, fieldName1, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", nmslibNeighbors.size(), nmslibNeighbors.size()); + + final List faissNeighbors = getResults(indexName, fieldName2, testData.queries[0], 1); + assertEquals("unexpected neighbors are returned", faissNeighbors.size(), faissNeighbors.size()); + + // update build vector data structure setting + updateIndexSettings(indexName, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0)); + forceMergeKnnIndex(indexName, 1); + + final int k = 10; + for (int i = 0; i < testData.queries.length; i++) { + // Search nmslib field + final Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName1, testData.queries[i], k), k); + final String responseBody = EntityUtils.toString(response.getEntity()); + final List nmslibValidNeighbors = parseSearchResponse(responseBody, fieldName1); + assertEquals(k, nmslibValidNeighbors.size()); + // Search faiss field + final List faissValidNeighbors = getResults(indexName, fieldName2, testData.queries[i], k); + assertEquals(k, faissValidNeighbors.size()); + } + + // Delete index + deleteKNNIndex(indexName); + } + private List getResults(final String indexName, final String fieldName, final float[] vector, final int k) throws IOException, ParseException { final Response searchResponseField = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, vector, k), k); diff --git a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java index eed2772b4..cca40c55d 100644 --- a/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java +++ b/src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java @@ -155,17 +155,6 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_ } } - @SneakyThrows - public void testFaissHnswBinary_whenRadialSearch_thenThrowException() { - // Create Index - createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16); - - // Query - float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 }; - Exception e = expectThrows(Exception.class, () -> runRnnQuery(INDEX_NAME, FIELD_NAME, queryVector, 1, 4)); - assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search")); - } - private float getRecall(final Set truth, final Set result) { // Count the number of relevant documents retrieved result.retainAll(truth); @@ -178,23 +167,6 @@ private float getRecall(final Set truth, final Set result) { return (float) relevantRetrieved / totalRelevant; } - private List runRnnQuery( - final String indexName, - final String fieldName, - final float[] queryVector, - final float minScore, - final int size - ) throws Exception { - String query = KNNJsonQueryBuilder.builder() - .fieldName(fieldName) - .vector(ArrayUtils.toObject(queryVector)) - .minScore(minScore) - .build() - .getQueryString(); - Response response = searchKNNIndex(indexName, query, size); - return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName); - } - private List runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k) throws Exception { String query = KNNJsonQueryBuilder.builder()