Skip to content

Commit

Permalink
Fix failure UTs
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Jul 10, 2023
1 parent 7cbe26a commit 65e31b1
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

package org.opensearch.ml.action.connector;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;

import org.opensearch.action.ActionListener;
Expand Down Expand Up @@ -58,12 +60,16 @@ protected void doExecute(Task task, SearchRequest request, ActionListener<Search
private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
FetchSourceContext fetchSourceContext = request.source().fetchSource();
List<String> excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList());
List<String> excludes = Optional.ofNullable(request.source())
.map(SearchSourceBuilder::fetchSource)
.map(FetchSourceContext::excludes)
.map(x -> Arrays.stream(x)
.collect(Collectors.toList()))
.orElse(new ArrayList<>());
excludes.add(HttpConnector.CREDENTIAL_FIELD);
FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext(
fetchSourceContext.fetchSource(),
fetchSourceContext.includes(),
Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::fetchSource).orElse(true),
Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::includes).orElse(null),
excludes.toArray(new String[0])
);
request.source().fetchSource(rebuiltFetchSourceContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ public void search(SearchRequest request, ActionListener<SearchResponse> actionL
User user = RestActionUtils.getUserContext(client);
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search model version");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
FetchSourceContext fetchSourceContext = request.source().fetchSource();
List<String> excludes = Arrays.stream(fetchSourceContext.excludes()).collect(Collectors.toList());
List<String> excludes = Optional.ofNullable(request.source())
.map(SearchSourceBuilder::fetchSource)
.map(FetchSourceContext::excludes)
.map(x -> Arrays.stream(x)
.collect(Collectors.toList()))
.orElse(new ArrayList<>());
excludes.add(MLModel.CONNECTOR_FIELD + "." + HttpConnector.CREDENTIAL_FIELD);
FetchSourceContext rebuiltFetchSourceContext = new FetchSourceContext(
fetchSourceContext.fetchSource(),
fetchSourceContext.includes(),
Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::fetchSource).orElse(true),
Optional.ofNullable(request.source()).map(SearchSourceBuilder::fetchSource).map(FetchSourceContext::includes).orElse(null),
excludes.toArray(new String[0])
);
request.source().fetchSource(rebuiltFetchSourceContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class SearchModelITTests extends MLCommonsIntegTestCase {
public ExpectedException exceptionRule = ExpectedException.none();

private static final String PRE_BUILD_MODEL_URL =
"https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/msmarco-distilbert-base-tas-b/1.0.1/torch_script/sentence-transformers_msmarco-distilbert-base-tas-b-1.0.1-torch_script.zip";
"https://artifacts.opensearch.org/models/ml-models/huggingface/sentence-transformers/all-MiniLM-L6-v2/1.0.1/torch_script/sentence-transformers_all-MiniLM-L6-v2-1.0.1-torch_script.zip";

private String modelGroupId;

Expand Down Expand Up @@ -70,7 +70,7 @@ private void registerModelVersion() throws InterruptedException {
.modelFormat(MLModelFormat.TORCH_SCRIPT)
.modelConfig(modelConfig)
.url(PRE_BUILD_MODEL_URL)
.hashValue("acdc81b652b83121f914c5912ae27c0fca8fabf270e6f191ace6979a19830413")
.hashValue("c15f0d2e62d872be5b5bc6c84d2e0f4921541e29fefbef51d59cc10a8ae30e0f")
.description("mock model desc")
.build();
MLRegisterModelRequest registerModelRequest = new MLRegisterModelRequest(input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
Expand Down Expand Up @@ -84,6 +85,9 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase {
@Mock
private ClusterService clusterService;

@Mock
private FetchSourceContext fetchSourceContext;

@Rule
public ExpectedException thrown = ExpectedException.none();

Expand All @@ -98,6 +102,9 @@ public void setup() {
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

when(fetchSourceContext.includes()).thenReturn(new String[]{});
when(fetchSourceContext.excludes()).thenReturn(new String[]{});
searchSourceBuilder.fetchSource(fetchSourceContext);
when(searchRequest.source()).thenReturn(searchSourceBuilder);
when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void testPrepareRequest() throws Exception {
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}",
searchRequest.source().toString()
);
RestResponse restResponse = responseCaptor.getValue();
Expand Down Expand Up @@ -184,7 +184,7 @@ public void testPrepareRequest_timeout() throws Exception {
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_MODEL_INDEX }, indices);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}",
searchRequest.source().toString()
);
RestResponse restResponse = responseCaptor.getValue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void testPrepareRequest() throws Exception {
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}",
searchRequest.source().toString()
);
RestResponse restResponse = responseCaptor.getValue();
Expand Down Expand Up @@ -184,7 +184,7 @@ public void testPrepareRequest_timeout() throws Exception {
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_MODEL_GROUP_INDEX }, indices);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"model_content\",\"ui_metadata\",\"content\"]}}",
searchRequest.source().toString()
);
RestResponse restResponse = responseCaptor.getValue();
Expand Down

0 comments on commit 65e31b1

Please sign in to comment.