Skip to content

Commit

Permalink
KNNIterators should support with and without filters (#2155)
Browse files Browse the repository at this point in the history
* Rename class names to represent both and filter and non filter use cases
* Iterator should support with filters

Update VectorIterator and NesterVector Iterator to
iterate even if there is no filters provided to iterator.
Currently this is used by exact search to score either topk
docs or all docs when filter is provided by users.
However, in future we will be allowing exact search
even if there are no filters. Hence, decouple filter
and make it option to support both cases.

---------

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB authored Sep 28, 2024
1 parent e0c3afe commit 6f6dd56
Show file tree
Hide file tree
Showing 13 changed files with 404 additions and 192 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Add short circuit if no live docs are in segments [#2059](https://github.com/opensearch-project/k-NN/pull/2059)
* Optimize reduceToTopK in ResultUtil by removing pre-filling and reducing peek calls [#2146](https://github.com/opensearch-project/k-NN/pull/2146)
* Update Default Rescore Context based on Dimension [#2149](https://github.com/opensearch-project/k-NN/pull/2149)
* KNNIterators should support with and without filters [#2155](https://github.com/opensearch-project/k-NN/pull/2155)
### Bug Fixes
* KNN80DocValues should only be considered for BinaryDocValues fields [#2147](https://github.com/opensearch-project/k-NN/pull/2147)
### Infrastructure
Expand Down
47 changes: 22 additions & 25 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.FilteredIdsKNNIterator;
import org.opensearch.knn.index.query.filtered.KNNIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNByteIterator;
import org.opensearch.knn.index.query.filtered.NestedFilteredIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.KNNIterator;
import org.opensearch.knn.index.query.iterators.NestedByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedVectorIdsKNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
Expand All @@ -51,8 +51,9 @@ public class ExactSearcher {
*/
public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext)
throws IOException {
KNNIterator iterator = getMatchedKNNIterator(leafReaderContext, exactSearcherContext);
if (exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
KNNIterator iterator = getKNNIterator(leafReaderContext, exactSearcherContext);
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopK(iterator, exactSearcherContext.getK());
Expand Down Expand Up @@ -98,8 +99,7 @@ private Map<Integer, Float> searchTopK(KNNIterator iterator, int k) throws IOExc
return docToScore;
}

private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext)
throws IOException {
private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
Expand All @@ -108,20 +108,18 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E

boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null;

if (VectorDataType.BINARY == knnQuery.getVectorDataType() && isNestedRequired) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
return new NestedFilteredIdsKNNByteIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}

if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
return new FilteredIdsKNNByteIterator(
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new ByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
Expand All @@ -142,7 +140,7 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E

final KNNVectorValues<float[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedFilteredIdsKNNIterator(
return new NestedVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
Expand All @@ -152,8 +150,7 @@ private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, E
segmentLevelQuantizationInfo
);
}

return new FilteredIdsKNNIterator(
return new VectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNFloatVectorValues) vectorValues,
Expand All @@ -180,7 +177,7 @@ public static class ExactSearcherContext {
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
* filtered nested search where the matchedDocs contain the parent ids and {@link NestedFilteredIdsKNNIterator}
* filtered nested search where the matchedDocs contain the parent ids and {@link NestedVectorIdsKNNIterator}
* needs to be used.
*/
boolean isParentHits;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;
package org.opensearch.knn.index.query.iterators;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

Expand All @@ -17,30 +18,34 @@
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162
*
* The class is used in KNNWeight to score filtered KNN field by iterating filterIdsArray.
* The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided
*/
public class FilteredIdsKNNByteIterator implements KNNIterator {
// Array of doc ids to iterate
protected final BitSet filterIdsBitSet;
public class ByteVectorIdsKNNIterator implements KNNIterator {
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public FilteredIdsKNNByteIterator(
final BitSet filterIdsBitSet,
public ByteVectorIdsKNNIterator(
@Nullable final BitSet filterIdsBitSet,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType
) {
this.filterIdsBitSet = filterIdsBitSet;
this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
) throws IOException {
this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryVectorValues = binaryVectorValues;
this.spaceType = spaceType;
this.docId = bitSetIterator.nextDoc();
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where
// nextDoc should already be referring to next knnVectorValues
this.docId = getNextDocId();
}

public ByteVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType)
throws IOException {
this(null, queryVector, binaryVectorValues, spaceType);
}

/**
Expand All @@ -55,10 +60,10 @@ public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
int doc = binaryVectorValues.advance(docId);
currentScore = computeScore();
docId = bitSetIterator.nextDoc();
return doc;
int currentDocId = docId;
docId = getNextDocId();
return currentDocId;
}

@Override
Expand All @@ -72,4 +77,16 @@ protected float computeScore() throws IOException {
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
}

protected int getNextDocId() throws IOException {
if (bitSetIterator == null) {
return binaryVectorValues.nextDoc();
}
int nextDocID = this.bitSetIterator.nextDoc();
// For filter case, advance vector values to corresponding doc id from filter bit set
if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) {
binaryVectorValues.advance(nextDocID);
}
return nextDocID;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;
package org.opensearch.knn.index.query.iterators;

import java.io.IOException;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,45 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;
package org.opensearch.knn.index.query.iterators;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

/**
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc
* This iterator iterates filterIdsArray to score if filter is provided else it iterates over all docs.
* However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
*/
public class NestedFilteredIdsKNNByteIterator extends FilteredIdsKNNByteIterator {
public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator {
private final BitSet parentBitSet;

public NestedFilteredIdsKNNByteIterator(
final BitSet filterIdsArray,
public NestedByteVectorIdsKNNIterator(
@Nullable final BitSet filterIdsArray,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
) throws IOException {
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

public NestedByteVectorIdsKNNIterator(
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
super(null, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

/**
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
Expand All @@ -46,14 +58,18 @@ public int nextDoc() throws IOException {
int currentParent = parentBitSet.nextSetBit(docId);
int bestChild = -1;

// In order to traverse all children for given parent, we have to use docId < parentId, because,
// kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child
// and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4]
// and parentBitSet will have [1,5]
// Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId
while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
binaryVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
currentScore = score;
}
docId = bitSetIterator.nextDoc();
docId = getNextDocId();
}

return bestChild;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,53 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.filtered;
package org.opensearch.knn.index.query.iterators;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;

import java.io.IOException;

/**
* This iterator iterates filterIdsArray to score. However, it dedupe docs per each parent doc
* This iterator iterates filterIdsArray to score if filter is provided else it iterates over all docs.
* However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
*/
public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator {
public class NestedVectorIdsKNNIterator extends VectorIdsKNNIterator {
private final BitSet parentBitSet;

NestedFilteredIdsKNNIterator(
final BitSet filterIdsArray,
public NestedVectorIdsKNNIterator(
@Nullable final BitSet filterIdsArray,
final float[] queryVector,
final KNNFloatVectorValues knnFloatVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) {
) throws IOException {
this(filterIdsArray, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null);
}

public NestedFilteredIdsKNNIterator(
final BitSet filterIdsArray,
public NestedVectorIdsKNNIterator(
final float[] queryVector,
final KNNFloatVectorValues knnFloatVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
this(null, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null);
}

public NestedVectorIdsKNNIterator(
@Nullable final BitSet filterIdsArray,
final float[] queryVector,
final KNNFloatVectorValues knnFloatVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet,
final byte[] quantizedVector,
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo
) {
) throws IOException {
super(filterIdsArray, queryVector, knnFloatVectorValues, spaceType, quantizedVector, segmentLevelQuantizationInfo);
this.parentBitSet = parentBitSet;
}
Expand All @@ -59,14 +70,18 @@ public int nextDoc() throws IOException {
int currentParent = parentBitSet.nextSetBit(docId);
int bestChild = -1;

// In order to traverse all children for given parent, we have to use docId < parentId, because,
// kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child
// and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4]
// and parentBitSet will have [1,5]
// Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId
while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
knnFloatVectorValues.advance(docId);
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
currentScore = score;
}
docId = bitSetIterator.nextDoc();
docId = getNextDocId();
}

return bestChild;
Expand Down
Loading

0 comments on commit 6f6dd56

Please sign in to comment.