Skip to content

Commit

Permalink
Fix code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Oct 3, 2024
1 parent 9637fb7 commit abc6e69
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 17 deletions.
16 changes: 6 additions & 10 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.knn.index.query;

import com.google.common.base.Predicates;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.NonNull;
import lombok.Value;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
Expand Down Expand Up @@ -65,7 +67,7 @@ public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext,
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopK(iterator, exactSearcherContext.getK());
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
}

/**
Expand Down Expand Up @@ -104,18 +106,16 @@ private Map<Integer, Float> scoreAllDocs(KNNIterator iterator) throws IOExceptio
return docToScore;
}

private Map<Integer, Float> searchTopCandidates(KNNIterator iterator, int limit, Predicate<Float> filterScore) throws IOException {
private Map<Integer, Float> searchTopCandidates(KNNIterator iterator, int limit, @NonNull Predicate<Float> filterScore)
throws IOException {
// Creating min heap and init with MAX DocID and Score as -INF.
final HitQueue queue = new HitQueue(limit, true);
ScoreDoc topDoc = queue.top();
final Map<Integer, Float> docToScore = new HashMap<>();
int docId;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
final float currentScore = iterator.score();
if (filterScore != null && Predicate.not(filterScore).test(currentScore)) {
continue;
}
if (currentScore > topDoc.score) {
if (filterScore.test(currentScore) && currentScore > topDoc.score) {
topDoc.score = currentScore;
topDoc.doc = docId;
// As the HitQueue is min heap, updating top will bring the doc with -INF score or worst score we
Expand All @@ -138,10 +138,6 @@ private Map<Integer, Float> searchTopCandidates(KNNIterator iterator, int limit,
return docToScore;
}

private Map<Integer, Float> searchTopK(KNNIterator iterator, int k) throws IOException {
return searchTopCandidates(iterator, k, null);
}

private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, KNNIterator iterator, float minScore)
throws IOException {
int maxResultWindow = context.getKnnQuery().getContext().getMaxResultWindow();
Expand Down
7 changes: 4 additions & 3 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -623,10 +623,11 @@ public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCrea

// Assert we have the right number of documents in the index
assertEquals(numDocs, getDocCount(indexName));
// KNN Query should return empty result

final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1);
final List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(0, results.size());
// expect result due to exact search
assertEquals(1, results.size());

deleteKNNIndex(indexName);
validateGraphEviction();
Expand Down Expand Up @@ -682,7 +683,7 @@ public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph()
// KNN Query should return empty result
final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1);
final List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
assertEquals(0, results.size());
assertEquals(1, results.size());

// update index setting to build graph and do force merge
// update build vector data structure setting
Expand Down
91 changes: 90 additions & 1 deletion src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -360,7 +361,6 @@ public void testScorer_whenNoVectorFieldsInDocument_thenEmptyScorerIsReturned()
final Path path = mock(Path.class);
when(directory.getDirectory()).thenReturn(path);
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
// When no knn fields are available , field info for vector field will be null
when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(null);
Expand Down Expand Up @@ -763,6 +763,95 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean
}
}

@SneakyThrows
public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() {
try (MockedStatic<KNNVectorValuesFactory> valuesFactoryMockedStatic = Mockito.mockStatic(KNNVectorValuesFactory.class)) {
final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f };
final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT);
final List<float[]> dataVectors = Arrays.asList(
new float[] { 11.0f, 12.0f, 13.0f },
new float[] { 14.0f, 15.0f, 16.0f },
new float[] { 17.0f, 18.0f, 19.0f }
);
final List<Float> expectedScores = dataVectors.stream()
.map(vector -> spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector))
.collect(Collectors.toList());
final Float score = Collections.min(expectedScores);
final float radius = KNNEngine.FAISS.scoreToRadialThreshold(score, spaceType);
final int maxResults = 1000;
final KNNQuery.Context context = mock(KNNQuery.Context.class);
when(context.getMaxResultWindow()).thenReturn(maxResults);
KNNWeight.initialize(null);

final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(queryVector)
.radius(radius)
.indexName(INDEX_NAME)
.context(context)
.methodParameters(HNSW_METHOD_PARAMETERS)
.build();
final KNNWeight knnWeight = new KNNWeight(query, 1.0f);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final FSDirectory directory = mock(FSDirectory.class);
when(reader.directory()).thenReturn(directory);
final SegmentInfo segmentInfo = new SegmentInfo(
directory,
Version.LATEST,
Version.LATEST,
SEGMENT_NAME,
100,
false,
false,
KNNCodecVersion.current().getDefaultCodecDelegate(),
Map.of(),
new byte[StringHelper.ID_LENGTH],
Map.of(),
Sort.RELEVANCE
);
segmentInfo.setFiles(Set.of());
final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]);
when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo);

final Path path = mock(Path.class);
when(directory.getDirectory()).thenReturn(path);
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(
Map.of(
SPACE_TYPE,
spaceType.getValue(),
KNN_ENGINE,
KNNEngine.FAISS.getName(),
PARAMETERS,
String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32")
)
);
when(fieldInfo.getAttribute(SPACE_TYPE)).thenReturn(spaceType.getValue());
KNNFloatVectorValues floatVectorValues = mock(KNNFloatVectorValues.class);
valuesFactoryMockedStatic.when(() -> KNNVectorValuesFactory.getVectorValues(fieldInfo, reader)).thenReturn(floatVectorValues);
when(floatVectorValues.nextDoc()).thenReturn(0, 1, 2, NO_MORE_DOCS);
when(floatVectorValues.getVector()).thenReturn(dataVectors.get(0), dataVectors.get(1), dataVectors.get(2));

final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
assertNotNull(docIdSetIterator);
assertEquals(docIdSetIterator.cost(), dataVectors.size());
List<Float> actualScores = new ArrayList<>();
while (docIdSetIterator.nextDoc() != NO_MORE_DOCS) {
actualScores.add(knnScorer.score());
}
assertEquals(expectedScores, actualScores);
}
}

@SneakyThrows
public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() {
ModelDao modelDao = mock(ModelDao.class);
Expand Down
40 changes: 37 additions & 3 deletions src/test/java/org/opensearch/knn/integ/BinaryIndexIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_
assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size());

// update build vector data structure setting
updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0));
updateIndexSettings(
INDEX_NAME,
Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH)
);
forceMergeKnnIndex(INDEX_NAME, 1);

int k = 100;
Expand All @@ -133,15 +136,18 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsNegativeEndToEnd_
}

@SneakyThrows
public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() throws Exception {
public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_thenBuildGraphBasedOnSetting() {
// Create Index
createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 128, testData.indexData.docs.length);
ingestTestData(INDEX_NAME, FIELD_NAME, false);

assertEquals(1, runKnnQuery(INDEX_NAME, FIELD_NAME, testData.queries[0], 1).size());

// update build vector data structure setting
updateIndexSettings(INDEX_NAME, Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, 0));
updateIndexSettings(
INDEX_NAME,
Settings.builder().put(KNNSettings.INDEX_KNN_BUILD_VECTOR_DATA_STRUCTURE_THRESHOLD, ALWAYS_BUILD_GRAPH)
);
forceMergeKnnIndex(INDEX_NAME, 1);

int k = 100;
Expand All @@ -155,6 +161,17 @@ public void testFaissHnswBinary_whenBuildVectorGraphThresholdIsProvidedEndToEnd_
}
}

@SneakyThrows
public void testFaissHnswBinary_whenRadialSearch_thenThrowException() {
// Create Index
createKnnHnswBinaryIndex(KNNEngine.FAISS, INDEX_NAME, FIELD_NAME, 16);

// Query
float[] queryVector = { (byte) 0b10001111, (byte) 0b10000000 };
Exception e = expectThrows(Exception.class, () -> runRnnQuery(INDEX_NAME, FIELD_NAME, queryVector, 1, 4));
assertTrue(e.getMessage(), e.getMessage().contains("Binary data type does not support radial search"));
}

private float getRecall(final Set<String> truth, final Set<String> result) {
// Count the number of relevant documents retrieved
result.retainAll(truth);
Expand All @@ -167,6 +184,23 @@ private float getRecall(final Set<String> truth, final Set<String> result) {
return (float) relevantRetrieved / totalRelevant;
}

private List<KNNResult> runRnnQuery(
final String indexName,
final String fieldName,
final float[] queryVector,
final float minScore,
final int size
) throws Exception {
String query = KNNJsonQueryBuilder.builder()
.fieldName(fieldName)
.vector(ArrayUtils.toObject(queryVector))
.minScore(minScore)
.build()
.getQueryString();
Response response = searchKNNIndex(indexName, query, size);
return parseSearchResponse(EntityUtils.toString(response.getEntity()), fieldName);
}

private List<KNNResult> runKnnQuery(final String indexName, final String fieldName, final float[] queryVector, final int k)
throws Exception {
String query = KNNJsonQueryBuilder.builder()
Expand Down

0 comments on commit abc6e69

Please sign in to comment.