From 9f47bf8d2303fe273498876c100c076ae407b0e8 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Sun, 9 Jul 2023 18:44:13 +0800 Subject: [PATCH] Fix UT failures and add UT for search connector transport action Signed-off-by: zane-neo --- .../TransportRegisterModelAction.java | 12 +- .../DeleteConnectorTransportAction.java | 13 +- .../remote/GetConnectorTransportAction.java | 3 +- .../SearchConnectorTransportAction.java | 9 +- .../ml/breaker/MemoryCircuitBreaker.java | 4 +- .../ml/model/MLModelCacheHelper.java | 15 ++- .../opensearch/ml/model/MLModelManager.java | 26 +--- .../SearchModelTransportActionTests.java | 14 +- .../TransportRegisterModelActionTests.java | 12 +- .../SearchConnectorTransportActionTests.java | 120 ++++++++++++++++++ .../TransportCreateConnectorActionTests.java | 15 ++- .../MLModelAutoReDeployerTests.java | 8 +- .../ml/cluster/DiscoveryNodeHelperTests.java | 27 +++- .../ml/task/MLTaskDispatcherTests.java | 26 +++- .../org/opensearch/ml/utils/TestHelper.java | 13 +- 15 files changed, 248 insertions(+), 69 deletions(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/action/remote/SearchConnectorTransportActionTests.java diff --git a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java index a9dca29939..d7de1f0a79 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/register/TransportRegisterModelAction.java @@ -15,7 +15,6 @@ import java.time.Instant; import java.util.Arrays; import java.util.List; -import java.util.regex.Matcher; import java.util.regex.Pattern; import org.apache.logging.log4j.util.Strings; @@ -155,13 +154,18 @@ private void doRegister(MLRegisterModelInput registerModelInput, ActionListener< if (Boolean.TRUE.equals(r)) { registerModel(registerModelInput, listener); } else { - listener.onFailure(new IllegalArgumentException("You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId())); + listener + .onFailure( + new IllegalArgumentException( + "You don't have permission to use the connector provided, connector id: " + + registerModelInput.getConnectorId() + ) + ); } }, e -> { log .error( - "You don't have permission to use the connector provided, connector id: " - + registerModelInput.getConnectorId(), + "You don't have permission to use the connector provided, connector id: " + registerModelInput.getConnectorId(), e ); listener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java index 0a18bc92ff..36c5aeb341 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/remote/DeleteConnectorTransportAction.java @@ -72,8 +72,17 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to delete ML connector: " + connectorId, e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java index bc9f9782d3..854623db7e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/remote/GetConnectorTransportAction.java @@ -87,7 +87,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (e instanceof IndexNotFoundException) { diff --git a/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java index 17ed3c0430..70d349e06b 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/remote/SearchConnectorTransportAction.java @@ -5,6 +5,10 @@ package org.opensearch.ml.action.remote; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + import org.opensearch.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -15,7 +19,6 @@ import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.authuser.User; import org.opensearch.ml.common.CommonValue; -import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.transport.connector.MLConnectorSearchAction; import org.opensearch.ml.helper.ConnectorAccessControlHelper; @@ -27,10 +30,6 @@ import lombok.extern.log4j.Log4j2; -import java.util.Arrays; -import java.util.List; -import java.util.stream.Collectors; - @Log4j2 public class SearchConnectorTransportAction extends HandledTransportAction { diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java index 6697ab4bf3..8b467c2452 100644 --- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java +++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java @@ -5,12 +5,12 @@ package org.opensearch.ml.breaker; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; + import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.monitor.jvm.JvmService; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD; - /** * A circuit breaker for memory usage. */ diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java index 512dfc8920..74dbc26d61 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelCacheHelper.java @@ -157,12 +157,15 @@ public String[] getDeployedModels() { */ public String[] getLocalDeployedModels() { return modelCaches - .entrySet() - .stream() - .filter(entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED && entry.getValue().getFunctionName() != FunctionName.REMOTE)) - .map(entry -> entry.getKey()) - .collect(Collectors.toList()) - .toArray(new String[0]); + .entrySet() + .stream() + .filter( + entry -> (entry.getValue().getModelState() == MLModelState.DEPLOYED + && entry.getValue().getFunctionName() != FunctionName.REMOTE) + ) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()) + .toArray(new String[0]); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index e0ea78fcdd..a41c1037f8 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -34,10 +34,10 @@ import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly; import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL; import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MASTER_SECRET_KEY; import static org.opensearch.ml.stats.ActionName.REGISTER; import static org.opensearch.ml.stats.MLActionLevelStat.ML_ACTION_REQUEST_COUNT; import static org.opensearch.ml.utils.MLExceptionUtils.logException; @@ -105,7 +105,6 @@ import org.opensearch.ml.engine.MLExecutable; import org.opensearch.ml.engine.ModelHelper; import org.opensearch.ml.engine.Predictable; -import org.opensearch.ml.engine.exceptions.MetaDataException; import org.opensearch.ml.engine.utils.FileUtils; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.profile.MLModelProfile; @@ -209,12 +208,10 @@ public MLModelManager( .addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it); this.masterKey = ML_COMMONS_MASTER_SECRET_KEY.get(settings); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> { - masterKey = it; - mlEngine.setMasterKey(masterKey); - }); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MASTER_SECRET_KEY, it -> { + masterKey = it; + mlEngine.setMasterKey(masterKey); + }); } public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener listener) { @@ -371,11 +368,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa } } - private void indexRemoteModel( - MLRegisterModelInput registerModelInput, - MLTask mlTask, - String modelVersion - ) { + private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) { String taskId = mlTask.getTaskId(); FunctionName functionName = mlTask.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -435,11 +428,7 @@ private void indexRemoteModel( } } - private void uploadModel( - MLRegisterModelInput registerModelInput, - MLTask mlTask, - String modelVersion - ) throws PrivilegedActionException { + private void uploadModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) throws PrivilegedActionException { if (registerModelInput.getUrl() != null) { registerModelFromUrl(registerModelInput, mlTask, modelVersion); } else if (registerModelInput.getFunctionName() == FunctionName.REMOTE || registerModelInput.getConnectorId() != null) { @@ -1164,5 +1153,4 @@ public boolean isModelRunningOnNode(String modelId) { return modelCacheHelper.isModelRunningOnNode(modelId); } - } diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index 8bc6a306c4..db5371ae50 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -16,7 +16,6 @@ import static org.mockito.Mockito.when; import java.io.IOException; -import java.util.Map; import org.apache.lucene.search.TotalHits; import org.junit.Before; @@ -104,7 +103,18 @@ public void setup() { Metadata metadata = mock(Metadata.class); when(metadata.hasIndex(anyString())).thenReturn(true); - ClusterState testState = new ClusterState(new ClusterName("mock"), 123l, "111111", metadata, null, null, null, ImmutableOpenMap.of(), 0, false); + ClusterState testState = new ClusterState( + new ClusterName("mock"), + 123l, + "111111", + metadata, + null, + null, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(testState); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java index f60a13c80f..ebd3d9e98a 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/register/TransportRegisterModelActionTests.java @@ -18,7 +18,8 @@ import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX; import static org.opensearch.ml.utils.TestHelper.clusterSetting; -import com.google.common.collect.ImmutableList; +import java.util.List; + import org.junit.Before; import org.junit.Rule; import org.junit.rules.ExpectedException; @@ -62,7 +63,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.util.List; +import com.google.common.collect.ImmutableList; public class TransportRegisterModelActionTests extends OpenSearchTestCase { @Rule @@ -127,11 +128,8 @@ public class TransportRegisterModelActionTests extends OpenSearchTestCase { private String trustedUrlRegex = "^(https?|ftp|file)://[-a-zA-Z0-9+&@#/%?=~_|!:,.;]*[-a-zA-Z0-9+&@#/%=~_|]"; - private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of( - "^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", - "^https://api\\.openai\\.com/.*$", - "^https://api\\.cohere\\.ai/.*$" - ); + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList + .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Mock private ModelAccessControlHelper modelAccessControlHelper; diff --git a/plugin/src/test/java/org/opensearch/ml/action/remote/SearchConnectorTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/remote/SearchConnectorTransportActionTests.java new file mode 100644 index 0000000000..74d7b83afc --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/remote/SearchConnectorTransportActionTests.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.remote; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.ml.helper.ConnectorAccessControlHelper; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SearchConnectorTransportActionTests extends OpenSearchTestCase { + + @Mock + Client client; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + SearchRequest searchRequest; + + SearchSourceBuilder searchSourceBuilder; + + @Mock + FetchSourceContext fetchSourceContext; + + @Mock + ActionListener actionListener; + + @Mock + ThreadPool threadPool; + + @Mock + private Task task; + SearchConnectorTransportAction searchConnectorTransportAction; + ThreadContext threadContext; + @Mock + private ConnectorAccessControlHelper connectorAccessControlHelper; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + searchConnectorTransportAction = new SearchConnectorTransportAction(transportService, actionFilters, client, connectorAccessControlHelper); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + + searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.fetchSource(fetchSourceContext); + when(searchRequest.source()).thenReturn(searchSourceBuilder); + when(fetchSourceContext.includes()).thenReturn(new String[]{}); + when(fetchSourceContext.excludes()).thenReturn(new String[]{}); + } + + public void test_doExecute_connectorAccessControlNotEnabled_searchSuccess() { + String userString = "admin|role-1|all_access"; + threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userString); + when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(true); + SearchResponse searchResponse = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + + public void test_doExecute_connectorAccessControlEnabled_searchSuccess() { + when(connectorAccessControlHelper.skipConnectorAccessControl(any(User.class))).thenReturn(false); + SearchResponse searchResponse = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(searchResponse); + return null; + }).when(client).search(any(SearchRequest.class), isA(ActionListener.class)); + searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + verify(actionListener).onResponse(any(SearchResponse.class)); + } + + public void test_doExecute_exception() { + when(searchRequest.source()).thenThrow(new RuntimeException("runtime exception")); + searchConnectorTransportAction.doExecute(task, searchRequest, actionListener); + verify(actionListener).onFailure(any(RuntimeException.class)); + } + + +} diff --git a/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java index 612bd6ee71..85fc567e6e 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/remote/TransportCreateConnectorActionTests.java @@ -93,16 +93,16 @@ public class TransportCreateConnectorActionTests extends OpenSearchTestCase { private Settings settings; - private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList.of( - "^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", - "^https://api\\.openai\\.com/.*$", - "^https://api\\.cohere\\.ai/.*$" - ); + private static final List TRUSTED_CONNECTOR_ENDPOINTS_REGEXES = ImmutableList + .of("^https://runtime\\.sagemaker\\..*\\.amazonaws\\.com/.*$", "^https://api\\.openai\\.com/.*$", "^https://api\\.cohere\\.ai/.*$"); @Before public void setup() { MockitoAnnotations.openMocks(this); - settings = Settings.builder().putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES).build(); + settings = Settings + .builder() + .putList(ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX.getKey(), TRUSTED_CONNECTOR_ENDPOINTS_REGEXES) + .build(); ClusterSettings clusterSettings = clusterSetting( settings, ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, @@ -140,7 +140,8 @@ public void setup() { Map parameters = ImmutableMap.of("endpoint", "api.openai.com"); Map credential = ImmutableMap.of("access_key", "mockKey", "secret_key", "mockSecret"); - input = MLCreateConnectorInput.builder() + input = MLCreateConnectorInput + .builder() .actions(actions) .parameters(parameters) .protocol(ConnectorProtocols.HTTP) diff --git a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java index f566593cf3..cdb50648dc 100644 --- a/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/autoredeploy/MLModelAutoReDeployerTests.java @@ -18,11 +18,8 @@ import java.nio.file.Path; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.function.Consumer; - -import com.google.common.collect.ImmutableMap; import org.apache.lucene.search.TotalHits; import org.junit.Before; import org.mockito.Mock; @@ -635,7 +632,10 @@ private MLModel buildModelWithJsonFile(String file) throws Exception { private void mockClusterDataNodes(ClusterService clusterService) { ClusterState clusterState = mock(ClusterState.class); DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); - ImmutableOpenMap dataNodes = ImmutableOpenMap.builder().fPut("dataNodeId", mock(DiscoveryNode.class)).build(); + ImmutableOpenMap dataNodes = ImmutableOpenMap + .builder() + .fPut("dataNodeId", mock(DiscoveryNode.class)) + .build(); when(discoveryNodes.getDataNodes()).thenReturn(dataNodes); when(discoveryNodes.getSize()).thenReturn(2); // a ml node join cluster. when(clusterState.nodes()).thenReturn(discoveryNodes); diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java index 6485964682..83f283b89b 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java @@ -17,7 +17,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.HashSet; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -130,7 +129,18 @@ public void setup() throws IOException { .add(mlNode1) .add(mlNode2) .build(); - clusterState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + clusterState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(clusterState); discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); @@ -159,7 +169,18 @@ public void testGetEligibleNodes_DataNode() { mockSettings(false, nonExistingNodeName); DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); DiscoveryNodes nodes = DiscoveryNodes.builder().add(clusterManagerNode).add(dataNode1).add(dataNode2).add(warmDataNode1).build(); - clusterState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + clusterState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(clusterState); DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(); diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java index 9a81100f57..32d9b921f7 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java @@ -83,7 +83,18 @@ public void setup() { Set mlRoleSet = ImmutableSet.of(ML_ROLE); mlNode = new DiscoveryNode("mlNode", buildNewFakeTransportAddress(), new HashMap<>(), mlRoleSet, Version.CURRENT); DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).build(); - testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + testState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(testState); doAnswer(invocation -> { @@ -146,7 +157,18 @@ public void testGetEligibleNodes_DataNodeOnly() { @Ignore public void testGetEligibleNodes_MlAndDataNodes() { DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).add(mlNode).build(); - testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, ImmutableOpenMap.of(), 0, false); + testState = new ClusterState( + new ClusterName(clusterName), + 123l, + "111111", + null, + null, + nodes, + null, + ImmutableOpenMap.of(), + 0, + false + ); when(clusterService.state()).thenReturn(testState); DiscoveryNode[] eligibleNodes = nodeHelper.getEligibleNodes(); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index dff7d45486..8fb98de50c 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -64,7 +64,6 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.index.Index; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataType; import org.opensearch.ml.common.dataset.SearchQueryInputDataset; @@ -321,7 +320,10 @@ public static ClusterState state( final Settings.Builder existingSettings = Settings.builder().put(indexSettings).put(IndexMetadata.SETTING_INDEX_UUID, "test2UUID"); IndexMetadata indexMetaData = IndexMetadata.builder(indexName).settings(existingSettings).putMapping(mapping).build(); - final ImmutableOpenMap indices = ImmutableOpenMap.builder().fPut(indexName, indexMetaData).build(); + final ImmutableOpenMap indices = ImmutableOpenMap + .builder() + .fPut(indexName, indexMetaData) + .build(); return ClusterState.builder(name).metadata(Metadata.builder().indices(indices).build()).build(); } @@ -373,10 +375,11 @@ public static ClusterState setupTestClusterState() { .put("index.version.created", Version.CURRENT.id) ) .build(); - ImmutableOpenMap indices = ImmutableOpenMap.builder().fPut(ML_MODEL_INDEX, indexMetadata).build(); - Metadata metadata = new Metadata.Builder() - .indices(indices) + ImmutableOpenMap indices = ImmutableOpenMap + .builder() + .fPut(ML_MODEL_INDEX, indexMetadata) .build(); + Metadata metadata = new Metadata.Builder().indices(indices).build(); return new ClusterState( new ClusterName("test cluster"), 123l,