Skip to content

Commit

Permalink
init master key automatically
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jul 11, 2023
1 parent 6294a76 commit 4bd3c44
Show file tree
Hide file tree
Showing 25 changed files with 294 additions and 69 deletions.
20 changes: 20 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ public class CommonValue {
public static final String UNDEPLOYED = "undeployed";
public static final String NOT_FOUND = "not_found";

public static final String MASTER_KEY = "master_key";
public static final String CREATE_TIME_FIELD = "create_time";

public static final String BOX_TYPE_KEY = "box_type";
//hot node
public static String HOT_BOX_TYPE = "hot";
Expand All @@ -37,6 +40,8 @@ public class CommonValue {
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down Expand Up @@ -301,4 +306,19 @@ public class CommonValue {
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";


public static final String ML_CONFIG_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_CONFIG_INDEX_SCHEMA_VERSION
+ "},\n"
+ " \"properties\": {\n"
+ " \""
+ MASTER_KEY
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ CREATE_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";
}
1 change: 1 addition & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ dependencies {
implementation 'com.amazonaws:aws-encryption-sdk-java:2.4.0'
implementation 'com.jayway.jsonpath:json-path:2.8.0'
implementation group: 'org.json', name: 'json', version: '20230227'
implementation group: 'org.yaml', name: 'snakeyaml', version: '2.0'
}

configurations.all {
Expand Down
66 changes: 58 additions & 8 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,96 @@

package org.opensearch.ml.engine;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.stream.JsonReader;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.encryptor.Encryptor;

import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.yaml.snakeyaml.DumperOptions;
import org.yaml.snakeyaml.Yaml;

import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.nio.file.Files;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.ml.common.CommonValue.MASTER_KEY;

/**
* This is the interface to all ml algorithms.
*/
@Log4j2
public class MLEngine {

public static final String REGISTER_MODEL_FOLDER = "register";
public static final String DEPLOY_MODEL_FOLDER = "deploy";
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";

private final Path mlUserConfigPath;
@Getter
private final Path mlConfigPath;

@Getter
private final Path mlCachePath;
private final Path mlModelsCachePath;

private final Encryptor encryptor;
private Encryptor encryptor;

public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
mlCachePath = opensearchDataFolder.resolve("ml_cache");
mlModelsCachePath = mlCachePath.resolve("models_cache");
public MLEngine(Path opensearchDataFolder, Path opensearchConfigFolder, Encryptor encryptor) {
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlUserConfigPath = opensearchConfigFolder.resolve("opensearch-ml");
this.mlConfigPath = mlCachePath.resolve("config");
this.encryptor = encryptor;
initMasterKey();
}

private synchronized void initMasterKey() {
try {
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
Path userConfigFilePath = mlUserConfigPath.resolve("security_config.json");
Map<String, String> config = null;
if (Files.exists(userConfigFilePath)) {
try (FileInputStream fis = new FileInputStream(userConfigFilePath.toFile());) {
Yaml yaml = new Yaml();
config = yaml.load(fis);
}
}
if (config == null) {
config = new HashMap<>();
}

if (config.containsKey(MASTER_KEY)) {
encryptor.setMasterKey(config.get(MASTER_KEY));
}
return null;
});
} catch (Exception e) {
log.error("Failed to save master key", e);
throw new MLException(e);
}
}

public String getPrebuiltModelMetaListPath() {
Expand Down Expand Up @@ -195,7 +248,4 @@ public String encrypt(String credential) {
return encryptor.encrypt(credential);
}

public void setMasterKey(String masterKey) {
encryptor.setMasterKey(masterKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.engine.encryptor;

import java.security.SecureRandom;
import java.util.Base64;

public interface Encryptor {

/**
Expand All @@ -29,4 +32,7 @@ public interface Encryptor {
* @param masterKey masterKey to be set.
*/
void setMasterKey(String masterKey);

String generateMasterKey();

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Base64;

public class EncryptorImpl implements Encryptor {

private volatile String masterKey;

public EncryptorImpl(String masterKey) {
this.masterKey = masterKey;
public EncryptorImpl() {
this.masterKey = null;
}

@Override
Expand Down Expand Up @@ -60,14 +61,17 @@ public String decrypt(String encryptedText) {
return new String(decryptedResult.getResult());
}

@Override
public String generateMasterKey() {
byte[] keyBytes = new byte[16];
new SecureRandom().nextBytes(keyBytes);
String base64Key = Base64.getEncoder().encodeToString(keyBytes);
return base64Key;
}

private void checkMasterKey() {
if (masterKey == "0000000000000000" || masterKey == null) {
throw new MetaDataException("Please provide a masterKey for credential encryption! Example: PUT /_cluster/settings\n" +
"{\n" +
" \"persistent\" : {\n" +
" \"plugins.ml_commons.encryption.master_key\" : \"1234567x\" \n" +
" }\n" +
"}");
if (masterKey == null) {
throw new MetaDataException("Encryption key not created yet.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ public class MLEngineTest {

@Before
public void setUp() {
Encryptor encryptor = new EncryptorImpl("0000000000000000");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
Encryptor encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000000");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -128,7 +129,8 @@ public class MetricsCorrelationTest {
ActionListener<MLDeployModelResponse> mlDeployModelResponseActionListener;
private MetricsCorrelation metricsCorrelation;
private MetricsCorrelationInput input, extendedInput;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private MLModel model;

private MetricsCorrelationModelConfig modelConfig;
Expand All @@ -144,7 +146,6 @@ public class MetricsCorrelationTest {

Map<String, Object> params = new HashMap<>();

@Mock
private Encryptor encryptor;

public MetricsCorrelationTest() {
Expand All @@ -155,8 +156,11 @@ public void setUp() throws IOException, URISyntaxException {

System.setProperty("testMode", "true");

djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlConfigPath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor);
modelConfig = MetricsCorrelationModelConfig.builder()
.modelType(MetricsCorrelation.MODEL_TYPE)
.allConfig(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public class AwsConnectorExecutorTest {
@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public class RemoteModelTest {
public void setUp() {
MockitoAnnotations.openMocks(this);
remoteModel = new RemoteModel();
encryptor = spy(new EncryptorImpl("0000000000000001"));
encryptor = spy(new EncryptorImpl());
encryptor.setMasterKey("0000000000000001");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

import java.io.IOException;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -50,12 +52,16 @@ public class ModelHelperTest {
@Mock
ActionListener<MLRegisterModelInput> registerModelListener;

Encryptor encryptor;

@Before
public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
modelFormat = MLModelFormat.TORCH_SCRIPT;
modelId = "model_id";
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), null);
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), Path.of("/tmp/test_config"), encryptor);
modelHelper = new ModelHelper(mlEngine);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;

import java.io.File;
Expand Down Expand Up @@ -62,16 +63,20 @@ public class TextEmbeddingModelTest {
private ModelHelper modelHelper;
private Map<String, Object> params;
private TextEmbeddingModel textEmbeddingModel;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private TextDocsInputDataSet inputDataSet;
private int dimension = 384;
private MLEngine mlEngine;
private Encryptor encryptor;

@Before
public void setUp() throws URISyntaxException {
djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
mlConfigPath = Path.of("/tmp/ml_config" + UUID.randomUUID());
encryptor = new EncryptorImpl();
encryptor.setMasterKey("0000000000000001");
mlEngine = new MLEngine(mlCachePath, mlConfigPath, encryptor);
modelId = "test_model_id";
modelName = "test_model_name";
functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -329,7 +334,7 @@ public void predict_BeforeInitingModel() {

@After
public void tearDown() {
FileUtils.deleteFileQuietly(djlCachePath);
FileUtils.deleteFileQuietly(mlCachePath);
}

private int findSentenceEmbeddingPosition(ModelTensors modelTensors) {
Expand Down
Loading

0 comments on commit 4bd3c44

Please sign in to comment.