From b6bb288366ef2689de0fe4028fbb8b6e8bab6734 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 11 Jul 2023 21:17:46 -0700 Subject: [PATCH] init master key automatically (#1075) * init master key automatically Signed-off-by: Yaliang Wu * remove unnecessary escape Signed-off-by: Yaliang Wu * fix failed ut Signed-off-by: Yaliang Wu * tune syncup jot interval Signed-off-by: Yaliang Wu * tune syncup jot interval Signed-off-by: Yaliang Wu * remove local config file code Signed-off-by: Yaliang Wu * set master key when init remote model Signed-off-by: Yaliang Wu * move init master key to encryptor Signed-off-by: Yaliang Wu * fine tune code Signed-off-by: Yaliang Wu * fine tune code Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu (cherry picked from commit 66672d954b357ecc26125bdefb4d0b08d8661d8e) --- .../org/opensearch/ml/common/CommonValue.java | 20 +++ .../org/opensearch/ml/engine/MLEngine.java | 16 +- .../algorithms/remote/ConnectorUtils.java | 5 - .../ml/engine/encryptor/Encryptor.java | 7 + .../ml/engine/encryptor/EncryptorImpl.java | 81 ++++++++-- .../MetricsCorrelationTest.java | 10 +- .../algorithms/remote/ConnectorUtilsTest.java | 2 +- .../text_embedding/ModelHelperTest.java | 7 +- .../TextEmbeddingModelTest.java | 11 +- .../engine/encryptor/EncryptorImplTest.java | 148 ++++++++++++++++++ .../MLCommonsClusterManagerEventListener.java | 8 +- .../opensearch/ml/cluster/MLSyncUpCron.java | 49 +++++- .../org/opensearch/ml/indices/MLIndex.java | 6 +- .../ml/indices/MLIndicesHandler.java | 4 + .../opensearch/ml/model/MLModelManager.java | 8 - .../ml/plugin/MachineLearningPlugin.java | 17 +- .../ml/settings/MLCommonsSettings.java | 5 +- .../TransportDeployModelActionTests.java | 2 +- .../ml/cluster/MLSyncUpCronTests.java | 63 +++++++- .../ml/model/MLModelManagerTests.java | 5 +- .../ml/task/MLPredictTaskRunnerTests.java | 2 +- .../MLTrainAndPredictTaskRunnerTests.java | 2 +- .../ml/task/MLTrainingTaskRunnerTests.java | 2 +- 23 files changed, 415 insertions(+), 65 deletions(-) create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index 9fc2294d3b..16554933b5 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -24,6 +24,9 @@ public class CommonValue { public static final String UNDEPLOYED = "undeployed"; public static final String NOT_FOUND = "not_found"; + public static final String MASTER_KEY = "master_key"; + public static final String CREATE_TIME_FIELD = "create_time"; + public static final String BOX_TYPE_KEY = "box_type"; //hot node public static String HOT_BOX_TYPE = "hot"; @@ -37,6 +40,8 @@ public class CommonValue { public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector"; public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1; public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1; + public static final String ML_CONFIG_INDEX = ".plugins-ml-config"; + public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1; public static final String USER_FIELD_MAPPING = " \"" + CommonValue.USER + "\": {\n" @@ -301,4 +306,19 @@ public class CommonValue { + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + " }\n" + "}"; + + + public static final String ML_CONFIG_INDEX_MAPPING = "{\n" + + " \"_meta\": {\"schema_version\": " + + ML_CONFIG_INDEX_SCHEMA_VERSION + + "},\n" + + " \"properties\": {\n" + + " \"" + + MASTER_KEY + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n" + + " }\n" + + "}"; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index b0ed953bd1..0c49e83bac 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -6,6 +6,7 @@ package org.opensearch.ml.engine; import lombok.Getter; +import lombok.extern.log4j.Log4j2; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; @@ -18,7 +19,6 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.engine.encryptor.Encryptor; - import java.nio.file.Path; import java.util.Locale; import java.util.Map; @@ -26,21 +26,26 @@ /** * This is the interface to all ml algorithms. */ +@Log4j2 public class MLEngine { public static final String REGISTER_MODEL_FOLDER = "register"; public static final String DEPLOY_MODEL_FOLDER = "deploy"; private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models"; + @Getter + private final Path mlConfigPath; + @Getter private final Path mlCachePath; private final Path mlModelsCachePath; - private final Encryptor encryptor; + private Encryptor encryptor; public MLEngine(Path opensearchDataFolder, Encryptor encryptor) { - mlCachePath = opensearchDataFolder.resolve("ml_cache"); - mlModelsCachePath = mlCachePath.resolve("models_cache"); + this.mlCachePath = opensearchDataFolder.resolve("ml_cache"); + this.mlModelsCachePath = mlCachePath.resolve("models_cache"); + this.mlConfigPath = mlCachePath.resolve("config"); this.encryptor = encryptor; } @@ -195,7 +200,4 @@ public String encrypt(String credential) { return encryptor.encrypt(credential); } - public void setMasterKey(String masterKey) { - encryptor.setMasterKey(masterKey); - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index f3bfed3c3e..7eccd6155d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -95,11 +95,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto } else { throw new IllegalArgumentException("Wrong input type"); } - Map escapedParameters = new HashMap<>(); - inputData.getParameters().entrySet().forEach(entry -> { - escapedParameters.put(entry.getKey(), escapeJava(entry.getValue())); - }); - inputData.setParameters(escapedParameters); return inputData; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java index df8e43d887..2316869ffd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java @@ -5,6 +5,9 @@ package org.opensearch.ml.engine.encryptor; +import java.security.SecureRandom; +import java.util.Base64; + public interface Encryptor { /** @@ -29,4 +32,8 @@ public interface Encryptor { * @param masterKey masterKey to be set. */ void setMasterKey(String masterKey); + String getMasterKey(); + + String generateMasterKey(); + } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 3e9d9175b4..0778af444a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -9,16 +9,39 @@ import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import org.opensearch.ml.engine.exceptions.MetaDataException; +import lombok.extern.log4j.Log4j2; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.ActionListener; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.common.exception.MLException; import javax.crypto.spec.SecretKeySpec; import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; import java.util.Base64; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; + +@Log4j2 public class EncryptorImpl implements Encryptor { + private ClusterService clusterService; + private Client client; private volatile String masterKey; + public EncryptorImpl(ClusterService clusterService, Client client) { + this.masterKey = null; + this.clusterService = clusterService; + this.client = client; + } + public EncryptorImpl(String masterKey) { this.masterKey = masterKey; } @@ -28,9 +51,14 @@ public void setMasterKey(String masterKey) { this.masterKey = masterKey; } + @Override + public String getMasterKey() { + return masterKey; + } + @Override public String encrypt(String plainText) { - checkMasterKey(); + initMasterKey(); final AwsCrypto crypto = AwsCrypto.builder() .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) .build(); @@ -46,7 +74,7 @@ public String encrypt(String plainText) { @Override public String decrypt(String encryptedText) { - checkMasterKey(); + initMasterKey(); final AwsCrypto crypto = AwsCrypto.builder() .withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt) .build(); @@ -60,14 +88,45 @@ public String decrypt(String encryptedText) { return new String(decryptedResult.getResult()); } - private void checkMasterKey() { - if (masterKey == "0000000000000000" || masterKey == null) { - throw new MetaDataException("Please provide a masterKey for credential encryption! Example: PUT /_cluster/settings\n" + - "{\n" + - " \"persistent\" : {\n" + - " \"plugins.ml_commons.encryption.master_key\" : \"1234567x\" \n" + - " }\n" + - "}"); + @Override + public String generateMasterKey() { + byte[] keyBytes = new byte[16]; + new SecureRandom().nextBytes(keyBytes); + String base64Key = Base64.getEncoder().encodeToString(keyBytes); + return base64Key; + } + + private void initMasterKey() { + if (masterKey != null) { + return; + } + AtomicReference exceptionRef = new AtomicReference<>(); + + CountDownLatch latch = new CountDownLatch(1); + if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, new LatchedActionListener(ActionListener.wrap(r -> { + if (r.isExists()) { + String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY); + setMasterKey(masterKey); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + }, e -> { + log.error("Failed to get ML encryption master key", e); + exceptionRef.set(e); + }), latch)); + } else { + exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet")); + } + + if (exceptionRef.get() != null) { + log.debug("Failed to init master key", exceptionRef.get()); + if (exceptionRef.get() instanceof RuntimeException) { + throw (RuntimeException) exceptionRef.get(); + } else { + throw new MLException(exceptionRef.get()); + } } } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java index 32d1df3a01..35d782bc67 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java @@ -58,6 +58,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.aggregations.InternalAggregations; @@ -128,7 +129,8 @@ public class MetricsCorrelationTest { ActionListener mlDeployModelResponseActionListener; private MetricsCorrelation metricsCorrelation; private MetricsCorrelationInput input, extendedInput; - private Path djlCachePath; + private Path mlCachePath; + private Path mlConfigPath; private MLModel model; private MetricsCorrelationModelConfig modelConfig; @@ -144,7 +146,6 @@ public class MetricsCorrelationTest { Map params = new HashMap<>(); - @Mock private Encryptor encryptor; public MetricsCorrelationTest() { @@ -155,8 +156,9 @@ public void setUp() throws IOException, URISyntaxException { System.setProperty("testMode", "true"); - djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - mlEngine = new MLEngine(djlCachePath, encryptor); + mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(mlCachePath, encryptor); modelConfig = MetricsCorrelationModelConfig.builder() .modelType(MetricsCorrelation.MODEL_TYPE) .allConfig(null) diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index 9c3057b3a5..857cbe997f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -84,7 +84,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc() processInput_TextDocsInputDataSet_PreprocessFunction( "{\"input\": ${parameters.input}}", "{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }", - "[\\\"test_value1\\\",\\\"test_value2\\\"]"); + "[\"test_value1\",\"test_value2\"]"); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java index 3d2043fc24..245afdfe7f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java @@ -19,6 +19,8 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import java.io.IOException; import java.net.URISyntaxException; @@ -50,12 +52,15 @@ public class ModelHelperTest { @Mock ActionListener registerModelListener; + Encryptor encryptor; + @Before public void setup() throws URISyntaxException { MockitoAnnotations.openMocks(this); modelFormat = MLModelFormat.TORCH_SCRIPT; modelId = "model_id"; - mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), null); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor); modelHelper = new ModelHelper(mlEngine); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java index e8c46d37a9..70d7fd2e35 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java @@ -27,6 +27,7 @@ import org.opensearch.ml.engine.MLEngine; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.ml.engine.utils.FileUtils; import java.io.File; @@ -62,7 +63,8 @@ public class TextEmbeddingModelTest { private ModelHelper modelHelper; private Map params; private TextEmbeddingModel textEmbeddingModel; - private Path djlCachePath; + private Path mlCachePath; + private Path mlConfigPath; private TextDocsInputDataSet inputDataSet; private int dimension = 384; private MLEngine mlEngine; @@ -70,8 +72,9 @@ public class TextEmbeddingModelTest { @Before public void setUp() throws URISyntaxException { - djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID()); - mlEngine = new MLEngine(djlCachePath, encryptor); + mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID()); + encryptor = new EncryptorImpl("0000000000000001"); + mlEngine = new MLEngine(mlCachePath, encryptor); modelId = "test_model_id"; modelName = "test_model_name"; functionName = FunctionName.TEXT_EMBEDDING; @@ -329,7 +332,7 @@ public void predict_BeforeInitingModel() { @After public void tearDown() { - FileUtils.deleteFileQuietly(djlCachePath); + FileUtils.deleteFileQuietly(mlCachePath); } private int findSentenceEmbeddingPosition(ModelTensors modelTensors) { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java new file mode 100644 index 0000000000..2e0980bc0f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java @@ -0,0 +1,148 @@ +package org.opensearch.ml.engine.encryptor; + +import com.google.common.collect.ImmutableMap; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.Version; +import org.opensearch.action.ActionListener; +import org.opensearch.action.get.GetResponse; +import org.opensearch.client.Client; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; + +import java.time.Instant; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; + +public class EncryptorImplTest { + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + String masterKey; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + masterKey = "0000000000000001"; + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + when(response.getSourceAsMap()) + .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + + when(clusterService.state()).thenReturn(clusterState); + + Metadata metadata = new Metadata.Builder() + .indices(ImmutableMap + .builder() + .put(ML_CONFIG_INDEX, IndexMetadata.builder(ML_CONFIG_INDEX) + .settings(Settings.builder() + .put("index.number_of_shards", 1) + .put("index.number_of_replicas", 1) + .put("index.version.created", Version.CURRENT.id)) + .build()) + .build()).build(); + when(clusterState.metadata()).thenReturn(metadata); + } + + @Test + public void encrypt() { + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + String encrypted = encryptor.encrypt("test"); + Assert.assertNotNull(encrypted); + Assert.assertEquals(masterKey, encryptor.getMasterKey()); + } + + @Test + public void decrypt() { + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + String encrypted = encryptor.encrypt("test"); + String decrypted = encryptor.decrypt(encrypted); + Assert.assertEquals("test", decrypted); + Assert.assertEquals(masterKey, encryptor.getMasterKey()); + } + + @Test + public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.encrypt("test"); + } + + @Test + public void decrypt_NullMasterKey_GetMasterKey_Exception() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("test error"); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.decrypt("test"); + } + + @Test + public void decrypt_MLConfigIndexNotFound() { + exceptionRule.expect(ResourceNotFoundException.class); + exceptionRule.expectMessage("ML encryption master key not initialized yet"); + + Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build(); + when(clusterState.metadata()).thenReturn(metadata); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new RuntimeException("test error")); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client); + Assert.assertNull(encryptor.getMasterKey()); + encryptor.decrypt("test"); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java index 050d475dd2..f4abfea4df 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterManagerEventListener.java @@ -14,6 +14,7 @@ import org.opensearch.common.component.LifecycleListener; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.threadpool.Scheduler; import org.opensearch.threadpool.ThreadPool; @@ -30,6 +31,7 @@ public class MLCommonsClusterManagerEventListener implements LocalNodeClusterMan private Scheduler.Cancellable syncModelRoutingCron; private DiscoveryNodeHelper nodeHelper; private final MLIndicesHandler mlIndicesHandler; + private final Encryptor encryptor; private volatile Integer jobInterval; @@ -39,7 +41,8 @@ public MLCommonsClusterManagerEventListener( Settings settings, ThreadPool threadPool, DiscoveryNodeHelper nodeHelper, - MLIndicesHandler mlIndicesHandler + MLIndicesHandler mlIndicesHandler, + Encryptor encryptor ) { this.clusterService = clusterService; this.client = client; @@ -47,6 +50,7 @@ public MLCommonsClusterManagerEventListener( this.clusterService.addListener(this); this.nodeHelper = nodeHelper; this.mlIndicesHandler = mlIndicesHandler; + this.encryptor = encryptor; this.jobInterval = ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS, it -> { @@ -67,7 +71,7 @@ private void startSyncModelRoutingCron() { if (jobInterval > 0) { syncModelRoutingCron = threadPool .scheduleWithFixedDelay( - new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler), + new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor), TimeValue.timeValueSeconds(jobInterval), GENERAL_THREAD_POOL ); diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java index 0ba118bd29..95c5ec037c 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java @@ -5,6 +5,9 @@ package org.opensearch.ml.cluster; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.time.Instant; @@ -20,7 +23,10 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.WriteRequest; import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.node.DiscoveryNode; @@ -33,6 +39,7 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpInput; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest; +import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -50,19 +57,30 @@ public class MLSyncUpCron implements Runnable { private ClusterService clusterService; private DiscoveryNodeHelper nodeHelper; private MLIndicesHandler mlIndicesHandler; + private Encryptor encryptor; + private volatile Boolean mlConfigInited; @VisibleForTesting Semaphore updateModelStateSemaphore; - public MLSyncUpCron(Client client, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler) { + public MLSyncUpCron( + Client client, + ClusterService clusterService, + DiscoveryNodeHelper nodeHelper, + MLIndicesHandler mlIndicesHandler, + Encryptor encryptor + ) { this.client = client; this.clusterService = clusterService; this.nodeHelper = nodeHelper; this.mlIndicesHandler = mlIndicesHandler; this.updateModelStateSemaphore = new Semaphore(1); + this.mlConfigInited = false; + this.encryptor = encryptor; } @Override public void run() { + initMLConfig(); if (!clusterService.state().metadata().indices().containsKey(ML_MODEL_INDEX)) { // no need to run sync up job if no model index return; @@ -71,6 +89,7 @@ public void run() { DiscoveryNode[] allNodes = nodeHelper.getAllNodes(); MLSyncUpInput gatherInfoInput = MLSyncUpInput.builder().getDeployedModels(true).build(); MLSyncUpNodesRequest gatherInfoRequest = new MLSyncUpNodesRequest(allNodes, gatherInfoInput); + // gather running model/tasks on nodes client.execute(MLSyncUpAction.INSTANCE, gatherInfoRequest, ActionListener.wrap(r -> { List responses = r.getNodes(); @@ -142,6 +161,34 @@ public void run() { }, e -> { log.error("Failed to sync model routing", e); })); } + @VisibleForTesting + void initMLConfig() { + if (mlConfigInited) { + return; + } + mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { + GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + client.get(getRequest, ActionListener.wrap(getResponse -> { + if (!getResponse.isExists()) { + IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY); + final String masterKey = encryptor.generateMasterKey(); + indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + client.index(indexRequest, ActionListener.wrap(indexResponse -> { + log.info("ML configuration initialized successfully"); + encryptor.setMasterKey(masterKey); + mlConfigInited = true; + }, e -> { log.debug("Failed to save ML encryption master key", e); })); + } else { + final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY); + encryptor.setMasterKey(masterKey); + mlConfigInited = true; + log.info("ML configuration already initialized, no action needed"); + } + }, e -> { log.debug("Failed to get ML encryption master key", e); })); + }, e -> { log.debug("Failed to init ML config index", e); })); + } + @VisibleForTesting void refreshModelState(Map> modelWorkerNodes, Map> deployingModels) { if (!updateModelStateSemaphore.tryAcquire()) { diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java index 668306b763..b81682f07e 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndex.java @@ -5,6 +5,9 @@ package org.opensearch.ml.indices; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_MAPPING; +import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX_SCHEMA_VERSION; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX_MAPPING; import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_SCHEMA_VERSION; @@ -22,7 +25,8 @@ public enum MLIndex { MODEL_GROUP(ML_MODEL_GROUP_INDEX, false, ML_MODEL_GROUP_INDEX_MAPPING, ML_MODEL_GROUP_INDEX_SCHEMA_VERSION), MODEL(ML_MODEL_INDEX, false, ML_MODEL_INDEX_MAPPING, ML_MODEL_INDEX_SCHEMA_VERSION), TASK(ML_TASK_INDEX, false, ML_TASK_INDEX_MAPPING, ML_TASK_INDEX_SCHEMA_VERSION), - CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION); + CONNECTOR(ML_CONNECTOR_INDEX, false, ML_CONNECTOR_INDEX_MAPPING, ML_CONNECTOR_SCHEMA_VERSION), + CONFIG(ML_CONFIG_INDEX, false, ML_CONFIG_INDEX_MAPPING, ML_CONFIG_INDEX_SCHEMA_VERSION); private final String indexName; // whether we use an alias for the index diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java index 3235d27f29..12954a62a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java @@ -61,6 +61,10 @@ public void initMLConnectorIndex(ActionListener listener) { initMLIndexIfAbsent(MLIndex.CONNECTOR, listener); } + public void initMLConfigIndex(ActionListener listener) { + initMLIndexIfAbsent(MLIndex.CONFIG, listener); + } + public void initMLIndexIfAbsent(MLIndex index, ActionListener listener) { String indexName = index.getIndexName(); String mapping = index.getMapping(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 02d3889464..07720d45f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -34,7 +34,6 @@ import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; @@ -154,7 +153,6 @@ public class MLModelManager { private volatile Integer maxModelPerNode; private volatile Integer maxRegisterTasksPerNode; private volatile Integer maxDeployTasksPerNode; - private volatile String masterKey; public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet .of( @@ -208,12 +206,6 @@ public MLModelManager( clusterService .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it); - - this.masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(settings); - clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> { - masterKey = it; - mlEngine.setMasterKey(masterKey); - }); } public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 6b525d8695..5870cbd3f1 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -7,8 +7,8 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; +import java.nio.file.Path; import java.util.Collection; import java.util.List; import java.util.Map; @@ -276,17 +276,19 @@ public Collection createComponents( this.clusterService = clusterService; this.xContentRegistry = xContentRegistry; Settings settings = environment.settings(); - String masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(clusterService.getSettings()); - Encryptor encryptor = new EncryptorImpl(masterKey); + Path dataPath = environment.dataFiles()[0]; + Path configFile = environment.configFile(); - mlEngine = new MLEngine(environment.dataFiles()[0], encryptor); + Encryptor encryptor = new EncryptorImpl(clusterService, client); + + mlEngine = new MLEngine(dataPath, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); modelCacheHelper = new MLModelCacheHelper(clusterService, settings); JvmService jvmService = new JvmService(environment.settings()); OsService osService = new OsService(environment.settings()); MLCircuitBreakerService mlCircuitBreakerService = new MLCircuitBreakerService(jvmService, osService, settings, clusterService) - .init(environment.dataFiles()[0]); + .init(dataPath); Map> stats = new ConcurrentHashMap<>(); // cluster level stats @@ -408,11 +410,13 @@ public Collection createComponents( settings, threadPool, nodeHelper, - mlIndicesHandler + mlIndicesHandler, + encryptor ); return ImmutableList .of( + encryptor, mlEngine, nodeHelper, modelCacheHelper, @@ -601,7 +605,6 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_ALLOW_MODEL_URL, MLCommonsSettings.ML_COMMONS_ALLOW_LOCAL_FILE_UPLOAD, MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED, - MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY, MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX ); diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 856b819380..9f1a5308a2 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -42,7 +42,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_SYNC_UP_JOB_INTERVAL_IN_SECONDS = Setting .intSetting( "plugins.ml_commons.sync_up_job_interval_in_seconds", - 3, + 10, 0, 86400, Setting.Property.NodeScope, @@ -111,9 +111,6 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.model_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_MASTER_SECRET_KEY = Setting - .simpleString("plugins.ml_commons.encryption.master_key", "0000000000000000", Setting.Property.NodeScope, Setting.Property.Dynamic); - public static final Setting ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting .boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); diff --git a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java index 38dc5a605d..1cb0670e14 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java @@ -142,7 +142,7 @@ public void setup() { clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN))); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); - encryptor = new EncryptorImpl("0000000000000000"); + encryptor = new EncryptorImpl("0000000000000001"); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); modelHelper = new ModelHelper(mlEngine); when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId"); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java index 417e9d2e77..5a6913f576 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLSyncUpCronTests.java @@ -10,10 +10,14 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD; +import static org.opensearch.ml.common.CommonValue.MASTER_KEY; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.setupTestClusterState; @@ -30,6 +34,7 @@ import java.util.concurrent.atomic.AtomicInteger; import org.apache.lucene.search.TotalHits; +import org.junit.Assert; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -37,6 +42,8 @@ import org.opensearch.Version; import org.opensearch.action.ActionListener; import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.search.SearchResponseSections; import org.opensearch.action.search.ShardSearchFailure; @@ -57,6 +64,9 @@ import org.opensearch.ml.common.transport.sync.MLSyncUpAction; import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse; import org.opensearch.ml.common.transport.sync.MLSyncUpNodesResponse; +import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.ml.engine.encryptor.EncryptorImpl; +import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.utils.TestHelper; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -66,6 +76,7 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.test.OpenSearchTestCase; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; public class MLSyncUpCronTests extends OpenSearchTestCase { @@ -76,6 +87,8 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { private ClusterService clusterService; @Mock private DiscoveryNodeHelper nodeHelper; + @Mock + private MLIndicesHandler mlIndicesHandler; private DiscoveryNode mlNode1; private DiscoveryNode mlNode2; @@ -85,16 +98,64 @@ public class MLSyncUpCronTests extends OpenSearchTestCase { private final String mlNode2Id = "mlNode2"; private ClusterState testState; + private Encryptor encryptor; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); - syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, null); + encryptor = spy(new EncryptorImpl(null)); + syncUpCron = new MLSyncUpCron(client, clusterService, nodeHelper, mlIndicesHandler, encryptor); testState = setupTestClusterState(); when(clusterService.state()).thenReturn(testState); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + } + + public void testInitMlConfig_MasterKeyNotExist() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(false); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + IndexResponse indexResponse = mock(IndexResponse.class); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + syncUpCron.initMLConfig(); + Assert.assertNotNull(encryptor.encrypt("test")); + syncUpCron.initMLConfig(); + verify(encryptor, times(1)).setMasterKey(any()); + } + + public void testInitMlConfig_MasterKeyExists() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = mock(GetResponse.class); + when(response.isExists()).thenReturn(true); + String masterKey = encryptor.generateMasterKey(); + when(response.getSourceAsMap()) + .thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli())); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + syncUpCron.initMLConfig(); + Assert.assertNotNull(encryptor.encrypt("test")); + syncUpCron.initMLConfig(); + verify(encryptor, times(1)).setMasterKey(any()); } public void testRun_NoMLModelIndex() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 07c87c65fd..24fd02bf71 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -26,7 +26,6 @@ import static org.opensearch.ml.engine.ModelHelper.MODEL_SIZE_IN_BYTES; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; @@ -179,14 +178,12 @@ public void setup() throws URISyntaxException { settings = Settings.builder().put(ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), 10).build(); settings = Settings.builder().put(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE.getKey(), 10).build(); - settings = Settings.builder().put(ML_COMMONS_MASTER_SECRET_KEY.getKey(), masterKey).build(); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_MAX_MODELS_PER_NODE, ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE, ML_COMMONS_MONITORING_REQUEST_COUNT, - ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, - ML_COMMONS_MASTER_SECRET_KEY + ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE ); clusterService = spy(new ClusterService(settings, clusterSettings, null)); xContentRegistry = NamedXContentRegistry.EMPTY; diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index 1002ee7d8f..530033197c 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -127,7 +127,7 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000000"); + encryptor = new EncryptorImpl("0000000000000001"); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java index 73df81252c..0714bf0234 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainAndPredictTaskRunnerTests.java @@ -101,7 +101,7 @@ public class MLTrainAndPredictTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { - encryptor = new EncryptorImpl("0000000000000000"); + encryptor = new EncryptorImpl("0000000000000001"); mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor); settings = Settings.builder().build(); MockitoAnnotations.openMocks(this); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java index f64faf59cc..6565a41b95 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTrainingTaskRunnerTests.java @@ -111,7 +111,7 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase { @Before public void setup() { MockitoAnnotations.openMocks(this); - encryptor = new EncryptorImpl("0000000000000000"); + encryptor = new EncryptorImpl("0000000000000001"); mlEngine = new MLEngine(Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), encryptor); localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT); remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT);