Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to main][2.x] Adding an integration test for redeploying a model (#1016) #1264

Merged
merged 1 commit into from
Aug 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.rest;

import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD;

import java.io.IOException;
import java.util.Map;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.utils.TestHelper;

public class RestMLDeployModelActionIT extends MLCommonsRestTestCase {
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
private MLRegisterModelInput registerModelInput;
private MLRegisterModelGroupInput mlRegisterModelGroupInput;
private String modelGroupId;

@Before
public void setup() throws IOException {
mlRegisterModelGroupInput = MLRegisterModelGroupInput.builder().name("testGroupID").description("This is test Group").build();
registerModelGroup(client(), TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> {
this.modelGroupId = (String) registerModelGroupResult.get("model_group_id");
});
registerModelInput = createRegisterModelInput(modelGroupId);
}

public void testReDeployModel() throws InterruptedException, IOException {
// Register Model
String taskId = registerModel(TestHelper.toJsonString(registerModelInput));
waitForTask(taskId, MLTaskState.COMPLETED);
getTask(client(), taskId, response -> {
String model_id = (String) response.get(MODEL_ID_FIELD);
try {
// Deploy Model
String taskId1 = deployModel(model_id);
getTask(client(), taskId1, innerResponse -> { assertEquals(model_id, innerResponse.get(MODEL_ID_FIELD)); });
waitForTask(taskId1, MLTaskState.COMPLETED);

// Undeploy Model
Map<String, Object> undeployresponse = undeployModel(model_id);
for (Map.Entry<String, Object> entry : undeployresponse.entrySet()) {
Map stats = (Map) ((Map) entry.getValue()).get("stats");
assertEquals("undeployed", stats.get(model_id));
}

// Deploy Model again
taskId1 = deployModel(model_id);
getTask(client(), taskId1, innerResponse -> { logger.info("Re-Deploy model {}", innerResponse); });
waitForTask(taskId1, MLTaskState.COMPLETED);

getModel(client(), model_id, model -> {
logger.info("Get Model after re-deploy {}", model);
assertEquals("DEPLOYED", model.get("model_state"));
});

} catch (Exception e) {
throw new RuntimeException(e);
}
});
}
}
Loading