diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 594071d26a..ba73a01faa 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -168,16 +168,16 @@ public MLModelManager( * @param mlTask ML task */ public void uploadMLModel(MLUploadInput uploadInput, MLTask mlTask) { - mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); - String errorMsg = checkAndAddRunningTask(mlTask, maxUploadTasksPerNode); - if (errorMsg != null) { - mlTaskManager.updateMLTaskDirectly(mlTask.getTaskId(), ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, errorMsg)); - throw new MLLimitExceededException(errorMsg); - } String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); + String errorMsg = checkAndAddRunningTask(mlTask, maxUploadTasksPerNode); + if (errorMsg != null) { + mlTaskManager.updateMLTaskDirectly(mlTask.getTaskId(), ImmutableMap.of(STATE_FIELD, FAILED, ERROR_FIELD, errorMsg)); + throw new MLLimitExceededException(errorMsg); + } + mlStats.createCounterStatIfAbsent(functionName, UPLOAD, ML_ACTION_REQUEST_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).increment(); String modelName = uploadInput.getModelName(); @@ -501,7 +501,7 @@ private void retrieveModelChunks(MLModel mlModelMeta, ActionListener liste }, e -> { stopNow.set(true); semaphore.release(); - log.error("Failed to model and chunks", e); + log.error("Failed to retrieve model chunk " + modelChunkId, e); if (retrievedChunks.get() == totalChunks - 1) { listener.onFailure(new MLResourceNotFoundException("Fail to find model chunk " + modelChunkId)); } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 494d6d421b..0224c963cc 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -75,7 +75,6 @@ import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.breaker.MLCircuitBreakerService; import org.opensearch.ml.common.dataset.MLInputDataType; -import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.model.MLModelConfig; import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.MLModelState; @@ -248,20 +247,16 @@ public void setup() throws URISyntaxException { public void testUploadMLModel_ExceedMaxRunningTask() { String error = "exceed max running task limit"; - expectedEx.expect(MLLimitExceededException.class); - expectedEx.expectMessage(error); when(mlTaskManager.checkLimitAndAddRunningTask(any(), any())).thenReturn(error); modelManager.uploadMLModel(uploadInput, mlTask); - verify(mlTaskManager, never()).updateMLTaskDirectly(eq(mlTask.getTaskId()), any()); + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } public void testUploadMLModel_CircuitBreakerOpen() { - expectedEx.expect(MLLimitExceededException.class); - expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!"); when(mlTaskManager.checkLimitAndAddRunningTask(any(), any())).thenReturn(null); when(mlCircuitBreakerService.checkOpenCB()).thenReturn("Disk Circuit Breaker"); modelManager.uploadMLModel(uploadInput, mlTask); - verify(mlTaskManager, never()).updateMLTaskDirectly(eq(mlTask.getTaskId()), any()); + verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean()); } public void testUploadMLModel_InitModelIndexFailure() {