From ecc992899c34341f962df5496b32491522925174 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Wed, 22 May 2024 15:25:42 -0700 Subject: [PATCH] fix memory CB bugs Signed-off-by: Xun Zhang --- .../TransportPredictionTaskAction.java | 3 +++ .../ml/breaker/MemoryCircuitBreaker.java | 2 +- .../ml/task/MLPredictTaskRunner.java | 8 ++++++- .../TransportPredictionTaskActionTests.java | 23 +++++++++++++++++++ .../ml/breaker/MemoryCircuitBreakerTests.java | 18 +++++++++++++++ .../ml/model/MLModelCacheHelperTests.java | 6 ++++- .../ml/model/MLModelManagerTests.java | 5 ++-- .../ml/task/MLExecuteTaskRunnerTests.java | 10 ++++---- 8 files changed, 65 insertions(+), 10 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 94ed36214a..93d51dd1aa 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -25,6 +25,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.transport.MLTaskResponse; @@ -177,6 +178,8 @@ public void onResponse(MLModel mlModel) { ); } else if (e instanceof MLResourceNotFoundException) { wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND)); + } else if (e instanceof MLLimitExceededException) { + wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE)); } else { wrappedListener .onFailure( diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java index 5e045ae539..c1287ef481 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java @@ -50,6 +50,6 @@ public Short getThreshold() { @Override public boolean isOpen() { - return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold(); + return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 101d9c9244..b5f8b46167 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -143,7 +143,13 @@ public void dispatchTask( if (clusterService.localNode().getId().equals(node.getId())) { log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId()); request.setDispatchTask(false); - executeTask(request, listener); + run( + // This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here + functionName, + request, + transportService, + listener + ); } else { log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId()); request.setDispatchTask(false); diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java index aa7afdce6e..88d0262c4e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java @@ -43,6 +43,7 @@ import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataset.DataFrameInputDataset; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; @@ -235,6 +236,28 @@ public void testPrediction_MLResourceNotFoundException() { assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage()); } + public void testPrediction_MLLimitExceededException() { + when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model); + when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(3); + listener.onFailure(new MLLimitExceededException("Memory Circuit Breaker is open, please check your resources!")); + return null; + }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any()); + + doAnswer(invocation -> { + ((ActionListener) invocation.getArguments()[3]).onResponse(null); + return null; + }).when(mlPredictTaskRunner).run(any(), any(), any(), any()); + + transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class); + verify(actionListener).onFailure(argumentCaptor.capture()); + assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage()); + } + public void testValidateInputSchemaSuccess() { RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet .builder() diff --git a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java index cdd1f6fc22..8c7f6f41d4 100644 --- a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java @@ -84,4 +84,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() { settingsService.applySettings(newSettingsBuilder.build()); Assert.assertFalse(breaker.isOpen()); } + + @Test + public void testIsOpen_DisableMemoryCB() { + ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD); + when(clusterService.getClusterSettings()).thenReturn(settingsService); + + CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService); + + when(mem.getHeapUsedPercent()).thenReturn((short) 90); + Assert.assertTrue(breaker.isOpen()); + + when(mem.getHeapUsedPercent()).thenReturn((short) 100); + Settings.Builder newSettingsBuilder = Settings.builder(); + newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100); + settingsService.applySettings(newSettingsBuilder.build()); + Assert.assertFalse(breaker.isOpen()); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java index 232290520d..dc0666b8a1 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java @@ -24,6 +24,7 @@ import org.junit.rules.ExpectedException; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.ClusterManagerMetrics; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -64,13 +65,16 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase { @Mock private TokenBucket rateLimiter; + @Mock + ClusterManagerMetrics clusterManagerMetrics; + @Before public void setup() { MockitoAnnotations.openMocks(this); maxMonitoringRequests = 10; settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), maxMonitoringRequests).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MONITORING_REQUEST_COUNT); - clusterService = spy(new ClusterService(settings, clusterSettings, null)); + clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics)); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); cacheHelper = new MLModelCacheHelper(clusterService, settings); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index d42fa9ca65..09981b985d 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -76,6 +76,7 @@ import org.opensearch.action.update.UpdateRequest; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterManagerMetrics; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -177,7 +178,7 @@ public class MLModelManagerTests extends OpenSearchTestCase { private ScriptService scriptService; @Mock - private MLTask pretrainedMLTask; + ClusterManagerMetrics clusterManagerMetrics; @Before public void setup() throws URISyntaxException { @@ -196,7 +197,7 @@ public void setup() throws URISyntaxException { ML_COMMONS_MONITORING_REQUEST_COUNT, ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE ); - clusterService = spy(new ClusterService(settings, clusterSettings, null)); + clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics)); xContentRegistry = NamedXContentRegistry.EMPTY; modelName = "model_name1"; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 9011746797..34d4e2380f 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -28,6 +28,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterManagerMetrics; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Settings; @@ -48,7 +49,6 @@ import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; -import org.opensearch.transport.TransportService; public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { @@ -70,14 +70,14 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { @Mock MLCircuitBreakerService mlCircuitBreakerService; - @Mock - TransportService transportService; - @Mock ActionListener listener; @Mock DiscoveryNodeHelper nodeHelper; + @Mock + ClusterManagerMetrics clusterManagerMetrics; + @Rule public ExpectedException exceptionRule = ExpectedException.none(); @@ -115,7 +115,7 @@ public void setup() { ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL ); - clusterService = spy(new ClusterService(settings, clusterSettings, null)); + clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics)); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); Map> stats = new ConcurrentHashMap<>();