Skip to content

Commit

Permalink
Adds integration test for nmslib integratio with ef_search
Browse files Browse the repository at this point in the history
Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed Jun 7, 2024
1 parent d636476 commit 6da2592
Showing 1 changed file with 89 additions and 17 deletions.
106 changes: 89 additions & 17 deletions src/test/java/org/opensearch/knn/index/NmslibIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@

import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.BeforeClass;
import org.opensearch.client.Response;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.client.ResponseException;
Expand Down Expand Up @@ -52,6 +54,62 @@ public static void setUpClass() throws IOException {
testData = new TestUtils.TestData(testIndexVectors.getPath(), testQueries.getPath());
}

public void testInvalidMethodParameters() throws Exception {
String indexName = "test-index-1";
String fieldName = "test-field-1";
Integer dimension = testData.indexData.vectors[0].length;
KNNMethod hnswMethod = KNNEngine.NMSLIB.getMethod(KNNConstants.METHOD_HNSW);
SpaceType spaceType = SpaceType.L1;

// Create an index
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", dimension)
.startObject(KNNConstants.KNN_METHOD)
.field(KNNConstants.NAME, hnswMethod.getMethodComponent().getName())
.field(KNNConstants.METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue())
.field(KNNConstants.KNN_ENGINE, KNNEngine.NMSLIB.getName())
.startObject(KNNConstants.PARAMETERS)
.field(KNNConstants.METHOD_PARAMETER_M, 32)
.field(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, 100)
.endObject()
.endObject()
.endObject()
.endObject()
.endObject();

final Map<String, Object> mappingMap = xContentBuilderToMap(builder);
String mapping = builder.toString();

createKnnIndex(indexName, mapping);
assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName)));

// Index the test data
// Adding only doc to cut on integ test time
addKnnDoc(
indexName,
Integer.toString(testData.indexData.docs[0]),
fieldName,
Floats.asList(testData.indexData.vectors[0]).toArray()
);

expectThrows(IllegalArgumentException.class, () -> searchKNNIndex(indexName, KNNQueryBuilder.builder()
.k(10)
.methodParameters(Map.of("foo", "bar"))
.vector(testData.queries[0])
.fieldName(fieldName)
.build(), 10));
expectThrows(IllegalArgumentException.class, () -> searchKNNIndex(indexName, KNNQueryBuilder.builder()
.k(10)
.methodParameters(Map.of("ef_search", "bar"))
.vector(testData.queries[0])
.fieldName(fieldName)
.build(), 10));
}

public void testEndToEnd() throws Exception {
String indexName = "test-index-1";
String fieldName = "test-field-1";
Expand Down Expand Up @@ -104,23 +162,11 @@ public void testEndToEnd() throws Exception {
refreshAllIndices();
assertEquals(testData.indexData.docs.length, getDocCount(indexName));

int k = 10;
for (int i = 0; i < testData.queries.length; i++) {
Response response = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, testData.queries[i], k), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
0.0001
);
}
}
//search index
//without method parameters
validateSearch(indexName, fieldName, spaceType, null);
// With valid method params
validateSearch(indexName, fieldName, spaceType, Map.of("ef_search", 50));

// Delete index
deleteKNNIndex(indexName);
Expand All @@ -138,6 +184,32 @@ public void testEndToEnd() throws Exception {
fail("Graphs are not getting evicted");
}

@SneakyThrows
private void validateSearch(final String indexName,
final String fieldName,
SpaceType spaceType,
final Map<String, Object> methodParams) {
int k = 10;
for (int i = 0; i < testData.queries.length; i++) {
Response response = searchKNNIndex(indexName, KNNQueryBuilder.builder()
.fieldName(fieldName).vector(testData.queries[i]).k(k).methodParameters(methodParams)
.build(), k);
String responseBody = EntityUtils.toString(response.getEntity());
List<KNNResult> knnResults = parseSearchResponse(responseBody, fieldName);
assertEquals(k, knnResults.size());

List<Float> actualScores = parseSearchResponseScore(responseBody, fieldName);
for (int j = 0; j < k; j++) {
float[] primitiveArray = knnResults.get(j).getVector();
assertEquals(
KNNEngine.NMSLIB.score(KNNScoringUtil.l1Norm(testData.queries[i], primitiveArray), spaceType),
actualScores.get(j),
0.0001
);
}
}
}

public void testAddDoc() throws Exception {
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
Float[] vector = { 6.0f, 6.0f };
Expand Down

0 comments on commit 6da2592

Please sign in to comment.