Skip to content

Commit

Permalink
stash context before accessing ml config index; increase master key s…
Browse files Browse the repository at this point in the history
…ize to 32 (#1092)

Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn authored Jul 12, 2023
1 parent 66672d9 commit d9d1190
Show file tree
Hide file tree
Showing 16 changed files with 92 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.action.get.GetResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.common.exception.MLException;

import javax.crypto.spec.SecretKeySpec;
Expand Down Expand Up @@ -62,9 +63,9 @@ public String encrypt(String plainText) {
final AwsCrypto crypto = AwsCrypto.builder()
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
.build();

byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey
= JceMasterKey.getInstance(new SecretKeySpec(masterKey.getBytes(), "AES"), "Custom", "",
= JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "",
"AES/GCM/NoPadding");

final CryptoResult<byte[], JceMasterKey> encryptResult = crypto.encryptData(jceMasterKey,
Expand All @@ -79,8 +80,9 @@ public String decrypt(String encryptedText) {
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
.build();

byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey
= JceMasterKey.getInstance(new SecretKeySpec(masterKey.getBytes(), "AES"), "Custom", "",
= JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "",
"AES/GCM/NoPadding");

final CryptoResult<byte[], JceMasterKey> decryptedResult
Expand All @@ -90,7 +92,7 @@ public String decrypt(String encryptedText) {

@Override
public String generateMasterKey() {
byte[] keyBytes = new byte[16];
byte[] keyBytes = new byte[32];
new SecureRandom().nextBytes(keyBytes);
String base64Key = Base64.getEncoder().encodeToString(keyBytes);
return base64Key;
Expand All @@ -104,18 +106,20 @@ private void initMasterKey() {

CountDownLatch latch = new CountDownLatch(1);
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
if (r.isExists()) {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
setMasterKey(masterKey);
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
if (r.isExists()) {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
setMasterKey(masterKey);
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
}
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class MLEngineTest {

@Before
public void setUp() {
Encryptor encryptor = new EncryptorImpl("0000000000000000");
Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public void setUp() throws IOException, URISyntaxException {
System.setProperty("testMode", "true");

mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelConfig = MetricsCorrelationModelConfig.builder()
.modelType(MetricsCorrelation.MODEL_TYPE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public class AwsConnectorExecutorTest {
@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class RemoteModelTest {
public void setUp() {
MockitoAnnotations.openMocks(this);
remoteModel = new RemoteModel();
encryptor = spy(new EncryptorImpl("0000000000000001"));
encryptor = spy(new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public void setup() throws URISyntaxException {
MockitoAnnotations.openMocks(this);
modelFormat = MLModelFormat.TORCH_SCRIPT;
modelId = "model_id";
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor);
modelHelper = new ModelHelper(mlEngine);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public class TextEmbeddingModelTest {
@Before
public void setUp() throws URISyntaxException {
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(mlCachePath, encryptor);
modelId = "test_model_id";
modelName = "test_model_name";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.ConfigConstants;
import org.opensearch.threadpool.ThreadPool;

import java.time.Instant;

Expand All @@ -43,10 +46,15 @@ public class EncryptorImplTest {

String masterKey;

@Mock
ThreadPool threadPool;
ThreadContext threadContext;
final String USER_STRING = "myuser|role1,role2|myTenant";

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
masterKey = "0000000000000001";
masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=";

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
Expand All @@ -72,6 +80,12 @@ public void setUp() {
.build())
.build()).build();
when(clusterState.metadata()).thenReturn(metadata);

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

@Test
Expand All @@ -83,6 +97,17 @@ public void encrypt() {
Assert.assertEquals(masterKey, encryptor.getMasterKey());
}

@Test
public void encrypt_DifferentMasterKey() {
Encryptor encryptor = new EncryptorImpl(masterKey);
Assert.assertNotNull(encryptor.getMasterKey());
String encrypted1 = encryptor.encrypt("test");

encryptor.setMasterKey(encryptor.generateMasterKey());
String encrypted2 = encryptor.encrypt("test");
Assert.assertNotEquals(encrypted1, encrypted2);
}

@Test
public void decrypt() {
Encryptor encryptor = new EncryptorImpl(clusterService, client);
Expand Down
35 changes: 19 additions & 16 deletions plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.common.MLModel;
Expand Down Expand Up @@ -168,24 +169,26 @@ void initMLConfig() {
}
mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, ActionListener.wrap(getResponse -> {
if (!getResponse.isExists()) {
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
final String masterKey = encryptor.generateMasterKey();
indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
log.info("ML configuration initialized successfully");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getRequest, ActionListener.wrap(getResponse -> {
if (!getResponse.isExists()) {
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
final String masterKey = encryptor.generateMasterKey();
indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
log.info("ML configuration initialized successfully");
encryptor.setMasterKey(masterKey);
mlConfigInited = true;
}, e -> { log.debug("Failed to save ML encryption master key", e); }));
} else {
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
encryptor.setMasterKey(masterKey);
mlConfigInited = true;
}, e -> { log.debug("Failed to save ML encryption master key", e); }));
} else {
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
encryptor.setMasterKey(masterKey);
mlConfigInited = true;
log.info("ML configuration already initialized, no action needed");
}
}, e -> { log.debug("Failed to get ML encryption master key", e); }));
log.info("ML configuration already initialized, no action needed");
}
}, e -> { log.debug("Failed to get ML encryption master key", e); }));
}
}, e -> { log.debug("Failed to init ML config index", e); }));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public void setup() {
clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN)));
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);

encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
modelHelper = new ModelHelper(mlEngine);
when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.ConfigConstants;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
Expand All @@ -75,12 +78,12 @@
import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.search.suggest.Suggest;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

public class MLSyncUpCronTests extends OpenSearchTestCase {

@Mock
private Client client;
@Mock
Expand All @@ -100,6 +103,11 @@ public class MLSyncUpCronTests extends OpenSearchTestCase {
private ClusterState testState;
private Encryptor encryptor;

@Mock
ThreadPool threadPool;
ThreadContext threadContext;
final String USER_STRING = "myuser|role1,role2|myTenant";

@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
Expand All @@ -116,6 +124,12 @@ public void setup() throws IOException {
actionListener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLConfigIndex(any());

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);
}

public void testInitMlConfig_MasterKeyNotExist() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ public class MLModelManagerTests extends OpenSearchTestCase {

@Before
public void setup() throws URISyntaxException {
String masterKey = "0000000000000001";
String masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=";
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl(masterKey);
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("0000000000000000");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor);
when(threadPool.executor(anyString())).thenReturn(executorService);
doAnswer(invocation -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public class MLPredictTaskRunnerTests extends OpenSearchTestCase {
@Before
public void setup() throws IOException {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT);
remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ public class MLTrainAndPredictTaskRunnerTests extends OpenSearchTestCase {

@Before
public void setup() {
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
settings = Settings.builder().build();
MockitoAnnotations.openMocks(this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public class MLTrainingTaskRunnerTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
encryptor = new EncryptorImpl("0000000000000001");
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
mlEngine = new MLEngine(Path.of("/tmp/djl-cache_" + randomAlphaOfLength(10)), encryptor);
localNode = new DiscoveryNode("localNodeId", buildNewFakeTransportAddress(), Version.CURRENT);
remoteNode = new DiscoveryNode("remoteNodeId", buildNewFakeTransportAddress(), Version.CURRENT);
Expand Down

0 comments on commit d9d1190

Please sign in to comment.