Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

init master key automatically #1075

Merged
merged 10 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/CommonValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ public class CommonValue {
public static final String UNDEPLOYED = "undeployed";
public static final String NOT_FOUND = "not_found";

public static final String MASTER_KEY = "master_key";
public static final String CREATE_TIME_FIELD = "create_time";

public static final String BOX_TYPE_KEY = "box_type";
//hot node
public static String HOT_BOX_TYPE = "hot";
Expand All @@ -37,6 +40,8 @@ public class CommonValue {
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1;
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1;
public static final String USER_FIELD_MAPPING = " \""
+ CommonValue.USER
+ "\": {\n"
Expand Down Expand Up @@ -301,4 +306,19 @@ public class CommonValue {
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";


public static final String ML_CONFIG_INDEX_MAPPING = "{\n"
+ " \"_meta\": {\"schema_version\": "
+ ML_CONFIG_INDEX_SCHEMA_VERSION
+ "},\n"
+ " \"properties\": {\n"
+ " \""
+ MASTER_KEY
+ "\": {\"type\": \"keyword\"},\n"
+ " \""
+ CREATE_TIME_FIELD
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
+ " }\n"
+ "}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.engine;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
Expand All @@ -18,29 +19,33 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.encryptor.Encryptor;

import java.nio.file.Path;
import java.util.Locale;
import java.util.Map;

/**
* This is the interface to all ml algorithms.
*/
@Log4j2
public class MLEngine {

public static final String REGISTER_MODEL_FOLDER = "register";
public static final String DEPLOY_MODEL_FOLDER = "deploy";
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";

@Getter
private final Path mlConfigPath;

@Getter
private final Path mlCachePath;
private final Path mlModelsCachePath;

private final Encryptor encryptor;
private Encryptor encryptor;

public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
mlCachePath = opensearchDataFolder.resolve("ml_cache");
mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
this.mlConfigPath = mlCachePath.resolve("config");
this.encryptor = encryptor;
}

Expand Down Expand Up @@ -195,7 +200,4 @@ public String encrypt(String credential) {
return encryptor.encrypt(credential);
}

public void setMasterKey(String masterKey) {
encryptor.setMasterKey(masterKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
} else {
throw new IllegalArgumentException("Wrong input type");
}
Map<String, String> escapedParameters = new HashMap<>();
inputData.getParameters().entrySet().forEach(entry -> {
escapedParameters.put(entry.getKey(), escapeJava(entry.getValue()));
});
inputData.setParameters(escapedParameters);
return inputData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.engine.encryptor;

import java.security.SecureRandom;
import java.util.Base64;

public interface Encryptor {

/**
Expand All @@ -29,4 +32,8 @@ public interface Encryptor {
* @param masterKey masterKey to be set.
*/
void setMasterKey(String masterKey);
String getMasterKey();

String generateMasterKey();

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,39 @@
import com.amazonaws.encryptionsdk.CommitmentPolicy;
import com.amazonaws.encryptionsdk.CryptoResult;
import com.amazonaws.encryptionsdk.jce.JceMasterKey;
import org.opensearch.ml.engine.exceptions.MetaDataException;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.ml.common.exception.MLException;

import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.util.Base64;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;

@Log4j2
public class EncryptorImpl implements Encryptor {

private ClusterService clusterService;
private Client client;
private volatile String masterKey;

public EncryptorImpl(ClusterService clusterService, Client client) {
this.masterKey = null;
this.clusterService = clusterService;
this.client = client;
}

public EncryptorImpl(String masterKey) {
this.masterKey = masterKey;
}
Expand All @@ -28,9 +51,14 @@ public void setMasterKey(String masterKey) {
this.masterKey = masterKey;
}

@Override
public String getMasterKey() {
return masterKey;
}

@Override
public String encrypt(String plainText) {
checkMasterKey();
initMasterKey();
final AwsCrypto crypto = AwsCrypto.builder()
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
.build();
Expand All @@ -46,7 +74,7 @@ public String encrypt(String plainText) {

@Override
public String decrypt(String encryptedText) {
checkMasterKey();
initMasterKey();
final AwsCrypto crypto = AwsCrypto.builder()
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
.build();
Expand All @@ -60,14 +88,45 @@ public String decrypt(String encryptedText) {
return new String(decryptedResult.getResult());
}

private void checkMasterKey() {
if (masterKey == "0000000000000000" || masterKey == null) {
throw new MetaDataException("Please provide a masterKey for credential encryption! Example: PUT /_cluster/settings\n" +
"{\n" +
" \"persistent\" : {\n" +
" \"plugins.ml_commons.encryption.master_key\" : \"1234567x\" \n" +
" }\n" +
"}");
@Override
public String generateMasterKey() {
byte[] keyBytes = new byte[16];
new SecureRandom().nextBytes(keyBytes);
String base64Key = Base64.getEncoder().encodeToString(keyBytes);
return base64Key;
}

private void initMasterKey() {
if (masterKey != null) {
return;
}
AtomicReference<Exception> exceptionRef = new AtomicReference<>();

CountDownLatch latch = new CountDownLatch(1);
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
if (r.isExists()) {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
setMasterKey(masterKey);
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}

if (exceptionRef.get() != null) {
log.debug("Failed to init master key", exceptionRef.get());
if (exceptionRef.get() instanceof RuntimeException) {
throw (RuntimeException) exceptionRef.get();
} else {
throw new MLException(exceptionRef.get());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
Expand Down Expand Up @@ -128,7 +129,8 @@ public class MetricsCorrelationTest {
ActionListener<MLDeployModelResponse> mlDeployModelResponseActionListener;
private MetricsCorrelation metricsCorrelation;
private MetricsCorrelationInput input, extendedInput;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private MLModel model;

private MetricsCorrelationModelConfig modelConfig;
Expand All @@ -144,7 +146,6 @@ public class MetricsCorrelationTest {

Map<String, Object> params = new HashMap<>();

@Mock
private Encryptor encryptor;

public MetricsCorrelationTest() {
Expand All @@ -155,8 +156,9 @@ public void setUp() throws IOException, URISyntaxException {

System.setProperty("testMode", "true");

djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelConfig = MetricsCorrelationModelConfig.builder()
.modelType(MetricsCorrelation.MODEL_TYPE)
.allConfig(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
processInput_TextDocsInputDataSet_PreprocessFunction(
"{\"input\": ${parameters.input}}",
"{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }",
"[\\\"test_value1\\\",\\\"test_value2\\\"]");
"[\"test_value1\",\"test_value2\"]");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

import java.io.IOException;
import java.net.URISyntaxException;
Expand Down Expand Up @@ -50,12 +52,15 @@ public class ModelHelperTest {
@Mock
ActionListener<MLRegisterModelInput> registerModelListener;

Encryptor encryptor;

@Before
public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
modelFormat = MLModelFormat.TORCH_SCRIPT;
modelId = "model_id";
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), null);
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor);
modelHelper = new ModelHelper(mlEngine);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.engine.utils.FileUtils;

import java.io.File;
Expand Down Expand Up @@ -62,16 +63,18 @@ public class TextEmbeddingModelTest {
private ModelHelper modelHelper;
private Map<String, Object> params;
private TextEmbeddingModel textEmbeddingModel;
private Path djlCachePath;
private Path mlCachePath;
private Path mlConfigPath;
private TextDocsInputDataSet inputDataSet;
private int dimension = 384;
private MLEngine mlEngine;
private Encryptor encryptor;

@Before
public void setUp() throws URISyntaxException {
djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
mlEngine = new MLEngine(djlCachePath, encryptor);
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
encryptor = new EncryptorImpl("0000000000000001");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelId = "test_model_id";
modelName = "test_model_name";
functionName = FunctionName.TEXT_EMBEDDING;
Expand Down Expand Up @@ -329,7 +332,7 @@ public void predict_BeforeInitingModel() {

@After
public void tearDown() {
FileUtils.deleteFileQuietly(djlCachePath);
FileUtils.deleteFileQuietly(mlCachePath);
}

private int findSentenceEmbeddingPosition(ModelTensors modelTensors) {
Expand Down
Loading
Loading