Skip to content

Commit

Permalink
Refactor test
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Oct 3, 2024
1 parent 54cdfe4 commit e51ae11
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -95,8 +96,13 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
}

public static void initialize(ModelDao modelDao) {
initialize(modelDao, new ExactSearcher(modelDao));
}

@VisibleForTesting
static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) {
KNNWeight.modelDao = modelDao;
KNNWeight.DEFAULT_EXACT_SEARCHER = new ExactSearcher(modelDao);
KNNWeight.DEFAULT_EXACT_SEARCHER = exactSearcher;
}

@Override
Expand Down
78 changes: 78 additions & 0 deletions src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.KNNRestTestCase.INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER;
Expand Down Expand Up @@ -762,6 +763,83 @@ public void validateANNWithFilterQuery_whenExactSearch_thenSuccess(final boolean
}
}

@SneakyThrows
public void testRadialSearch_whenNoEngineFiles_thenPerformExactSearch() {
ExactSearcher mockedExactSearcher = mock(ExactSearcher.class);
final float[] queryVector = new float[] { 0.1f, 2.0f, 3.0f };
final SpaceType spaceType = randomFrom(SpaceType.L2, SpaceType.INNER_PRODUCT);
KNNWeight.initialize(null, mockedExactSearcher);
final KNNQuery query = KNNQuery.builder()
.field(FIELD_NAME)
.queryVector(queryVector)
.indexName(INDEX_NAME)
.methodParameters(HNSW_METHOD_PARAMETERS)
.build();
final KNNWeight knnWeight = new KNNWeight(query, 1.0f);

final LeafReaderContext leafReaderContext = mock(LeafReaderContext.class);
final SegmentReader reader = mock(SegmentReader.class);
when(leafReaderContext.reader()).thenReturn(reader);

final FSDirectory directory = mock(FSDirectory.class);
when(reader.directory()).thenReturn(directory);
final SegmentInfo segmentInfo = new SegmentInfo(
directory,
Version.LATEST,
Version.LATEST,
SEGMENT_NAME,
100,
false,
false,
KNNCodecVersion.current().getDefaultCodecDelegate(),
Map.of(),
new byte[StringHelper.ID_LENGTH],
Map.of(),
Sort.RELEVANCE
);
segmentInfo.setFiles(Set.of());
final SegmentCommitInfo segmentCommitInfo = new SegmentCommitInfo(segmentInfo, 0, 0, 0, 0, 0, new byte[StringHelper.ID_LENGTH]);
when(reader.getSegmentInfo()).thenReturn(segmentCommitInfo);

final Path path = mock(Path.class);
when(directory.getDirectory()).thenReturn(path);
final FieldInfos fieldInfos = mock(FieldInfos.class);
final FieldInfo fieldInfo = mock(FieldInfo.class);
when(reader.getFieldInfos()).thenReturn(fieldInfos);
when(fieldInfos.fieldInfo(FIELD_NAME)).thenReturn(fieldInfo);
when(fieldInfo.attributes()).thenReturn(
Map.of(
SPACE_TYPE,
spaceType.getValue(),
KNN_ENGINE,
KNNEngine.FAISS.getName(),
PARAMETERS,
String.format(Locale.ROOT, "{\"%s\":\"%s\"}", INDEX_DESCRIPTION_PARAMETER, "HNSW32")
)
);
final ExactSearcher.ExactSearcherContext exactSearchContext = ExactSearcher.ExactSearcherContext.builder()
.isParentHits(true)
// setting to true, so that if quantization details are present we want to do search on the quantized
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
.knnQuery(query)
.build();
when(mockedExactSearcher.searchLeaf(leafReaderContext, exactSearchContext)).thenReturn(DOC_ID_TO_SCORES);
final KNNScorer knnScorer = (KNNScorer) knnWeight.scorer(leafReaderContext);
assertNotNull(knnScorer);
final DocIdSetIterator docIdSetIterator = knnScorer.iterator();
final List<Integer> actualDocIds = new ArrayList<>();
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
assertEquals(DOC_ID_TO_SCORES.get(docId), knnScorer.score(), 0.00000001f);
}
assertEquals(docIdSetIterator.cost(), actualDocIds.size());
assertTrue(Comparators.isInOrder(actualDocIds, Comparator.naturalOrder()));
// verify JNI Service is not called
jniServiceMockedStatic.verifyNoInteractions();
verify(mockedExactSearcher).searchLeaf(leafReaderContext, exactSearchContext);
}

@SneakyThrows
public void testANNWithFilterQuery_whenExactSearchAndThresholdComputations_thenSuccess() {
ModelDao modelDao = mock(ModelDao.class);
Expand Down

0 comments on commit e51ae11

Please sign in to comment.