Skip to content

Commit

Permalink
add feature flag for offline batch ingestion
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
  • Loading branch information
Zhangxunmt committed Sep 24, 2024
1 parent 6a6cac1 commit b06e298
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG;

import java.time.Instant;
import java.util.List;
Expand All @@ -35,6 +36,7 @@
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.ingest.Ingestable;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.tasks.Task;
Expand All @@ -55,27 +57,33 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
MLTaskManager mlTaskManager;
private final Client client;
private ThreadPool threadPool;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportBatchIngestionAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
MLTaskManager mlTaskManager,
ThreadPool threadPool
ThreadPool threadPool,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
this.transportService = transportService;
this.client = client;
this.mlTaskManager = mlTaskManager;
this.threadPool = threadPool;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatchIngestionResponse> listener) {
MLBatchIngestionRequest mlBatchIngestionRequest = MLBatchIngestionRequest.fromActionRequest(request);
MLBatchIngestionInput mlBatchIngestionInput = mlBatchIngestionRequest.getMlBatchIngestionInput();
try {
if (!mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()) {
throw new IllegalStateException(OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG);
}
validateBatchIngestInput(mlBatchIngestionInput);
MLTask mlTask = MLTask
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,8 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED = Setting
.boolSetting("plugins.ml_commons.offline_batch_ingestion_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<List<String>> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting
.listSetting(
"plugins.ml_commons.trusted_connector_endpoints_regex",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -27,13 +28,15 @@ public class MLFeatureEnabledSetting {
private volatile AtomicBoolean isConnectorPrivateIpEnabled;

private volatile Boolean isControllerEnabled;
private volatile Boolean isBatchIngestionEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);
isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings));
isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings);
isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand All @@ -46,6 +49,9 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it));
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, it -> isBatchIngestionEnabled = it);
}

/**
Expand Down Expand Up @@ -84,4 +90,12 @@ public Boolean isControllerEnabled() {
return isControllerEnabled;
}

/**
* Whether the offline batch ingestion is enabled. If disabled, APIs in ml-commons will block offline batch ingestion.
* @return whether the feature is enabled.
*/
public Boolean isOfflineBatchIngestionEnabled() {
return isBatchIngestionEnabled;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public class MLExceptionUtils {
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";
public static final String CONTROLLER_DISABLED_ERR_MSG =
"Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.";
public static final String OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG =
"Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.";

public static String getRootCauseMessage(final Throwable throwable) {
String message = ExceptionUtils.getRootCauseMessage(throwable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -73,6 +74,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
ThreadPool threadPool;
@Mock
ExecutorService executorService;
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

private TransportBatchIngestionAction batchAction;
private MLBatchIngestionInput batchInput;
Expand All @@ -81,7 +84,14 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool);
batchAction = new TransportBatchIngestionAction(
transportService,
actionFilters,
client,
mlTaskManager,
threadPool,
mlFeatureEnabledSetting
);

Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put("chapter", "$.content[0]");
Expand All @@ -106,6 +116,8 @@ public void setup() {
.dataSources(dataSource)
.build();
when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput);

when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(true);
}

public void test_doExecute_success() {
Expand Down Expand Up @@ -181,6 +193,18 @@ public void test_doExecute_handleSuccessRate0() {
);
}

public void test_doExecute_batchIngestionDisabled() {
when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(false);
batchAction.doExecute(task, mlBatchIngestionRequest, actionListener);

ArgumentCaptor<IllegalStateException> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.",
argumentCaptor.getValue().getMessage()
);
}

public void test_doExecute_noDataSource() {
MLBatchIngestionInput batchInput = MLBatchIngestionInput
.builder()
Expand Down

0 comments on commit b06e298

Please sign in to comment.