Skip to content

Commit

Permalink
add feature flag for offline batch inference
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Sep 24, 2024
1 parent b06e298 commit 2a9228d
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import java.util.HashMap;
Expand Down Expand Up @@ -51,8 +52,8 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
Expand All @@ -74,7 +75,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action
MLModelManager mlModelManager;

MLTaskManager mlTaskManager;
MLModelCacheHelper modelCacheHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public CancelBatchJobTransportAction(
Expand All @@ -87,7 +88,8 @@ public CancelBatchJobTransportAction(
ConnectorAccessControlHelper connectorAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLCancelBatchJobAction.NAME, transportService, actionFilters, MLCancelBatchJobRequest::new);
this.client = client;
Expand All @@ -98,6 +100,7 @@ public CancelBatchJobTransportAction(
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -116,6 +119,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCanc
MLTask mlTask = MLTask.parse(parser);

// check if function is remote and task is of type batch prediction
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION
&& !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
}
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) {
processRemoteBatchPrediction(mlTask, actionListener);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

Expand Down Expand Up @@ -68,8 +69,8 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
Expand All @@ -91,7 +92,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
MLModelManager mlModelManager;

MLTaskManager mlTaskManager;
MLModelCacheHelper modelCacheHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

volatile List<String> remoteJobStatusFields;
volatile Pattern remoteJobCompletedStatusRegexPattern;
Expand All @@ -111,6 +112,7 @@ public GetTaskTransportAction(
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting,
Settings settings
) {
super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new);
Expand All @@ -122,6 +124,7 @@ public GetTaskTransportAction(
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;

remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it);
Expand Down Expand Up @@ -178,6 +181,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
MLTask mlTask = MLTask.parse(parser);

// check if function is remote and task is of type batch prediction
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION
&& !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
}
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) {
processRemoteBatchPrediction(mlTask, taskId, actionListener);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import com.google.common.collect.ImmutableList;

public class RestMLGetTaskAction extends BaseRestHandler {
private static final String ML_GET_Task_ACTION = "ml_get_task_action";
private static final String ML_GET_TASK_ACTION = "ml_get_task_action";

/**
* Constructor
Expand All @@ -33,7 +33,7 @@ public RestMLGetTaskAction() {}

@Override
public String getName() {
return ML_GET_Task_ACTION;
return ML_GET_TASK_ACTION;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
Expand Down Expand Up @@ -131,6 +132,8 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
throw new IllegalArgumentException("Wrong action type in the rest request path!");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED = Setting
.boolSetting("plugins.ml_commons.offline_batch_ingestion_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.offline_batch_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<List<String>> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting
.listSetting(
"plugins.ml_commons.trusted_connector_endpoints_regex",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

Expand All @@ -29,6 +30,7 @@ public class MLFeatureEnabledSetting {

private volatile Boolean isControllerEnabled;
private volatile Boolean isBatchIngestionEnabled;
private volatile Boolean isBatchInferenceEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
Expand All @@ -37,6 +39,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings));
isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings);
isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings);
isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand All @@ -52,6 +55,9 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, it -> isBatchIngestionEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, it -> isBatchInferenceEnabled = it);
}

/**
Expand Down Expand Up @@ -98,4 +104,11 @@ public Boolean isOfflineBatchIngestionEnabled() {
return isBatchIngestionEnabled;
}

/**
* Whether the offline batch inference is enabled. If disabled, APIs in ml-commons will block offline batch inference.
* @return whether the feature is enabled.
*/
public Boolean isOfflineBatchInferenceEnabled() {
return isBatchInferenceEnabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public class MLExceptionUtils {
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
public static final String LOCAL_MODEL_DISABLED_ERR_MSG =
"Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.";
public static final String BATCH_INFERENCE_DISABLED_ERR_MSG =
"Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.";
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";
public static final String CONTROLLER_DISABLED_ERR_MSG =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.script.ScriptService;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -106,6 +107,9 @@ public class CancelBatchJobTransportActionTests extends OpenSearchTestCase {
@Mock
private MLTaskManager mlTaskManager;

@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -139,7 +143,8 @@ public void setup() throws IOException {
connectorAccessControlHelper,
encryptor,
mlTaskManager,
mlModelManager
mlModelManager,
mlFeatureEnabledSetting
)
);

Expand Down Expand Up @@ -182,7 +187,7 @@ public void setup() throws IOException {
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any());

when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true);
}

public void testGetTask_NullResponse() {
Expand Down Expand Up @@ -221,6 +226,28 @@ public void testGetTask_IndexNotFoundException() {
assertEquals("Fail to find task", argumentCaptor.getValue().getMessage());
}

public void testGetTask_FeatureFlagDisabled() throws IOException {
Map<String, Object> remoteJob = new HashMap<>();
remoteJob.put("Status", "IN PROGRESS");
remoteJob.put("TransformJobName", "SM-offline-batch-transform13");

GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob);

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());
when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false);
cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener);
ArgumentCaptor<IllegalStateException> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.",
argumentCaptor.getValue().getMessage()
);
}

@Ignore
public void testGetTask_SuccessBatchPredictCancel() throws IOException {
Map<String, Object> remoteJob = new HashMap<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.script.ScriptService;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -116,6 +117,9 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase {
@Mock
private MLTaskManager mlTaskManager;

@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Rule
public ExpectedException exceptionRule = ExpectedException.none();

Expand Down Expand Up @@ -172,6 +176,7 @@ public void setup() throws IOException {
encryptor,
mlTaskManager,
mlModelManager,
mlFeatureEnabledSetting,
settings
)
);
Expand Down Expand Up @@ -215,7 +220,7 @@ public void setup() throws IOException {
listener.onResponse(connector);
return null;
}).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any());

when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true);
}

public void testGetTask_NullResponse() {
Expand Down Expand Up @@ -299,6 +304,31 @@ public void test_BatchPredictStatus_NoConnector() throws IOException {
assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage());
}

public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException {
Map<String, Object> remoteJob = new HashMap<>();
remoteJob.put("Status", "IN PROGRESS");
remoteJob.put("TransformJobName", "SM-offline-batch-transform13");

when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false);

GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob);

doAnswer(invocation -> {
ActionListener<GetResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(getResponse);
return null;
}).when(client).get(any(), any());
when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false);

getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener);
ArgumentCaptor<IllegalStateException> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.",
argumentCaptor.getValue().getMessage()
);
}

public void test_BatchPredictStatus_NoAccessToConnector() throws IOException {
Map<String, Object> remoteJob = new HashMap<>();
remoteJob.put("Status", "IN PROGRESS");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,26 @@ public void testPrepareRequest() throws Exception {

public void testPrepareBatchRequest() throws Exception {
RestRequest request = getBatchRestRequest();
when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true);
restMLPredictionAction.handleRequest(request, channel, client);
ArgumentCaptor<MLPredictionTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argumentCaptor.capture(), any());
MLInput mlInput = argumentCaptor.getValue().getMlInput();
verifyParsedBatchMLInput(mlInput);
}

public void testPrepareBatchRequest_FeatureFlagDisabled() throws Exception {
thrown.expect(IllegalStateException.class);
thrown
.expectMessage(
"Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true."
);

RestRequest request = getBatchRestRequest();
when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false);
restMLPredictionAction.handleRequest(request, channel, client);
}

public void testPrepareBatchRequest_WrongActionType() throws Exception {
thrown.expect(IllegalArgumentException.class);
thrown.expectMessage("Wrong Action Type");
Expand Down

0 comments on commit 2a9228d

Please sign in to comment.