Skip to content

Commit

Permalink
Add search connector uts (#1054)
Browse files Browse the repository at this point in the history
* Fix UT failures

Signed-off-by: zane-neo <[email protected]>

* Fix UT failures and add UT for search connector transport action

Signed-off-by: zane-neo <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo authored Jul 9, 2023
1 parent 56996d7 commit 7826b19
Show file tree
Hide file tree
Showing 18 changed files with 321 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import lombok.extern.log4j.Log4j2;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.AccessMode;

Expand All @@ -28,8 +29,9 @@ public class AwsConnector extends HttpConnector {
@Builder(builderMethodName = "awsConnectorBuilder")
public AwsConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode);
List<String> backendRoles, AccessMode accessMode, User owner
) {
super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner);
validate();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class HttpConnector extends AbstractConnector {
@Builder
public HttpConnector(String name, String description, String version, String protocol,
Map<String, String> parameters, Map<String, String> credential, List<ConnectorAction> actions,
List<String> backendRoles, AccessMode accessMode) {
List<String> backendRoles, AccessMode accessMode, User owner) {
this.name = name;
this.description = description;
this.version = version;
Expand All @@ -56,6 +56,7 @@ public HttpConnector(String name, String description, String version, String pro
this.actions = actions;
this.backendRoles = backendRoles;
this.access = accessMode;
this.owner = owner;
}

public HttpConnector(String protocol, XContentParser parser) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.logging.log4j.util.Strings;
Expand Down Expand Up @@ -155,13 +154,18 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener<
if (Boolean.TRUE.equals(r)) {
registerModel(registerModelInput, listener);
} else {
listener.onFailure(new IllegalArgumentException("You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId()));
listener
.onFailure(
new IllegalArgumentException(
"You don't have permission to use the connector provided, connector id: "
+ registerModelInput.getConnectorId()
)
);
}
}, e -> {
log
.error(
"You don't have permission to use the connector provided, connector id: "
+ registerModelInput.getConnectorId(),
"You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId(),
e
);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
if (searchHits.length == 0) {
deleteConnector(deleteRequest, connectorId, actionListener);
} else {
log.error(searchHits.length + " models are still using this connector, please delete or update the models first!");
actionListener.onFailure(new MLValidationException(searchHits.length + " models are still using this connector, please delete or update the models first!"));
log
.error(
searchHits.length + " models are still using this connector, please delete or update the models first!"
);
actionListener
.onFailure(
new MLValidationException(
searchHits.length
+ " models are still using this connector, please delete or update the models first!"
)
);
}
}, e -> {
log.error("Failed to delete ML connector: " + connectorId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLConn
actionListener.onFailure(e);
}
} else {
actionListener.onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId));
actionListener
.onFailure(new IllegalArgumentException("Failed to find connector with the provided connector id: " + connectorId));
}
}, e -> {
if (e instanceof IndexNotFoundException) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

package org.opensearch.ml.action.remote;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

import org.opensearch.action.ActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
Expand All @@ -15,7 +19,6 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.ml.common.CommonValue;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
Expand All @@ -27,10 +30,6 @@

import lombok.extern.log4j.Log4j2;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

@Log4j2
public class SearchConnectorTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

package org.opensearch.ml.breaker;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD;

import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.monitor.jvm.JvmService;

import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD;

/**
* A circuit breaker for memory usage.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,15 @@ public String[] getDeployedModels() {
*/
public String[] getLocalDeployedModels() {
return modelCaches
.entrySet()
.stream()
.filter(entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED && entry.getValue().getFunctionName() != FunctionName.REMOTE))
.map(entry -> entry.getKey())
.collect(Collectors.toList())
.toArray(new String[0]);
.entrySet()
.stream()
.filter(
entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED
&& entry.getValue().getFunctionName() != FunctionName.REMOTE)
)
.map(entry -> entry.getKey())
.collect(Collectors.toList())
.toArray(new String[0]);
}

/**
Expand Down
26 changes: 7 additions & 19 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY;
import static org.opensearch.ml.stats.ActionName.REGISTER;
import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
Expand Down Expand Up @@ -105,7 +105,6 @@
import org.opensearch.ml.engine.MLExecutable;
import org.opensearch.ml.engine.ModelHelper;
import org.opensearch.ml.engine.Predictable;
import org.opensearch.ml.engine.exceptions.MetaDataException;
import org.opensearch.ml.engine.utils.FileUtils;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.ml.profile.MLModelProfile;
Expand Down Expand Up @@ -209,12 +208,10 @@ public MLModelManager(
.addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it);

this.masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> {
masterKey = it;
mlEngine.setMasterKey(masterKey);
});
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> {
masterKey = it;
mlEngine.setMasterKey(masterKey);
});
}

public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
Expand Down Expand Up @@ -371,11 +368,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa
}
}

private void indexRemoteModel(
MLRegisterModelInput registerModelInput,
MLTask mlTask,
String modelVersion
) {
private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) {
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand Down Expand Up @@ -435,11 +428,7 @@ private void indexRemoteModel(
}
}

private void uploadModel(
MLRegisterModelInput registerModelInput,
MLTask mlTask,
String modelVersion
) throws PrivilegedActionException {
private void uploadModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) throws PrivilegedActionException {
if (registerModelInput.getUrl() != null) {
registerModelFromUrl(registerModelInput, mlTask, modelVersion);
} else if (registerModelInput.getFunctionName() == FunctionName.REMOTE || registerModelInput.getConnectorId() != null) {
Expand Down Expand Up @@ -1164,5 +1153,4 @@ public boolean isModelRunningOnNode(String modelId) {
return modelCacheHelper.isModelRunningOnNode(modelId);
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.util.Map;

import org.apache.lucene.search.TotalHits;
import org.junit.Before;
Expand Down Expand Up @@ -104,7 +103,18 @@ public void setup() {

Metadata metadata = mock(Metadata.class);
when(metadata.hasIndex(anyString())).thenReturn(true);
ClusterState testState = new ClusterState(new ClusterName("mock"), 123l, "111111", metadata, null, null, null, ImmutableOpenMap.of(), 0, false);
ClusterState testState = new ClusterState(
new ClusterName("mock"),
123l,
"111111",
metadata,
null,
null,
null,
ImmutableOpenMap.of(),
0,
false
);
when(clusterService.state()).thenReturn(testState);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX;
import static org.opensearch.ml.utils.TestHelper.clusterSetting;

import com.google.common.collect.ImmutableList;
import java.util.List;

import org.junit.Before;
import org.junit.Rule;
import org.junit.rules.ExpectedException;
Expand Down Expand Up @@ -62,7 +63,7 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.util.List;
import com.google.common.collect.ImmutableList;

public class TransportRegisterModelActionTests extends OpenSearchTestCase {
@Rule
Expand Down Expand Up @@ -127,11 +128,8 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase {

private String trustedUrlRegex = "^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]";

private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of(
"^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$",
"^https://api\\.openai\\.com/.*$",
"^https://api\\.cohere\\.ai/.*$"
);
private static final List<String> TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList
.of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$");

@Mock
private ModelAccessControlHelper modelAccessControlHelper;
Expand Down
Loading

0 comments on commit 7826b19

Please sign in to comment.