From 9dcb731a6b3fd7283fc66a82ee02ee8e485d8296 Mon Sep 17 00:00:00 2001 From: Sarat Vemulapalli Date: Wed, 5 Jul 2023 10:43:12 -0700 Subject: [PATCH] [2.x] Adding an integration test for redeploying a model (#1016) * Adding a failing integ test for redeploy model and fix breaking changes from OpenSearch core Signed-off-by: Sarat Vemulapalli * Adding model group ID changes for tests Signed-off-by: Sarat Vemulapalli * Fixing tests for ImmutableMap copy Signed-off-by: Sarat Vemulapalli * Commenting wait out task for model Signed-off-by: Sarat Vemulapalli * Adding a failing integ test for redeploy model and fix breaking changes from OpenSearch core Signed-off-by: Sarat Vemulapalli * Rebasing with 2.x Signed-off-by: Sarat Vemulapalli * Adding logs to debug the test in GHA Signed-off-by: Sarat Vemulapalli * GHA tests Signed-off-by: Sarat Vemulapalli * Still debugging Signed-off-by: Sarat Vemulapalli * Removing comment Signed-off-by: Sarat Vemulapalli * Removing unnecessary changes Signed-off-by: Sarat Vemulapalli * Removing logs Signed-off-by: Sarat Vemulapalli --------- Signed-off-by: Sarat Vemulapalli --- .../ml/rest/RestMLDeployModelActionIT.java | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java new file mode 100644 index 0000000000..fa542c7c0c --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLDeployModelActionIT.java @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.rest; + +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; + +import java.io.IOException; +import java.util.Map; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; +import org.opensearch.ml.utils.TestHelper; + +public class RestMLDeployModelActionIT extends MLCommonsRestTestCase { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + private MLRegisterModelInput registerModelInput; + private MLRegisterModelGroupInput mlRegisterModelGroupInput; + private String modelGroupId; + + @Before + public void setup() throws IOException { + mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder().name("testGroupID").description("This is test Group").build(); + registerModelGroup(client(), TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> { + this.modelGroupId = (String) registerModelGroupResult.get("model_group_id"); + }); + registerModelInput = createRegisterModelInput(modelGroupId); + } + + public void testReDeployModel() throws InterruptedException, IOException { + // Register Model + String taskId = registerModel(TestHelper.toJsonString(registerModelInput)); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + String model_id = (String) response.get(MODEL_ID_FIELD); + try { + // Deploy Model + String taskId1 = deployModel(model_id); + getTask(client(), taskId1, innerResponse -> { assertEquals(model_id, innerResponse.get(MODEL_ID_FIELD)); }); + waitForTask(taskId1, MLTaskState.COMPLETED); + + // Undeploy Model + Map undeployresponse = undeployModel(model_id); + for (Map.Entry entry : undeployresponse.entrySet()) { + Map stats = (Map) ((Map) entry.getValue()).get("stats"); + assertEquals("undeployed", stats.get(model_id)); + } + + // Deploy Model again + taskId1 = deployModel(model_id); + getTask(client(), taskId1, innerResponse -> { logger.info("Re-Deploy model {}", innerResponse); }); + waitForTask(taskId1, MLTaskState.COMPLETED); + + getModel(client(), model_id, model -> { + logger.info("Get Model after re-deploy {}", model); + assertEquals("DEPLOYED", model.get("model_state")); + }); + + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } +}