Skip to content

Commit

Permalink
fine tune code
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jul 12, 2023
1 parent e22f514 commit 5d7a898
Show file tree
Hide file tree
Showing 16 changed files with 26 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ public class MLEngine {
public static final String DEPLOY_MODEL_FOLDER = "deploy";
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";

private final Path mlUserConfigPath;
@Getter
private final Path mlConfigPath;

Expand All @@ -43,10 +42,9 @@ public class MLEngine {

private Encryptor encryptor;

public MLEngine(Path opensearchDataFolder, Path opensearchConfigFolder, Encryptor encryptor) {
public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlUserConfigPath = opensearchConfigFolder.resolve("opensearch-ml");
this.mlConfigPath = mlCachePath.resolve("config");
this.encryptor = encryptor;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import com.amazonaws.encryptionsdk.CommitmentPolicy;
import com.amazonaws.encryptionsdk.CryptoResult;
import com.amazonaws.encryptionsdk.jce.JceMasterKey;
import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
Expand Down Expand Up @@ -43,8 +42,8 @@ public EncryptorImpl(ClusterService clusterService, Client client) {
this.client = client;
}

@VisibleForTesting
public EncryptorImpl() {
public EncryptorImpl(String masterKey) {
this.masterKey = masterKey;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ public class MLEngineTest {

@Before
public void setUp() {
Encryptor encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000000");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
Encryptor encryptor = new EncryptorImpl("0000000000000000");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,8 @@ public void setUp() throws IOException, URISyntaxException {
System.setProperty("testMode", "true");

mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlConfigPath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelConfig = MetricsCorrelationModelConfig.builder()
.modelType(MetricsCorrelation.MODEL_TYPE)
.allConfig(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ public class AwsConnectorExecutorTest {
@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
encryptor = new EncryptorImpl("0000000000000001");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ public class RemoteModelTest {
public void setUp() {
MockitoAnnotations.openMocks(this);
remoteModel = new RemoteModel();
encryptor = spy(new EncryptorImpl());
encryptor.setMasterKey("0000000000000001");
encryptor = spy(new EncryptorImpl("0000000000000001"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
modelFormat = MLModelFormat.TORCH_SCRIPT;
modelId = "model_id";
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), Path.of("/tmp/test_config"), encryptor);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor);
modelHelper = new ModelHelper(mlEngine);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ public class TextEmbeddingModelTest {
@Before
public void setUp() throws URISyntaxException {
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
mlConfigPath = Path.of("/tmp/ml_config" + UUID.randomUUID());
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelId = "test_model_id";
modelName = "test_model_name";
functionName = FunctionName.TEXT_EMBEDDING;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ public Collection<Object> createComponents(

Encryptor encryptor = new EncryptorImpl(clusterService, client);

mlEngine = new MLEngine(dataPath, configFile, encryptor);
mlEngine = new MLEngine(dataPath, encryptor);
nodeHelper = new DiscoveryNodeHelper(clusterService, settings);
modelCacheHelper = new MLModelCacheHelper(clusterService, settings);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,8 @@ public void setup() {
clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN)));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
modelHelper = new ModelHelper(mlEngine);
when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId");
when(mlDeployModelRequest.getModelNodeIds()).thenReturn(new String[] { "node1" });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT);
encryptor = spy(new EncryptorImpl());
encryptor = spy(new EncryptorImpl(null));
syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor);

testState = setupTestClusterState();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,8 @@ public class MLModelManagerTests extends OpenSearchTestCase {
public void setup() throws URISyntaxException {
String masterKey = "0000000000000001";
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl();
encryptor.setMasterKey(masterKey);
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
encryptor = new EncryptorImpl(masterKey);
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
settings = Settings.builder().put(ML_COMMONS_MAX_MODELS_PER_NODE.getKey(), 10).build();
settings = Settings.builder().put(ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE.getKey(), 10).build();
settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,8 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000000");
mlEngine = new MLEngine(
Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)),
Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)),
encryptor
);
encryptor = new EncryptorImpl("0000000000000000");
mlEngine = new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor);
when(threadPool.executor(anyString())).thenReturn(executorService);
doAnswer(invocation -> {
Runnable runnable = invocation.getArgument(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,8 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase {
@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT);
remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT);
when(clusterService.localNode()).thenReturn(localNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ public class MLTrainAndPredictTaskRunnerTests extends OpenSearchTestCase {

@Before
public void setup() {
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
settings = Settings.builder().build();
MockitoAnnotations.openMocks(this);
localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(
Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)),
Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)),
Expand Down

0 comments on commit 5d7a898

Please sign in to comment.