Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Serialize all models into cluster metadata #1644

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Support script score when doc value is disabled [#1573](https://github.com/opensearch-project/k-NN/pull/1573)
* Implemented the Streaming Feature to stream vectors from Java to JNI layer to enable creation of larger segments for vector indices [#1604](https://github.com/opensearch-project/k-NN/pull/1604)
* Remove unnecessary toString conversion of vector field and added some minor optimization in KNNCodec [1613](https://github.com/opensearch-project/k-NN/pull/1613)
* Serialize all models into cluster metadata [#1499](https://github.com/opensearch-project/k-NN/pull/1499)
### Bug Fixes
* Add stored fields for knn_vector type [#1630](https://github.com/opensearch-project/k-NN/pull/1630)
### Infrastructure
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;

import java.io.File;
Expand Down Expand Up @@ -199,8 +200,8 @@ public static ValidationException validateKnnField(
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
exception.addValidationError(String.format("Model \"%s\" for field \"%s\" does not exist.", modelId, field));
if (!ModelUtil.isModelCreated(modelMetadata)) {
exception.addValidationError(String.format("Model \"%s\" for field \"%s\" is not created.", modelId, field));
return exception;
}

Expand Down Expand Up @@ -286,4 +287,5 @@ public static boolean isSharedIndexStateRequired(KNNEngine knnEngine, String mod
}
return JNIService.isSharedIndexStateRequired(indexAddr, knnEngine);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.index.mapper.ParseContext;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

import java.io.IOException;

Expand Down Expand Up @@ -50,10 +51,10 @@ protected void parseCreateField(ParseContext context) throws IOException {
// model when ingestion starts.
ModelMetadata modelMetadata = this.modelDao.getMetadata(modelId);

if (modelMetadata == null) {
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalStateException(
String.format(
"Model \"%s\" from %s's mapping does not exist. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.",
"Model \"%s\" from %s's mapping is not created. Because the \"%s\" parameter is not updatable, this index will need to be recreated with a valid model.",
modelId,
context.mapperService().index().getName(),
MODEL_ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryShardContext;
Expand Down Expand Up @@ -548,8 +549,8 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie
}

ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new IllegalArgumentException(String.format("Model ID '%s' does not exist.", modelId));
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
}
return modelMetadata;
}
Expand Down
5 changes: 3 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;

Expand Down Expand Up @@ -213,8 +214,8 @@ private Map<Integer, Float> doANNSearch(final LeafReaderContext context, final B
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
if (!ModelUtil.isModelCreated(modelMetadata)) {
throw new RuntimeException("Model \"" + modelId + "\" is not created.");
}

knnEngine = modelMetadata.getKnnEngine();
Expand Down
8 changes: 1 addition & 7 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
);
}, listener::onFailure);

// After the model is indexed, update metadata only if the model is in CREATED state
ActionListener<IndexResponse> onIndexListener;
if (ModelState.CREATED.equals(model.getModelMetadata().getState())) {
onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);
} else {
onIndexListener = onMetaListener;
}
ActionListener<IndexResponse> onIndexListener = getUpdateModelMetadataListener(model.getModelMetadata(), onMetaListener);

// Create the model index if it does not already exist
Runnable indexModelRunnable = () -> indexRequestBuilder.execute(onIndexListener);
Expand Down
30 changes: 30 additions & 0 deletions src/main/java/org/opensearch/knn/indices/ModelUtil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.indices;

/**
* A utility class for models.
*/
public class ModelUtil {

public static boolean isModelPresent(ModelMetadata modelMetadata) {
return modelMetadata != null;
}

public static boolean isModelCreated(ModelMetadata modelMetadata) {
if (!isModelPresent(modelMetadata)) {
return false;
}
return modelMetadata.getState().equals(ModelState.CREATED);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ protected void updateModelsNewCluster() throws IOException, InterruptedException
if (modelDao.isCreated()) {
List<String> modelIds = searchModelIds();
for (String modelId : modelIds) {
Model model = modelDao.get(modelId);
ModelMetadata modelMetadata = model.getModelMetadata();
ModelMetadata modelMetadata = getModelMetadata(modelId);
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as cluster crashed");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as cluster crashed");
}
}
}
Expand All @@ -123,11 +122,10 @@ protected void updateModelsNodesRemoved(List<DiscoveryNode> removedNodes) throws
List<String> modelIds = searchModelIds();
for (DiscoveryNode removedNode : removedNodes) {
for (String modelId : modelIds) {
Model model = modelDao.get(modelId);
ModelMetadata modelMetadata = model.getModelMetadata();
ModelMetadata modelMetadata = getModelMetadata(modelId);
if (modelMetadata.getNodeAssignment().equals(removedNode.getEphemeralId())
&& modelMetadata.getState().equals(ModelState.TRAINING)) {
updateModelStateAsFailed(model, "Training failed to complete as node dropped");
updateModelStateAsFailed(modelId, modelMetadata, "Training failed to complete as node dropped");
}
}
}
Expand Down Expand Up @@ -158,9 +156,11 @@ public void onFailure(Exception e) {
return modelIds;
}

private void updateModelStateAsFailed(Model model, String msg) throws IOException {
model.getModelMetadata().setState(ModelState.FAILED);
model.getModelMetadata().setError(msg);
private void updateModelStateAsFailed(String modelId, ModelMetadata modelMetadata, String msg) throws IOException, ExecutionException,
InterruptedException {
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(msg);
Model model = new Model(modelMetadata, null, modelId);
modelDao.update(model, new ActionListener<IndexResponse>() {
@Override
public void onResponse(IndexResponse indexResponse) {
Expand All @@ -173,4 +173,17 @@ public void onFailure(Exception e) {
}
});
}

private ModelMetadata getModelMetadata(String modelId) throws ExecutionException, InterruptedException {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
// On versions prior to 2.14, only models in created state are present in model metadata.
if (modelMetadata == null) {
log.info(
"Model metadata is null in cluster metadata. This can happen for models training on nodes prior to OpenSearch version 2.14.0. Fetching model information from system index."
);
Model model = modelDao.get(modelId);
return model.getModelMetadata();
}
return modelMetadata;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.plugins.SearchPlugin;

import java.io.IOException;
Expand Down Expand Up @@ -683,6 +684,7 @@ public void testDoToQuery_FromModel() {
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.COSINESIMIL);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -712,6 +714,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenDistanceThreshold
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down Expand Up @@ -744,6 +747,7 @@ public void testDoToQuery_whenFromModel_whenDoRadiusSearch_whenScoreThreshold_th
when(modelMetadata.getDimension()).thenReturn(4);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(SpaceType.L2);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
KNNQueryBuilder.initialize(modelDao);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.jni.JNIService;

import java.io.IOException;
Expand Down Expand Up @@ -167,6 +168,7 @@ public void testQueryScoreForFaissWithModel() {
ModelMetadata modelMetadata = mock(ModelMetadata.class);
when(modelMetadata.getKnnEngine()).thenReturn(KNNEngine.FAISS);
when(modelMetadata.getSpaceType()).thenReturn(spaceType);
when(modelMetadata.getState()).thenReturn(ModelState.CREATED);
when(modelDao.getMetadata(eq("modelId"))).thenReturn(modelMetadata);

KNNWeight.initialize(modelDao);
Expand Down Expand Up @@ -254,7 +256,7 @@ public void testQueryScoreForFaissWithNonExistingModel() throws IOException {
when(fieldInfo.getAttribute(eq(MODEL_ID))).thenReturn(modelId);

RuntimeException ex = expectThrows(RuntimeException.class, () -> knnWeight.scorer(leafReaderContext));
assertEquals(String.format("Model \"%s\" does not exist.", modelId), ex.getMessage());
assertEquals(String.format("Model \"%s\" is not created.", modelId), ex.getMessage());
}

@SneakyThrows
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ public void testValidation_valid_trainingIndexBuiltFromModel() {
// Mock the model dao to return metadata for modelId to recognize it is a duplicate
ModelMetadata trainingFieldModelMetadata = mock(ModelMetadata.class);
when(trainingFieldModelMetadata.getDimension()).thenReturn(dimension);
when(trainingFieldModelMetadata.getState()).thenReturn(ModelState.CREATED);

ModelDao modelDao = mock(ModelDao.class);
when(modelDao.getMetadata(modelId)).thenReturn(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public void testUpdateModelsNewCluster() throws IOException, InterruptedExceptio
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
doAnswer(invocationOnMock -> {
SearchResponse searchResponse = mock(SearchResponse.class);
SearchHits searchHits = mock(SearchHits.class);
Expand Down Expand Up @@ -144,6 +145,7 @@ public void testUpdateModelsNodesRemoved() throws IOException, InterruptedExcept
ModelDao modelDao = mock(ModelDao.class);
when(modelDao.isCreated()).thenReturn(true);
when(modelDao.get(modelId)).thenReturn(model);
when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata);
DiscoveryNode node1 = mock(DiscoveryNode.class);
when(node1.getEphemeralId()).thenReturn("test-node-model-match");
DiscoveryNode node2 = mock(DiscoveryNode.class);
Expand Down
Loading