diff --git a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java index 3c6f6c0694..cf6c670eaf 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/AwsConnector.java @@ -10,6 +10,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.commons.authuser.User; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.AccessMode; @@ -28,8 +29,9 @@ public class AwsConnector extends HttpConnector { @Builder(builderMethodName = "awsConnectorBuilder") public AwsConnector(String name, String description, String version, String protocol, Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode) { - super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode); + List backendRoles, AccessMode accessMode, User owner + ) { + super(name, description, version, protocol, parameters, credential, actions, backendRoles, accessMode, owner); validate(); } diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index dc7b20c5e2..9714a92994 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -46,7 +46,7 @@ public class HttpConnector extends AbstractConnector { @Builder public HttpConnector(String name, String description, String version, String protocol, Map parameters, Map credential, List actions, - List backendRoles, AccessMode accessMode) { + List backendRoles, AccessMode accessMode, User owner) { this.name = name; this.description = description; this.version = version; @@ -56,6 +56,7 @@ public HttpConnector(String name, String description, String version, String pro this.actions = actions; this.backendRoles = backendRoles; this.access = accessMode; + this.owner = owner; } public HttpConnector(String protocol, XContentParser parser) throws IOException { diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index a9dca29939..d7de1f0a79 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -15,7 +15,6 @@ import java.time.Instant; import java.util.Arrays; import java.util.List; -import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.logging.log4j.util.Strings; @@ -155,13 +154,18 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener< if (Boolean.TRUE.equals(r)) { registerModel(registerModelInput, listener); } else { - listener.onFailure(new IllegalArgumentException("You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId())); + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to use the connector provided, connector id: " + + registerModelInput.getConnectorId() + ) + ); } }, e -> { log .error( - "You don't have permission to use the connector provided, connector id: " - + registerModelInput.getConnectorId(), + "You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId(), e ); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java index 0a18bc92ff..36c5aeb341 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java @@ -72,8 +72,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to delete ML connector: " + connectorId, e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java index bc9f9782d3..854623db7e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java @@ -87,7 +87,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java index 17ed3c0430..70d349e06b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java @@ -5,6 +5,10 @@ package org.opensearch.ml.action.remote; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + import org.opensearch.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -15,7 +19,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.ml.common.CommonValue; -import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.helper.ConnectorAccessControlHelper; @@ -27,10 +30,6 @@ import lombok.extern.log4j.Log4j2; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - @Log4j2 public class SearchConnectorTransportAction extends HandledTransportAction { diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java index 6697ab4bf3..8b467c2452 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java @@ -5,12 +5,12 @@ package org.opensearch.ml.breaker; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; + import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.monitor.jvm.JvmService; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; - /** * A circuit breaker for memory usage. */ diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 512dfc8920..74dbc26d61 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -157,12 +157,15 @@ public String[] getDeployedModels() { */ public String[] getLocalDeployedModels() { return modelCaches - .entrySet() - .stream() - .filter(entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED && entry.getValue().getFunctionName() != FunctionName.REMOTE)) - .map(entry -> entry.getKey()) - .collect(Collectors.toList()) - .toArray(new String[0]); + .entrySet() + .stream() + .filter( + entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED + && entry.getValue().getFunctionName() != FunctionName.REMOTE) + ) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()) + .toArray(new String[0]); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index e0ea78fcdd..a41c1037f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -34,10 +34,10 @@ import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.stats.ActionName.REGISTER; import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT; import static org.opensearch.ml.utils.MLExceptionUtils.logException; @@ -105,7 +105,6 @@ import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.Predictable; -import org.opensearch.ml.engine.exceptions.MetaDataException; import org.opensearch.ml.engine.utils.FileUtils; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.profile.MLModelProfile; @@ -209,12 +208,10 @@ public MLModelManager( .addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it); this.masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> { - masterKey = it; - mlEngine.setMasterKey(masterKey); - }); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> { + masterKey = it; + mlEngine.setMasterKey(masterKey); + }); } public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { @@ -371,11 +368,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa } } - private void indexRemoteModel( - MLRegisterModelInput registerModelInput, - MLTask mlTask, - String modelVersion - ) { + private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) { String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -435,11 +428,7 @@ private void indexRemoteModel( } } - private void uploadModel( - MLRegisterModelInput registerModelInput, - MLTask mlTask, - String modelVersion - ) throws PrivilegedActionException { + private void uploadModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) throws PrivilegedActionException { if (registerModelInput.getUrl() != null) { registerModelFromUrl(registerModelInput, mlTask, modelVersion); } else if (registerModelInput.getFunctionName() == FunctionName.REMOTE || registerModelInput.getConnectorId() != null) { @@ -1164,5 +1153,4 @@ public boolean isModelRunningOnNode(String modelId) { return modelCacheHelper.isModelRunningOnNode(modelId); } - } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 8bc6a306c4..db5371ae50 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -16,7 +16,6 @@ import static org.mockito.Mockito.when; import java.io.IOException; -import java.util.Map; import org.apache.lucene.search.TotalHits; import org.junit.Before; @@ -104,7 +103,18 @@ public void setup() { Metadata metadata = mock(Metadata.class); when(metadata.hasIndex(anyString())).thenReturn(true); - ClusterState testState = new ClusterState(new ClusterName("mock"), 123l, "111111", metadata, null, null, null, ImmutableOpenMap.of(), 0, false); + ClusterState testState = new ClusterState( + new ClusterName("mock"), + 123l, + "111111", + metadata, + null, + null, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(testState); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index f60a13c80f..ebd3d9e98a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -18,7 +18,8 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; -import com.google.common.collect.ImmutableList; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -62,7 +63,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.util.List; +import com.google.common.collect.ImmutableList; public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Rule @@ -127,11 +128,8 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { private String trustedUrlRegex = "^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]"; - private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of( - "^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", - "^https://api\\.openai\\.com/.*$", - "^https://api\\.cohere\\.ai/.*$" - ); + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList + .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Mock private ModelAccessControlHelper modelAccessControlHelper; diff --git a/plugin/src/test/java/org/opensearch/ml/action/remote/SearchConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/remote/SearchConnectorTransportActionTests.java new file mode 100644 index 0000000000..74d7b83afc --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/remote/SearchConnectorTransportActionTests.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.remote; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SearchConnectorTransportActionTests extends OpenSearchTestCase { + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + SearchRequest searchRequest; + + SearchSourceBuilder searchSourceBuilder; + + @Mock + FetchSourceContext fetchSourceContext; + + @Mock + ActionListener actionListener; + + @Mock + ThreadPool threadPool; + + @Mock + private Task task; + SearchConnectorTransportAction searchConnectorTransportAction; + ThreadContext threadContext; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + searchConnectorTransportAction = new SearchConnectorTransportAction(transportService, actionFilters, client, connectorAccessControlHelper); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + when(fetchSourceContext.includes()).thenReturn(new String[]{}); + when(fetchSourceContext.excludes()).thenReturn(new String[]{}); + } + + public void test_doExecute_connectorAccessControlNotEnabled_searchSuccess() { + String userString = "admin|role-1|all_access"; + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString); + when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(true); + SearchResponse searchResponse = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + + public void test_doExecute_connectorAccessControlEnabled_searchSuccess() { + when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(false); + SearchResponse searchResponse = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + + public void test_doExecute_exception() { + when(searchRequest.source()).thenThrow(new RuntimeException("runtime exception")); + searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + verify(actionListener).onFailure(any(RuntimeException.class)); + } + + +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java index f0d0fea7cc..85fc567e6e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java @@ -78,7 +78,6 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { @Mock private MLCreateConnectorRequest request; - @Mock private MLCreateConnectorInput input; @Mock @@ -94,16 +93,16 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { private Settings settings; - private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of( - "^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", - "^https://api\\.openai\\.com/.*$", - "^https://api\\.cohere\\.ai/.*$" - ); + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList + .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Before public void setup() { MockitoAnnotations.openMocks(this); - settings = Settings.builder().putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES).build(); + settings = Settings + .builder() + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, @@ -121,7 +120,6 @@ public void setup() { clusterService, mlModelManager ); - when(request.getMlCreateConnectorInput()).thenReturn(input); Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), true).build(); threadContext = new ThreadContext(settings); threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); @@ -139,16 +137,23 @@ public void setup() { .url("https://${parameters.endpoint}/v1/completions") .build() ); - when(input.getActions()).thenReturn(actions); Map parameters = ImmutableMap.of("endpoint", "api.openai.com"); - when(input.getParameters()).thenReturn(parameters); + Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); + input = MLCreateConnectorInput + .builder() + .actions(actions) + .parameters(parameters) + .protocol(ConnectorProtocols.HTTP) + .credential(credential) + .build(); + when(request.getMlCreateConnectorInput()).thenReturn(input); } public void test_execute_connectorAccessControl_notEnabled_success() { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(true); - when(input.getAddAllBackendRoles()).thenReturn(null); - when(input.getBackendRoles()).thenReturn(null); + input.setAddAllBackendRoles(null); + input.setBackendRoles(null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -167,8 +172,8 @@ public void test_execute_connectorAccessControl_notEnabled_success() { public void test_execute_connectorAccessControl_notEnabled_withPermissionInfo_exception() { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(true); - when(input.getAddAllBackendRoles()).thenReturn(true); - when(input.getBackendRoles()).thenReturn(null); + input.setBackendRoles(null); + input.setAddAllBackendRoles(true); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -192,8 +197,8 @@ public void test_execute_connectorAccessControl_notEnabled_withPermissionInfo_ex public void test_execute_connectorAccessControlEnabled_success() { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); - when(input.getAddAllBackendRoles()).thenReturn(false); - when(input.getBackendRoles()).thenReturn(ImmutableList.of("role1", "role2")); + input.setAddAllBackendRoles(false); + input.setBackendRoles(ImmutableList.of("role1", "role2")); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -210,10 +215,10 @@ public void test_execute_connectorAccessControlEnabled_success() { verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); } - public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_exception() { + public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_defaultToPrivate() { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); - when(input.getAddAllBackendRoles()).thenReturn(null); - when(input.getBackendRoles()).thenReturn(null); + input.setAddAllBackendRoles(null); + input.setBackendRoles(null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -227,19 +232,14 @@ public void test_execute_connectorAccessControlEnabled_missingPermissionInfo_exc return null; }).when(client).index(any(IndexRequest.class), isA(ActionListener.class)); action.doExecute(task, request, actionListener); - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); - verify(actionListener).onFailure(argumentCaptor.capture()); - assertEquals( - "You must specify at least one backend role or make the connector public/private for registering it.", - argumentCaptor.getValue().getMessage() - ); + verify(actionListener).onResponse(any(MLCreateConnectorResponse.class)); } public void test_execute_connectorAccessControlEnabled_adminSpecifyAllBackendRoles_exception() { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); when(connectorAccessControlHelper.isAdmin(any(User.class))).thenReturn(true); - when(input.getAddAllBackendRoles()).thenReturn(true); - when(input.getBackendRoles()).thenReturn(null); + input.setAddAllBackendRoles(true); + input.setBackendRoles(null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -262,9 +262,9 @@ public void test_execute_connectorAccessControlEnabled_adminSpecifyAllBackendRol public void test_execute_connectorAccessControlEnabled_specifyBackendRolesForPublicConnector_exception() { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); - when(input.getAddAllBackendRoles()).thenReturn(true); - when(input.getAccess()).thenReturn(AccessMode.PUBLIC); - when(input.getBackendRoles()).thenReturn(null); + input.setAddAllBackendRoles(true); + input.setAccess(AccessMode.PUBLIC); + input.setBackendRoles(null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -296,8 +296,8 @@ public void test_execute_connectorAccessControlEnabled_userNoBackendRoles_except when(threadPool.getThreadContext()).thenReturn(threadContext); when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); - when(input.getAddAllBackendRoles()).thenReturn(true); - when(input.getBackendRoles()).thenReturn(null); + input.setAddAllBackendRoles(true); + input.setBackendRoles(null); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -337,8 +337,8 @@ public void test_execute_connectorAccessControlEnabled_parameterConflict_excepti when(threadPool.getThreadContext()).thenReturn(threadContext); when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); - when(input.getAddAllBackendRoles()).thenReturn(true); - when(input.getBackendRoles()).thenReturn(ImmutableList.of("role1", "role2")); + input.setAddAllBackendRoles(true); + input.setBackendRoles(ImmutableList.of("role1", "role2")); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -381,8 +381,8 @@ public void test_execute_connectorAccessControlEnabled_specifyNotBelongedRole_ex when(threadPool.getThreadContext()).thenReturn(threadContext); when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); - when(input.getAddAllBackendRoles()).thenReturn(false); - when(input.getBackendRoles()).thenReturn(ImmutableList.of("role1", "role2")); + input.setAddAllBackendRoles(false); + input.setBackendRoles(ImmutableList.of("role1", "role2")); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -414,8 +414,8 @@ public void test_execute_connectorAccessControlEnabled_specifyNotBelongedRole_ex public void test_execute_dryRun_connector_creation() { when(connectorAccessControlHelper.accessControlNotEnabled(any(User.class))).thenReturn(false); - when(input.getAddAllBackendRoles()).thenReturn(false); - when(input.getBackendRoles()).thenReturn(ImmutableList.of("role1", "role2")); + input.setAddAllBackendRoles(false); + input.setBackendRoles(ImmutableList.of("role1", "role2")); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(0); @@ -454,13 +454,12 @@ public void test_execute_URL_notMatchingExpression_exception() { .description(randomAlphaOfLength(10)) .version("1") .protocol(ConnectorProtocols.HTTP) - .parameters(ImmutableMap.of("k1", "v1", "k2", "v2")) .actions(actions) .build(); MLCreateConnectorRequest request = new MLCreateConnectorRequest(mlCreateConnectorInput); Map parameters = ImmutableMap.of("endpoint", "api.openai1.com"); - when(input.getParameters()).thenReturn(parameters); + mlCreateConnectorInput.setParameters(parameters); TransportCreateConnectorAction action = new TransportCreateConnectorAction( transportService, actionFilters, @@ -476,7 +475,7 @@ public void test_execute_URL_notMatchingExpression_exception() { ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); verify(actionListener).onFailure(argumentCaptor.capture()); assertEquals( - "Connector URL is not matching the trusted connector endpoint regex, regex is: ^https://(runtime\\.sagemaker\\..*\\.amazonaws\\.com/|api.openai.com|api.cohere.ai).*$,URL is: https://${parameters.endpoint}/v1/completions", + "Connector URL is not matching the trusted connector endpoint regex, URL is: https://api.openai1.com/v1/completions", argumentCaptor.getValue().getMessage() ); } diff --git a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java index f566593cf3..cdb50648dc 100644 --- a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java @@ -18,11 +18,8 @@ import java.nio.file.Path; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.function.Consumer; - -import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.Mock; @@ -635,7 +632,10 @@ private MLModel buildModelWithJsonFile(String file) throws Exception { private void mockClusterDataNodes(ClusterService clusterService) { ClusterState clusterState = mock(ClusterState.class); DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); - ImmutableOpenMap dataNodes = ImmutableOpenMap.builder().fPut("dataNodeId", mock(DiscoveryNode.class)).build(); + ImmutableOpenMap dataNodes = ImmutableOpenMap + .builder() + .fPut("dataNodeId", mock(DiscoveryNode.class)) + .build(); when(discoveryNodes.getDataNodes()).thenReturn(dataNodes); when(discoveryNodes.getSize()).thenReturn(2); // a ml node join cluster. when(clusterState.nodes()).thenReturn(discoveryNodes); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java index 6485964682..83f283b89b 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java @@ -17,7 +17,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.HashSet; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -130,7 +129,18 @@ public void setup() throws IOException { .add(mlNode1) .add(mlNode2) .build(); - clusterState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + clusterState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(clusterState); discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); @@ -159,7 +169,18 @@ public void testGetEligibleNodes_DataNode() { mockSettings(false, nonExistingNodeName); DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); DiscoveryNodes nodes = DiscoveryNodes.builder().add(clusterManagerNode).add(dataNode1).add(dataNode2).add(warmDataNode1).build(); - clusterState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + clusterState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(clusterState); DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(); diff --git a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java index 684e8c0d8c..933cd4d825 100644 --- a/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/helper/ConnectorAccessControlHelperTests.java @@ -40,6 +40,8 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.CommonValue; +import org.opensearch.ml.common.connector.ConnectorProtocols; +import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; @@ -91,72 +93,72 @@ public void setup() { } public void test_hasPermission_user_null_return_true() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); - boolean hasPermission = connectorAccessControlHelper.hasPermission(null, detachedConnector); + HttpConnector httpConnector = mock(HttpConnector.class); + boolean hasPermission = connectorAccessControlHelper.hasPermission(null, httpConnector); assertTrue(hasPermission); } public void test_hasPermission_connectorAccessControl_not_enabled_return_true() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); + HttpConnector httpConnector = mock(HttpConnector.class); Settings settings = Settings.builder().put(ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED.getKey(), false).build(); ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); ConnectorAccessControlHelper connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user, detachedConnector); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); assertTrue(hasPermission); } public void test_hasPermission_connectorOwner_is_null_return_true() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); - when(detachedConnector.getOwner()).thenReturn(null); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user, detachedConnector); + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getOwner()).thenReturn(null); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); assertTrue(hasPermission); } public void test_hasPermission_user_is_admin_return_true() { User user = User.parse("admin|role-1|all_access"); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user, mock(DetachedConnector.class)); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, mock(HttpConnector.class)); assertTrue(hasPermission); } public void test_hasPermission_connector_isPublic_return_true() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); - when(detachedConnector.getAccess()).thenReturn(AccessMode.PUBLIC); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user, detachedConnector); + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.PUBLIC); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); assertTrue(hasPermission); } public void test_hasPermission_connector_isPrivate_userIsOwner_return_true() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); - when(detachedConnector.getAccess()).thenReturn(AccessMode.PRIVATE); - when(detachedConnector.getOwner()).thenReturn(user); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user, detachedConnector); + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); + when(httpConnector.getOwner()).thenReturn(user); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); assertTrue(hasPermission); } public void test_hasPermission_connector_isPrivate_userIsNotOwner_return_false() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); - when(detachedConnector.getAccess()).thenReturn(AccessMode.PRIVATE); + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.PRIVATE); User user1 = User.parse(USER_STRING); - when(detachedConnector.getOwner()).thenReturn(user); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user1, detachedConnector); + when(httpConnector.getOwner()).thenReturn(user); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user1, httpConnector); assertFalse(hasPermission); } public void test_hasPermission_connector_isRestricted_userHasBackendRole_return_true() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); - when(detachedConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); - when(detachedConnector.getBackendRoles()).thenReturn(ImmutableList.of("role-1")); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user, detachedConnector); + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); + when(httpConnector.getBackendRoles()).thenReturn(ImmutableList.of("role-1")); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); assertTrue(hasPermission); } public void test_hasPermission_connector_isRestricted_userNotHasBackendRole_return_false() { - DetachedConnector detachedConnector = mock(DetachedConnector.class); - when(detachedConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); - when(detachedConnector.getBackendRoles()).thenReturn(ImmutableList.of("role-3")); - when(detachedConnector.getOwner()).thenReturn(user); - boolean hasPermission = connectorAccessControlHelper.hasPermission(user, detachedConnector); + HttpConnector httpConnector = mock(HttpConnector.class); + when(httpConnector.getAccess()).thenReturn(AccessMode.RESTRICTED); + when(httpConnector.getBackendRoles()).thenReturn(ImmutableList.of("role-3")); + when(httpConnector.getOwner()).thenReturn(user); + boolean hasPermission = connectorAccessControlHelper.hasPermission(user, httpConnector); assertFalse(hasPermission); } @@ -278,17 +280,18 @@ public void test_addUserBackendRolesFilter_nonBoolQuery() { } private GetResponse createGetResponse(List backendRoles) { - DetachedConnector detachedConnector = DetachedConnector + HttpConnector httpConnector = HttpConnector .builder() .name("testConnector") - .description("This is test connector") + .protocol(ConnectorProtocols.HTTP) .owner(user) + .description("This is test connector") .backendRoles(Optional.ofNullable(backendRoles).orElse(ImmutableList.of("role-1"))) - .access(AccessMode.RESTRICTED) + .accessMode(AccessMode.RESTRICTED) .build(); XContentBuilder content = null; try { - content = detachedConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); + content = httpConnector.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java index 9a81100f57..32d9b921f7 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java @@ -83,7 +83,18 @@ public void setup() { Set mlRoleSet = ImmutableSet.of(ML_ROLE); mlNode = new DiscoveryNode("mlNode", buildNewFakeTransportAddress(), new HashMap<>(), mlRoleSet, Version.CURRENT); DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).build(); - testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + testState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(testState); doAnswer(invocation -> { @@ -146,7 +157,18 @@ public void testGetEligibleNodes_DataNodeOnly() { @Ignore public void testGetEligibleNodes_MlAndDataNodes() { DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).add(mlNode).build(); - testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + testState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(testState); DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index dff7d45486..8fb98de50c 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -64,7 +64,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.Index; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; @@ -321,7 +320,10 @@ public static ClusterState state( final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); IndexMetadata indexMetaData = IndexMetadata.builder(indexName).settings(existingSettings).putMapping(mapping).build(); - final ImmutableOpenMap indices = ImmutableOpenMap.builder().fPut(indexName, indexMetaData).build(); + final ImmutableOpenMap indices = ImmutableOpenMap + .builder() + .fPut(indexName, indexMetaData) + .build(); return ClusterState.builder(name).metadata(Metadata.builder().indices(indices).build()).build(); } @@ -373,10 +375,11 @@ public static ClusterState setupTestClusterState() { .put("index.version.created", Version.CURRENT.id) ) .build(); - ImmutableOpenMap indices = ImmutableOpenMap.builder().fPut(ML_MODEL_INDEX, indexMetadata).build(); - Metadata metadata = new Metadata.Builder() - .indices(indices) + ImmutableOpenMap indices = ImmutableOpenMap + .builder() + .fPut(ML_MODEL_INDEX, indexMetadata) .build(); + Metadata metadata = new Metadata.Builder().indices(indices).build(); return new ClusterState( new ClusterName("test cluster"), 123l,