From ddf742f958269ddd733303eb63ed5416c59fe24c Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 8 Nov 2022 01:06:09 +0000 Subject: [PATCH] fix running tasks when circuit breaker is open Signed-off-by: Yaliang Wu --- .../org/opensearch/ml/model/MLModelManager.java | 14 +++++++------- .../opensearch/ml/model/MLModelManagerTests.java | 9 ++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 594071d26a..ace03cf34b 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -168,16 +168,16 @@ public MLModelManager( * @param mlTask ML task */ public void uploadMLModel(MLUploadInput uploadInput, MLTask mlTask) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - String errorMsg = checkAndAddRunningTask(mlTask, maxUploadTasksPerNode); - if (errorMsg != null) { - mlTaskManager.updateMLTaskDirectly(mlTask.getTaskId(), ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, errorMsg)); - throw new MLLimitExceededException(errorMsg); - } String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + String errorMsg = checkAndAddRunningTask(mlTask, maxUploadTasksPerNode); + if (errorMsg != null) { + mlTaskManager.updateMLTaskDirectly(mlTask.getTaskId(), ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, errorMsg)); + throw new MLLimitExceededException(errorMsg); + } + mlStats.createCounterStatIfAbsent(functionName, UPLOAD, ML_ACTION_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = uploadInput.getModelName(); diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 494d6d421b..0224c963cc 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -75,7 +75,6 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.dataset.MLInputDataType; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -248,20 +247,16 @@ public void setup() throws URISyntaxException { public void testUploadMLModel_ExceedMaxRunningTask() { String error = "exceed max running task limit"; - expectedEx.expect(MLLimitExceededException.class); - expectedEx.expectMessage(error); when(mlTaskManager.checkLimitAndAddRunningTask(any(), any())).thenReturn(error); modelManager.uploadMLModel(uploadInput, mlTask); - verify(mlTaskManager, never()).updateMLTaskDirectly(eq(mlTask.getTaskId()), any()); + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } public void testUploadMLModel_CircuitBreakerOpen() { - expectedEx.expect(MLLimitExceededException.class); - expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!"); when(mlTaskManager.checkLimitAndAddRunningTask(any(), any())).thenReturn(null); when(mlCircuitBreakerService.checkOpenCB()).thenReturn("Disk Circuit Breaker"); modelManager.uploadMLModel(uploadInput, mlTask); - verify(mlTaskManager, never()).updateMLTaskDirectly(eq(mlTask.getTaskId()), any()); + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } public void testUploadMLModel_InitModelIndexFailure() {