From abdccc0850687ce96133736a4280490432ee2360 Mon Sep 17 00:00:00 2001 From: Xun Zhang Date: Tue, 11 Jul 2023 09:55:12 -0700 Subject: [PATCH] fix memory circuit breaker Signed-off-by: Xun Zhang --- plugin/build.gradle | 1 - .../ml/breaker/MLCircuitBreakerService.java | 2 +- .../ml/breaker/MLCircuitBreakerServiceTests.java | 12 ++++++++++-- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/plugin/build.gradle b/plugin/build.gradle index 176ecfe320..946b413c51 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -287,7 +287,6 @@ List jacocoExclusions = [ 'org.opensearch.ml.action.training.TrainingITTests', 'org.opensearch.ml.action.prediction.PredictionITTests', 'org.opensearch.ml.cluster.MLSyncUpCron', - 'org.opensearch.ml.breaker.MemoryCircuitBreaker', 'org.opensearch.ml.model.MLModelGroupManager', 'org.opensearch.ml.helper.ModelAccessControlHelper', 'org.opensearch.ml.action.models.DeleteModelTransportAction.2' diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java b/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java index f56c7915e0..156c71b69c 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MLCircuitBreakerService.java @@ -74,7 +74,7 @@ public CircuitBreaker getBreaker(BreakerName name) { */ public MLCircuitBreakerService init(Path path) { // Register memory circuit breaker - registerBreaker(BreakerName.MEMORY, new MemoryCircuitBreaker(this.jvmService)); + registerBreaker(BreakerName.MEMORY, new MemoryCircuitBreaker(this.settings, this.clusterService, this.jvmService)); log.info("Registered ML memory breaker."); registerBreaker(BreakerName.DISK, new DiskCircuitBreaker(path.toString())); log.info("Registered ML disk breaker."); diff --git a/plugin/src/test/java/org/opensearch/ml/breaker/MLCircuitBreakerServiceTests.java b/plugin/src/test/java/org/opensearch/ml/breaker/MLCircuitBreakerServiceTests.java index 8e5e503c82..9ed06d0b3c 100644 --- a/plugin/src/test/java/org/opensearch/ml/breaker/MLCircuitBreakerServiceTests.java +++ b/plugin/src/test/java/org/opensearch/ml/breaker/MLCircuitBreakerServiceTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.breaker; import static org.mockito.Mockito.when; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD; import java.nio.file.Path; @@ -95,8 +96,15 @@ public void testClearBreakers() { @Test public void testInit() { - Settings settings = Settings.builder().put(ML_COMMONS_NATIVE_MEM_THRESHOLD.getKey(), 90).build(); - ClusterSettings clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_NATIVE_MEM_THRESHOLD))); + Settings settings = Settings + .builder() + .put(ML_COMMONS_NATIVE_MEM_THRESHOLD.getKey(), 90) + .put(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD.getKey(), 95) + .build(); + ClusterSettings clusterSettings = new ClusterSettings( + settings, + new HashSet<>(Arrays.asList(ML_COMMONS_NATIVE_MEM_THRESHOLD, ML_COMMONS_JVM_HEAP_MEM_THRESHOLD)) + ); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); mlCircuitBreakerService = new MLCircuitBreakerService(jvmService, osService, settings, clusterService); Assert.assertNotNull(mlCircuitBreakerService.init(Path.of("/")));