From 2a9228d217777b1243848388a9f4e06f724301d0 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Tue, 24 Sep 2024 15:31:29 -0700 Subject: [PATCH] add feature flag for offline batch inference Signed-off-by: Xun Zhang --- .../tasks/CancelBatchJobTransportAction.java | 13 ++++++-- .../action/tasks/GetTaskTransportAction.java | 11 +++++-- .../ml/plugin/MachineLearningPlugin.java | 3 +- .../ml/rest/RestMLGetTaskAction.java | 4 +-- .../ml/rest/RestMLPredictionAction.java | 3 ++ .../ml/settings/MLCommonsSettings.java | 3 ++ .../ml/settings/MLFeatureEnabledSetting.java | 13 ++++++++ .../opensearch/ml/utils/MLExceptionUtils.java | 2 ++ .../CancelBatchJobTransportActionTests.java | 31 ++++++++++++++++-- .../tasks/GetTaskTransportActionTests.java | 32 ++++++++++++++++++- .../ml/rest/RestMLPredictionActionTests.java | 13 ++++++++ 11 files changed, 117 insertions(+), 11 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java index 6a7fd617ae..95e43ca929 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT; +import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry; import java.util.HashMap; @@ -51,8 +52,8 @@ import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; -import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.script.ScriptService; import org.opensearch.tasks.Task; @@ -74,7 +75,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction remoteJobStatusFields; volatile Pattern remoteJobCompletedStatusRegexPattern; @@ -111,6 +112,7 @@ public GetTaskTransportAction( EncryptorImpl encryptor, MLTaskManager mlTaskManager, MLModelManager mlModelManager, + MLFeatureEnabledSetting mlFeatureEnabledSetting, Settings settings ) { super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new); @@ -122,6 +124,7 @@ public GetTaskTransportAction( this.encryptor = encryptor; this.mlTaskManager = mlTaskManager; this.mlModelManager = mlModelManager; + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it); @@ -178,6 +181,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener> getSettings() { MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX, MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED, - MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED + MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, + MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java index aeee474864..5284d42a32 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLGetTaskAction.java @@ -24,7 +24,7 @@ import com.google.common.collect.ImmutableList; public class RestMLGetTaskAction extends BaseRestHandler { - private static final String ML_GET_Task_ACTION = "ml_get_task_action"; + private static final String ML_GET_TASK_ACTION = "ml_get_task_action"; /** * Constructor @@ -33,7 +33,7 @@ public RestMLGetTaskAction() {} @Override public String getName() { - return ML_GET_Task_ACTION; + return ML_GET_TASK_ACTION; } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java index 72b841eb7b..68c0146ab2 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLPredictionAction.java @@ -7,6 +7,7 @@ import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG; import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM; @@ -131,6 +132,8 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG); } else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) { throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG); + } else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) { + throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG); } else if (!ActionType.isValidActionInModelPrediction(actionType)) { throw new IllegalArgumentException("Wrong action type in the rest request path!"); } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index a8e6a0867d..5b0e110d52 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -139,6 +139,9 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED = Setting .boolSetting("plugins.ml_commons.offline_batch_ingestion_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED = Setting + .boolSetting("plugins.ml_commons.offline_batch_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); + public static final Setting> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting .listSetting( "plugins.ml_commons.trusted_connector_endpoints_regex", diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java index b77e8fcf66..93159125de 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLFeatureEnabledSetting.java @@ -11,6 +11,7 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED; @@ -29,6 +30,7 @@ public class MLFeatureEnabledSetting { private volatile Boolean isControllerEnabled; private volatile Boolean isBatchIngestionEnabled; + private volatile Boolean isBatchInferenceEnabled; public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) { isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings); @@ -37,6 +39,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings)); isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings); isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings); + isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings); clusterService .getClusterSettings() @@ -52,6 +55,9 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, it -> isBatchIngestionEnabled = it); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, it -> isBatchInferenceEnabled = it); } /** @@ -98,4 +104,11 @@ public Boolean isOfflineBatchIngestionEnabled() { return isBatchIngestionEnabled; } + /** + * Whether the offline batch inference is enabled. If disabled, APIs in ml-commons will block offline batch inference. + * @return whether the feature is enabled. + */ + public Boolean isOfflineBatchInferenceEnabled() { + return isBatchInferenceEnabled; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java index 2d8ed1084c..7a056c762c 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLExceptionUtils.java @@ -22,6 +22,8 @@ public class MLExceptionUtils { "Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true."; public static final String LOCAL_MODEL_DISABLED_ERR_MSG = "Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true."; + public static final String BATCH_INFERENCE_DISABLED_ERR_MSG = + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true."; public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG = "Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true."; public static final String CONTROLLER_DISABLED_ERR_MSG = diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java index 99d9fbf8a1..0c6939ea77 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportActionTests.java @@ -61,6 +61,7 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; @@ -106,6 +107,9 @@ public class CancelBatchJobTransportActionTests extends OpenSearchTestCase { @Mock private MLTaskManager mlTaskManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -139,7 +143,8 @@ public void setup() throws IOException { connectorAccessControlHelper, encryptor, mlTaskManager, - mlModelManager + mlModelManager, + mlFeatureEnabledSetting ) ); @@ -182,7 +187,7 @@ public void setup() throws IOException { listener.onResponse(connector); return null; }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); - + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true); } public void testGetTask_NullResponse() { @@ -221,6 +226,28 @@ public void testGetTask_IndexNotFoundException() { assertEquals("Fail to find task", argumentCaptor.getValue().getMessage()); } + public void testGetTask_FeatureFlagDisabled() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false); + cancelBatchJobTransportAction.doExecute(null, mlCancelBatchJobRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + @Ignore public void testGetTask_SuccessBatchPredictCancel() throws IOException { Map remoteJob = new HashMap<>(); diff --git a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java index 1c9a1c449a..25c43eb9b6 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/tasks/GetTaskTransportActionTests.java @@ -71,6 +71,7 @@ import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.helper.ConnectorAccessControlHelper; import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.task.MLTaskManager; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; @@ -116,6 +117,9 @@ public class GetTaskTransportActionTests extends OpenSearchTestCase { @Mock private MLTaskManager mlTaskManager; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -172,6 +176,7 @@ public void setup() throws IOException { encryptor, mlTaskManager, mlModelManager, + mlFeatureEnabledSetting, settings ) ); @@ -215,7 +220,7 @@ public void setup() throws IOException { listener.onResponse(connector); return null; }).when(connectorAccessControlHelper).getConnector(eq(client), anyString(), any()); - + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true); } public void testGetTask_NullResponse() { @@ -299,6 +304,31 @@ public void test_BatchPredictStatus_NoConnector() throws IOException { assertEquals("You don't have permission to access this connector", argumentCaptor.getValue().getMessage()); } + public void test_BatchPredictStatus_FeatureFlagDisabled() throws IOException { + Map remoteJob = new HashMap<>(); + remoteJob.put("Status", "IN PROGRESS"); + remoteJob.put("TransformJobName", "SM-offline-batch-transform13"); + + when(connectorAccessControlHelper.validateConnectorAccess(eq(client), any())).thenReturn(false); + + GetResponse getResponse = prepareMLTask(FunctionName.REMOTE, MLTaskType.BATCH_PREDICTION, remoteJob); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(getResponse); + return null; + }).when(client).get(any(), any()); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false); + + getTaskTransportAction.doExecute(null, mlTaskGetRequest, actionListener); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.", + argumentCaptor.getValue().getMessage() + ); + } + public void test_BatchPredictStatus_NoAccessToConnector() throws IOException { Map remoteJob = new HashMap<>(); remoteJob.put("Status", "IN PROGRESS"); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java index 001b3709a8..c90f765ed0 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLPredictionActionTests.java @@ -157,6 +157,7 @@ public void testPrepareRequest() throws Exception { public void testPrepareBatchRequest() throws Exception { RestRequest request = getBatchRestRequest(); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(true); restMLPredictionAction.handleRequest(request, channel, client); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); verify(client, times(1)).execute(eq(MLPredictionTaskAction.INSTANCE), argumentCaptor.capture(), any()); @@ -164,6 +165,18 @@ public void testPrepareBatchRequest() throws Exception { verifyParsedBatchMLInput(mlInput); } + public void testPrepareBatchRequest_FeatureFlagDisabled() throws Exception { + thrown.expect(IllegalStateException.class); + thrown + .expectMessage( + "Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true." + ); + + RestRequest request = getBatchRestRequest(); + when(mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()).thenReturn(false); + restMLPredictionAction.handleRequest(request, channel, client); + } + public void testPrepareBatchRequest_WrongActionType() throws Exception { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("Wrong Action Type");