Skip to content

Commit

Permalink
set master key when init remote model
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>
  • Loading branch information
ylwu-amzn committed Jul 12, 2023
1 parent 15d5553 commit fe22f46
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.get.GetResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.core.xcontent.NamedXContentRegistry;
Expand All @@ -23,6 +28,11 @@
import org.opensearch.script.ScriptService;

import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicReference;

import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;

@Log4j2
@Function(FunctionName.REMOTE)
Expand Down Expand Up @@ -77,11 +87,40 @@ public boolean isModelReady() {
public void initModel(MLModel model, Map<String, Object> params, Encryptor encryptor) {
try {
Connector connector = model.getConnector().cloneConnector();
connector.decrypt((credential) -> encryptor.decrypt(credential));

ClusterService clusterService = (ClusterService) params.get(CLUSTER_SERVICE);
Client client = (Client) params.get(CLIENT);
CountDownLatch latch = new CountDownLatch(1);
AtomicReference<Exception> exceptionRef = new AtomicReference<>();
if (encryptor.getMasterKey() == null) {
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, new LatchedActionListener(ActionListener.< GetResponse >wrap(r-> {
if (r.isExists()) {
String masterKey = (String)r.getSourceAsMap().get(MASTER_KEY);
encryptor.setMasterKey(masterKey);
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}
}, e-> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
} else {
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
}
}

if (exceptionRef.get() != null) {
throw exceptionRef.get();
}
if (encryptor.getMasterKey() != null) {
connector.decrypt((credential) -> encryptor.decrypt(credential));
}
this.connectorExecutor = MLEngineClassLoader.initInstance(connector.getProtocol(), connector, Connector.class);
this.connectorExecutor.setScriptService((ScriptService) params.get(SCRIPT_SERVICE));
this.connectorExecutor.setClusterService((ClusterService) params.get(CLUSTER_SERVICE));
this.connectorExecutor.setClient((Client) params.get(CLIENT));
this.connectorExecutor.setClusterService(clusterService);
this.connectorExecutor.setClient(client);
this.connectorExecutor.setXContentRegistry((NamedXContentRegistry) params.get(XCONTENT_REGISTRY));
} catch (RuntimeException e) {
log.error("Failed to init remote model", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public interface Encryptor {
* @param masterKey masterKey to be set.
*/
void setMasterKey(String masterKey);
String getMasterKey();

String generateMasterKey();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ public void setMasterKey(String masterKey) {
this.masterKey = masterKey;
}

@Override
public String getMasterKey() {
return masterKey;
}

@Override
public String encrypt(String plainText) {
checkMasterKey();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.Version;
import org.opensearch.action.ActionListener;
import org.opensearch.action.get.GetResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
Expand All @@ -22,19 +32,38 @@
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;

import java.time.Instant;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.CREATE_TIME_FIELD;
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLIENT;
import static org.opensearch.ml.engine.algorithms.remote.RemoteModel.CLUSTER_SERVICE;

public class RemoteModelTest {

@Mock
MLInput mlInput;

@Mock
Client client;

@Mock
ClusterService clusterService;

@Mock
ClusterState clusterState;

@Mock
MLModel mlModel;

Expand All @@ -44,12 +73,46 @@ public class RemoteModelTest {
RemoteModel remoteModel;
Encryptor encryptor;

String masterKey;

Map<String, Object> params;
private static final AtomicInteger portGenerator = new AtomicInteger();

@Before
public void setUp() {
MockitoAnnotations.openMocks(this);
remoteModel = new RemoteModel();
encryptor = spy(new EncryptorImpl());
encryptor.setMasterKey("0000000000000001");
masterKey = "0000000000000001";
encryptor.setMasterKey(masterKey);
params = new HashMap<>();
params.put(CLIENT, client);
params.put(CLUSTER_SERVICE, clusterService);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
GetResponse response = mock(GetResponse.class);
when(response.isExists()).thenReturn(true);
when(response.getSourceAsMap())
.thenReturn(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
listener.onResponse(response);
return null;
}).when(client).get(any(), any());


when(clusterService.state()).thenReturn(clusterState);

Metadata metadata = new Metadata.Builder()
.indices(ImmutableMap
.<String, IndexMetadata>builder()
.put(ML_CONFIG_INDEX, IndexMetadata.builder(ML_CONFIG_INDEX)
.settings(Settings.builder()
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 1)
.put("index.version.created", Version.CURRENT.id))
.build())
.build()).build();
when(clusterState.metadata()).thenReturn(metadata);
}

@Test
Expand Down Expand Up @@ -112,6 +175,80 @@ public void initModel_WithHeader() {
Assert.assertNull(remoteModel.getConnectorExecutor());
}

@Test
public void initModel_WithHeader_NullMasterKey_MasterKeyExistInIndex() {
Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
Encryptor encryptor = new EncryptorImpl();
remoteModel.initModel(mlModel, params, encryptor);
Map<String, String> decryptedHeaders = connector.getDecryptedHeaders();
RemoteConnectorExecutor executor = remoteModel.getConnectorExecutor();
Assert.assertNotNull(executor);
Assert.assertNull(decryptedHeaders);
Assert.assertNotNull(executor.getConnector().getDecryptedHeaders());
Assert.assertEquals(1, executor.getConnector().getDecryptedHeaders().size());
Assert.assertEquals("Bearer test_api_key", executor.getConnector().getDecryptedHeaders().get("Authorization"));

remoteModel.close();
Assert.assertNull(remoteModel.getConnectorExecutor());
}

@Test
public void initModel_WithHeader_NullMasterKey_MasterKeyNotExistInIndex() {
exceptionRule.expect(ResourceNotFoundException.class);
exceptionRule.expectMessage("ML encryption master key not initialized yet");

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
GetResponse response = mock(GetResponse.class);
when(response.isExists()).thenReturn(false);
listener.onResponse(response);
return null;
}).when(client).get(any(), any());

Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
Encryptor encryptor = new EncryptorImpl();
remoteModel.initModel(mlModel, params, encryptor);
}

@Test
public void initModel_WithHeader_GetMasterKey_Exception() {
exceptionRule.expect(RuntimeException.class);
exceptionRule.expectMessage("test error");

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("test error"));
return null;
}).when(client).get(any(), any());

Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
Encryptor encryptor = new EncryptorImpl();
remoteModel.initModel(mlModel, params, encryptor);
}

@Test
public void initModel_WithHeader_IndexNotFound() {
exceptionRule.expect(ResourceNotFoundException.class);
exceptionRule.expectMessage("ML encryption master key not initialized yet");

Metadata metadata = new Metadata.Builder().indices(ImmutableMap.of()).build();
when(clusterState.metadata()).thenReturn(metadata);

doAnswer(invocation -> {
ActionListener<GetResponse> listener = invocation.getArgument(1);
listener.onFailure(new RuntimeException("test error"));
return null;
}).when(client).get(any(), any());

Connector connector = createConnector(ImmutableMap.of("Authorization", "Bearer ${credential.key}"));
when(mlModel.getConnector()).thenReturn(connector);
Encryptor encryptor = new EncryptorImpl();
remoteModel.initModel(mlModel, params, encryptor);
}

private Connector createConnector(Map<String, String> headers) {
ConnectorAction predictAction = ConnectorAction.builder()
.actionType(ConnectorAction.ActionType.PREDICT)
Expand Down

0 comments on commit fe22f46

Please sign in to comment.