From 65e31b1e33b371cc84d5f593917f7f98b3e651d2 Mon Sep 17 00:00:00 2001 From: zane-neo Date: Mon, 10 Jul 2023 16:55:27 +0800 Subject: [PATCH] Fix failure UTs Signed-off-by: zane-neo --- .../connector/SearchConnectorTransportAction.java | 14 ++++++++++---- .../ml/action/handler/MLSearchHandler.java | 12 ++++++++---- .../ml/action/models/SearchModelITTests.java | 4 ++-- .../models/SearchModelTransportActionTests.java | 7 +++++++ .../ml/rest/RestMLSearchModelActionTests.java | 4 ++-- .../ml/rest/RestMLSearchModelGroupActionTests.java | 4 ++-- 6 files changed, 31 insertions(+), 14 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java index f55291e60f..8eea4d5e04 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java @@ -5,8 +5,10 @@ package org.opensearch.ml.action.connector; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import org.opensearch.action.ActionListener; @@ -58,12 +60,16 @@ protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - FetchSourceContext fetchSourceContext = request.source().fetchSource(); - List excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList()); + List excludes = Optional.ofNullable(request.source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::excludes) + .map(x -> Arrays.stream(x) + .collect(Collectors.toList())) + .orElse(new ArrayList<>()); excludes.add(HttpConnector.CREDENTIAL_FIELD); FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext( - fetchSourceContext.fetchSource(), - fetchSourceContext.includes(), + Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::fetchSource).orElse(true), + Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::includes).orElse(null), excludes.toArray(new String[0]) ); request.source().fetchSource(rebuiltFetchSourceContext); diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index a205f17e30..ab3ea09a3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -80,12 +80,16 @@ public void search(SearchRequest request, ActionListener actionL User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search model version"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - FetchSourceContext fetchSourceContext = request.source().fetchSource(); - List excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList()); + List excludes = Optional.ofNullable(request.source()) + .map(SearchSourceBuilder::fetchSource) + .map(FetchSourceContext::excludes) + .map(x -> Arrays.stream(x) + .collect(Collectors.toList())) + .orElse(new ArrayList<>()); excludes.add(MLModel.CONNECTOR_FIELD + "." + HttpConnector.CREDENTIAL_FIELD); FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext( - fetchSourceContext.fetchSource(), - fetchSourceContext.includes(), + Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::fetchSource).orElse(true), + Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::includes).orElse(null), excludes.toArray(new String[0]) ); request.source().fetchSource(rebuiltFetchSourceContext); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java index 3421df9189..6a1889fde2 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelITTests.java @@ -33,7 +33,7 @@ public class SearchModelITTests extends MLCommonsIntegTestCase { public ExpectedException exceptionRule = ExpectedException.none(); private static final String PRE_BUILD_MODEL_URL = - "https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.1-torch_script.zip"; + "https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L6-v2-1.0.1-torch_script.zip"; private String modelGroupId; @@ -70,7 +70,7 @@ private void registerModelVersion() throws InterruptedException { .modelFormat(MLModelFormat.TORCH_SCRIPT) .modelConfig(modelConfig) .url(PRE_BUILD_MODEL_URL) - .hashValue("acdc81b652b83121f914c5912ae27c0fca8fabf270e6f191ace6979a19830413") + .hashValue("c15f0d2e62d872be5b5bc6c84d2e0f4921541e29fefbef51d59cc10a8ae30e0f") .description("mock model desc") .build(); MLRegisterModelRequest registerModelRequest = new MLRegisterModelRequest(input); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index a323a9c693..2a96714c56 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -46,6 +46,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; @@ -84,6 +85,9 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase { @Mock private ClusterService clusterService; + @Mock + private FetchSourceContext fetchSourceContext; + @Rule public ExpectedException thrown = ExpectedException.none(); @@ -98,6 +102,9 @@ public void setup() { when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); + when(fetchSourceContext.includes()).thenReturn(new String[]{}); + when(fetchSourceContext.excludes()).thenReturn(new String[]{}); + searchSourceBuilder.fetchSource(fetchSourceContext); when(searchRequest.source()).thenReturn(searchSourceBuilder); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java index baf254f9e6..31a49d3ea7 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelActionTests.java @@ -138,7 +138,7 @@ public void testPrepareRequest() throws Exception { String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); assertEquals( - "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}", searchRequest.source().toString() ); RestResponse restResponse = responseCaptor.getValue(); @@ -184,7 +184,7 @@ public void testPrepareRequest_timeout() throws Exception { String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices); assertEquals( - "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}", searchRequest.source().toString() ); RestResponse restResponse = responseCaptor.getValue(); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java index 09c8aebec3..209b5c2f1d 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLSearchModelGroupActionTests.java @@ -138,7 +138,7 @@ public void testPrepareRequest() throws Exception { String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices); assertEquals( - "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}", searchRequest.source().toString() ); RestResponse restResponse = responseCaptor.getValue(); @@ -184,7 +184,7 @@ public void testPrepareRequest_timeout() throws Exception { String[] indices = searchRequest.indices(); assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices); assertEquals( - "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}", + "{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}", searchRequest.source().toString() ); RestResponse restResponse = responseCaptor.getValue();