diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 7f8144f99..193cba8c1 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -51,8 +51,9 @@ public class ExactSearcher { */ public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) throws IOException { - KNNIterator iterator = getMatchedKNNIterator(leafReaderContext, exactSearcherContext); - if (exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { + KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext); + if (exactSearcherContext.getMatchedDocs() != null + && exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } return searchTopK(iterator, exactSearcherContext.getK()); @@ -98,8 +99,7 @@ private Map searchTopK(KNNIterator iterator, int k) throws IOExc return docToScore; } - private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) - throws IOException { + private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException { final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); @@ -108,19 +108,17 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null; - if (VectorDataType.BINARY == knnQuery.getVectorDataType() && isNestedRequired) { - final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); - return new NestedByteVectorIdsKNNIterator( - matchedDocs, - knnQuery.getByteQueryVector(), - (KNNBinaryVectorValues) vectorValues, - spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext) - ); - } - if (VectorDataType.BINARY == knnQuery.getVectorDataType()) { final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); + if (isNestedRequired) { + return new NestedByteVectorIdsKNNIterator( + matchedDocs, + knnQuery.getByteQueryVector(), + (KNNBinaryVectorValues) vectorValues, + spaceType, + knnQuery.getParentsFilter().getBitSet(leafReaderContext) + ); + } return new ByteVectorIdsKNNIterator( matchedDocs, knnQuery.getByteQueryVector(), @@ -152,7 +150,6 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E segmentLevelQuantizationInfo ); } - return new VectorIdsKNNIterator( matchedDocs, knnQuery.getQueryVector(), diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java index 61b3c01b1..520542990 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/ByteVectorIdsKNNIterator.java @@ -38,6 +38,8 @@ public ByteVectorIdsKNNIterator( this.queryVector = queryVector; this.binaryVectorValues = binaryVectorValues; this.spaceType = spaceType; + // This cannot be moved inside nextDoc() method since it will break when we have nested field, where + // nextDoc should already be referring to next knnVectorValues this.docId = getNextDocId(); } diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java index fbb7e2979..aa799c2b6 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedByteVectorIdsKNNIterator.java @@ -58,6 +58,11 @@ public int nextDoc() throws IOException { int currentParent = parentBitSet.nextSetBit(docId); int bestChild = -1; + // In order to traverse all children for given parent, we have to use docId < parentId, because, + // kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child + // and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4] + // and parentBitSet will have [1,5] + // Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { if (bitSetIterator != null) { binaryVectorValues.advance(docId); diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java index a30fc47df..de6f4ac3b 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/NestedVectorIdsKNNIterator.java @@ -70,6 +70,11 @@ public int nextDoc() throws IOException { int currentParent = parentBitSet.nextSetBit(docId); int bestChild = -1; + // In order to traverse all children for given parent, we have to use docId < parentId, because, + // kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child + // and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4] + // and parentBitSet will have [1,5] + // Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) { if (bitSetIterator != null) { knnFloatVectorValues.advance(docId); diff --git a/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java index 764bbfdbf..024a5616d 100644 --- a/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/iterators/VectorIdsKNNIterator.java @@ -58,18 +58,13 @@ public VectorIdsKNNIterator( this.queryVector = queryVector; this.knnFloatVectorValues = knnFloatVectorValues; this.spaceType = spaceType; + // This cannot be moved inside nextDoc() method since it will break when we have nested field, where + // nextDoc should already be referring to next knnVectorValues this.docId = getNextDocId(); this.quantizedQueryVector = quantizedQueryVector; this.segmentLevelQuantizationInfo = segmentLevelQuantizationInfo; } - protected int getNextDocId() throws IOException { - if (bitSetIterator != null) { - return bitSetIterator.nextDoc(); - } - return knnFloatVectorValues.nextDoc(); - } - /** * Advance to the next doc and update score value with score of the next doc. * DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs @@ -107,4 +102,11 @@ protected float computeScore() throws IOException { return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); } } + + protected int getNextDocId() throws IOException { + if (bitSetIterator != null) { + return bitSetIterator.nextDoc(); + } + return knnFloatVectorValues.nextDoc(); + } }