Skip to content

Commit

Permalink
Fix UT failures and add UT for search connector transport action
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jul 9, 2023
1 parent 323226f commit 9f47bf8
Show file tree
Hide file tree
Showing 15 changed files with 248 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.logging.log4j.util.Strings;
Expand Down Expand Up @@ -155,13 +154,18 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
if (Boolean.TRUE.equals(r)) {
registerModel(registerModelInput, listener);
} else {
listener.onFailure(new IllegalArgumentException("You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId()));
listener
.onFailure(
new IllegalArgumentException(
"You don't have permission to use the connector provided, connector id: "
+ registerModelInput.getConnectorId()
)
);
}
}, e -> {
log
.error(
"You don't have permission to use the connector provided, connector id: "
+ registerModelInput.getConnectorId(),
"You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId(),
e
);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
if (searchHits.length == 0) {
deleteConnector(deleteRequest, connectorId, actionListener);
} else {
log.error(searchHits.length + " models are still using this connector, please delete or update the models first!");
actionListener.onFailure(new MLValidationException(searchHits.length + " models are still using this connector, please delete or update the models first!"));
log
.error(
searchHits.length + " models are still using this connector, please delete or update the models first!"
);
actionListener
.onFailure(
new MLValidationException(
searchHits.length
+ " models are still using this connector, please delete or update the models first!"
)
);
}
}, e -> {
log.error("Failed to delete ML connector: " + connectorId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLConn
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId));
actionListener
.onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

package org.opensearch.ml.action.remote;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -15,7 +19,6 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
Expand All @@ -27,10 +30,6 @@

import lombok.extern.log4j.Log4j2;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

@Log4j2
public class SearchConnectorTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

package org.opensearch.ml.breaker;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.monitor.jvm.JvmService;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD;

/**
* A circuit breaker for memory usage.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,15 @@ public String[] getDeployedModels() {
*/
public String[] getLocalDeployedModels() {
return modelCaches
.entrySet()
.stream()
.filter(entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED && entry.getValue().getFunctionName() != FunctionName.REMOTE))
.map(entry -> entry.getKey())
.collect(Collectors.toList())
.toArray(new String[0]);
.entrySet()
.stream()
.filter(
entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED
&& entry.getValue().getFunctionName() != FunctionName.REMOTE)
)
.map(entry -> entry.getKey())
.collect(Collectors.toList())
.toArray(new String[0]);
}

/**
Expand Down
26 changes: 7 additions & 19 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY;
import static org.opensearch.ml.stats.ActionName.REGISTER;
import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
Expand Down Expand Up @@ -105,7 +105,6 @@
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.exceptions.MetaDataException;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.profile.MLModelProfile;
Expand Down Expand Up @@ -209,12 +208,10 @@ public MLModelManager(
.addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it);

this.masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> {
masterKey = it;
mlEngine.setMasterKey(masterKey);
});
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> {
masterKey = it;
mlEngine.setMasterKey(masterKey);
});
}

public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
Expand Down Expand Up @@ -371,11 +368,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa
}
}

private void indexRemoteModel(
MLRegisterModelInput registerModelInput,
MLTask mlTask,
String modelVersion
) {
private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) {
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand Down Expand Up @@ -435,11 +428,7 @@ private void indexRemoteModel(
}
}

private void uploadModel(
MLRegisterModelInput registerModelInput,
MLTask mlTask,
String modelVersion
) throws PrivilegedActionException {
private void uploadModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) throws PrivilegedActionException {
if (registerModelInput.getUrl() != null) {
registerModelFromUrl(registerModelInput, mlTask, modelVersion);
} else if (registerModelInput.getFunctionName() == FunctionName.REMOTE || registerModelInput.getConnectorId() != null) {
Expand Down Expand Up @@ -1164,5 +1153,4 @@ public boolean isModelRunningOnNode(String modelId) {
return modelCacheHelper.isModelRunningOnNode(modelId);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Map;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
Expand Down Expand Up @@ -104,7 +103,18 @@ public void setup() {

Metadata metadata = mock(Metadata.class);
when(metadata.hasIndex(anyString())).thenReturn(true);
ClusterState testState = new ClusterState(new ClusterName("mock"), 123l, "111111", metadata, null, null, null, ImmutableOpenMap.of(), 0, false);
ClusterState testState = new ClusterState(
new ClusterName("mock"),
123l,
"111111",
metadata,
null,
null,
null,
ImmutableOpenMap.of(),
0,
false
);
when(clusterService.state()).thenReturn(testState);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import com.google.common.collect.ImmutableList;
import java.util.List;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
Expand Down Expand Up @@ -62,7 +63,7 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.util.List;
import com.google.common.collect.ImmutableList;

public class TransportRegisterModelActionTests extends OpenSearchTestCase {
@Rule
Expand Down Expand Up @@ -127,11 +128,8 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase {

private String trustedUrlRegex = "^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]";

private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of(
"^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$"
);
private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList
.of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$");

@Mock
private ModelAccessControlHelper modelAccessControlHelper;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.remote;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.ConfigConstants;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;


import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

public class SearchConnectorTransportActionTests extends OpenSearchTestCase {

@Mock
Client client;

@Mock
TransportService transportService;

@Mock
ActionFilters actionFilters;

@Mock
SearchRequest searchRequest;

SearchSourceBuilder searchSourceBuilder;

@Mock
FetchSourceContext fetchSourceContext;

@Mock
ActionListener<SearchResponse> actionListener;

@Mock
ThreadPool threadPool;

@Mock
private Task task;
SearchConnectorTransportAction searchConnectorTransportAction;
ThreadContext threadContext;
@Mock
private ConnectorAccessControlHelper connectorAccessControlHelper;

@Before
public void setup() {
MockitoAnnotations.openMocks(this);
searchConnectorTransportAction = new SearchConnectorTransportAction(transportService, actionFilters, client, connectorAccessControlHelper);

Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

searchSourceBuilder = new SearchSourceBuilder();
searchSourceBuilder.fetchSource(fetchSourceContext);
when(searchRequest.source()).thenReturn(searchSourceBuilder);
when(fetchSourceContext.includes()).thenReturn(new String[]{});
when(fetchSourceContext.excludes()).thenReturn(new String[]{});
}

public void test_doExecute_connectorAccessControlNotEnabled_searchSuccess() {
String userString = "admin|role-1|all_access";
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString);
when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(true);
SearchResponse searchResponse = mock(SearchResponse.class);
doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));
searchConnectorTransportAction.doExecute(task, searchRequest, actionListener);
verify(actionListener).onResponse(any(SearchResponse.class));
}

public void test_doExecute_connectorAccessControlEnabled_searchSuccess() {
when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(false);
SearchResponse searchResponse = mock(SearchResponse.class);
doAnswer(invocation -> {
ActionListener<SearchResponse> actionListener = invocation.getArgument(1);
actionListener.onResponse(searchResponse);
return null;
}).when(client).search(any(SearchRequest.class), isA(ActionListener.class));
searchConnectorTransportAction.doExecute(task, searchRequest, actionListener);
verify(actionListener).onResponse(any(SearchResponse.class));
}

public void test_doExecute_exception() {
when(searchRequest.source()).thenThrow(new RuntimeException("runtime exception"));
searchConnectorTransportAction.doExecute(task, searchRequest, actionListener);
verify(actionListener).onFailure(any(RuntimeException.class));
}


}
Loading

0 comments on commit 9f47bf8

Please sign in to comment.