diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d9d35ef0..5daf3b564 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.13...2.x) ### Features ### Enhancements +* Make the HitQueue size more appropriate for exact search [#1549](https://github.com/opensearch-project/k-NN/pull/1549) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 1c4d0a646..06bf96d63 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -117,7 +117,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { * This improves the recall. */ if (filterWeight != null && canDoExactSearch(cardinality)) { - docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet)); + docIdsToScoreMap.putAll(doExactSearch(context, filterBitSet, cardinality)); } else { Map annResults = doANNSearch(context, filterBitSet, cardinality); if (annResults == null) { @@ -131,7 +131,7 @@ public Scorer scorer(LeafReaderContext context) throws IOException { annResults.size(), cardinality ); - annResults = doExactSearch(context, filterBitSet); + annResults = doExactSearch(context, filterBitSet, cardinality); } docIdsToScoreMap.putAll(annResults); } @@ -309,10 +309,10 @@ private Map doANNSearch(final LeafReaderContext context, final B .collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType))); } - private Map doExactSearch(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet) { + private Map doExactSearch(final LeafReaderContext leafReaderContext, final BitSet filterIdsBitSet, int cardinality) { try { // Creating min heap and init with MAX DocID and Score as -INF. - final HitQueue queue = new HitQueue(this.knnQuery.getK(), true); + final HitQueue queue = new HitQueue(Math.min(this.knnQuery.getK(), cardinality), true); ScoreDoc topDoc = queue.top(); final Map docToScore = new HashMap<>(); FilteredIdsKNNIterator iterator = getFilteredKNNIterator(leafReaderContext, filterIdsBitSet);