Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
)

* system error handling

Signed-off-by: Jing Zhang <[email protected]>

* remove NullPointerException from client error

Signed-off-by: Jing Zhang <[email protected]>

* replace jsonobject with ObjectMapper

Signed-off-by: Jing Zhang <[email protected]>

* add more UT

Signed-off-by: Jing Zhang <[email protected]>

* fix format issue

Signed-off-by: Jing Zhang <[email protected]>

* spotless

Signed-off-by: Jing Zhang <[email protected]>

---------

Signed-off-by: Jing Zhang <[email protected]>
(cherry picked from commit c2a1d82)

Co-authored-by: Jing Zhang <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and jngz-es authored Feb 6, 2024
1 parent 8c1a85f commit a9af85f
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.ml.rest;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG;
Expand All @@ -17,17 +19,23 @@
import java.util.Locale;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.error.ErrorMessageFactory;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.action.RestToXContentListener;

Expand Down Expand Up @@ -62,7 +70,28 @@ public List<Route> routes() {
@Override
public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(request);
return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new RestToXContentListener<>(channel));

return channel -> client.execute(MLExecuteTaskAction.INSTANCE, mlExecuteTaskRequest, new ActionListener<>() {
@Override
public void onResponse(MLExecuteTaskResponse response) {
try {
sendResponse(channel, response);
} catch (Exception e) {
reportError(channel, e, INTERNAL_SERVER_ERROR);
}
}

@Override
public void onFailure(Exception e) {
RestStatus status;
if (isClientError(e)) {
status = BAD_REQUEST;
} else {
status = INTERNAL_SERVER_ERROR;
}
reportError(channel, e, status);
}
});
}

/**
Expand Down Expand Up @@ -95,4 +124,16 @@ MLExecuteTaskRequest getRequest(RestRequest request) throws IOException {

return new MLExecuteTaskRequest(functionName, input);
}

private void sendResponse(RestChannel channel, MLExecuteTaskResponse response) throws Exception {
channel.sendResponse(new RestToXContentListener<MLExecuteTaskResponse>(channel).buildResponse(response));
}

private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
channel.sendResponse(new BytesRestResponse(status, ErrorMessageFactory.createErrorMessage(e, status.getStatus()).toString()));
}

private boolean isClientError(Exception e) {
return e instanceof IllegalArgumentException || e instanceof IllegalAccessException;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils.error;

import java.util.HashMap;
import java.util.Map;

import org.opensearch.core.rest.RestStatus;

import com.fasterxml.jackson.databind.ObjectMapper;

import lombok.Getter;
import lombok.SneakyThrows;

/** Error Message. */
public class ErrorMessage {

protected Throwable exception;

private final int status;

@Getter
private final String type;

@Getter
private final String reason;

@Getter
private final String details;

/** Error Message Constructor. */
public ErrorMessage(Throwable exception, int status) {
this.exception = exception;
this.status = status;

this.type = fetchType();
this.reason = fetchReason();
this.details = fetchDetails();
}

private String fetchType() {
return exception.getClass().getSimpleName();
}

protected String fetchReason() {
return status == RestStatus.BAD_REQUEST.getStatus() ? "Invalid Request" : "System Error";
}

protected String fetchDetails() {
// Some exception prints internal information (full class name) which is security concern
return emptyStringIfNull(exception.getLocalizedMessage());
}

private String emptyStringIfNull(String str) {
return str != null ? str : "";
}

@SneakyThrows
@Override
public String toString() {
ObjectMapper objectMapper = new ObjectMapper();
Map<String, Object> errorContent = new HashMap<>();
errorContent.put("type", type);
errorContent.put("reason", reason);
errorContent.put("details", details);
Map<String, Object> errMessage = new HashMap<>();
errMessage.put("status", status);
errMessage.put("error", errorContent);

return objectMapper.writeValueAsString(errMessage);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils.error;

import org.opensearch.OpenSearchException;

import lombok.experimental.UtilityClass;

@UtilityClass
public class ErrorMessageFactory {
/**
* Create error message based on the exception type.
*
* @param e exception to create error message
* @param status exception status code
* @return error message
*/
public static ErrorMessage createErrorMessage(Throwable e, int status) {
Throwable t = e;
int st = status;
if (t instanceof OpenSearchException) {
st = ((OpenSearchException) t).status().getStatus();
} else {
t = unwrapCause(e);
}

return new ErrorMessage(t, st);
}

protected static Throwable unwrapCause(Throwable t) {
Throwable result = t;
if (result instanceof OpenSearchException) {
return result;
}
if (result.getCause() == null) {
return result;
}
result = unwrapCause(result.getCause());
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -33,9 +36,11 @@
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -132,6 +137,69 @@ public void testPrepareRequest() throws Exception {
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequest1() throws Exception {
doNothing().when(channel).sendResponse(isA(RestResponse.class));
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(new MLExecuteTaskResponse(FunctionName.LOCAL_SAMPLE_CALCULATOR, null));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequest2() throws Exception {
doThrow(new IllegalArgumentException("input error")).when(channel).sendResponse(isA(RestResponse.class));
doNothing().when(channel).sendResponse(isA(BytesRestResponse.class));
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onResponse(new MLExecuteTaskResponse(FunctionName.LOCAL_SAMPLE_CALCULATOR, null));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequestClientError() throws Exception {
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new IllegalArgumentException("input error"));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequestSystemError() throws Exception {
doAnswer(invocation -> {
ActionListener<MLExecuteTaskResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException("system error"));
return null;
}).when(client).execute(eq(MLExecuteTaskAction.INSTANCE), any(), any());
RestRequest request = getLocalSampleCalculatorRestRequest();
restMLExecuteAction.handleRequest(request, channel, client);

ArgumentCaptor<MLExecuteTaskRequest> argumentCaptor = ArgumentCaptor.forClass(MLExecuteTaskRequest.class);
verify(client, times(1)).execute(eq(MLExecuteTaskAction.INSTANCE), argumentCaptor.capture(), any());
Input input = argumentCaptor.getValue().getInput();
assertEquals(FunctionName.LOCAL_SAMPLE_CALCULATOR, input.getFunctionName());
}

public void testPrepareRequest_disabled() {
RestRequest request = getExecuteAgentRestRequest();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.utils.error;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

import org.junit.Test;
import org.opensearch.OpenSearchException;
import org.opensearch.core.rest.RestStatus;

public class ErrorMessageFactoryTests {

private Throwable nonOpenSearchThrowable = new Throwable();
private Throwable openSearchThrowable = new OpenSearchException(nonOpenSearchThrowable);

@Test
public void openSearchExceptionShouldCreateEsErrorMessage() {
Exception exception = new OpenSearchException(nonOpenSearchThrowable);
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
}

@Test
public void nonOpenSearchExceptionShouldCreateGenericErrorMessage() {
Exception exception = new Exception(nonOpenSearchThrowable);
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertFalse(msg.exception instanceof OpenSearchException);
}

@Test
public void nonOpenSearchExceptionWithWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
Exception exception = (Exception) openSearchThrowable;
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
}

@Test
public void nonOpenSearchExceptionWithMultiLayerWrappedEsExceptionCauseShouldCreateEsErrorMessage() {
Exception exception = new Exception(new Throwable(new Throwable(openSearchThrowable)));
ErrorMessage msg = ErrorMessageFactory.createErrorMessage(exception, RestStatus.BAD_REQUEST.getStatus());
assertTrue(msg.exception instanceof OpenSearchException);
}
}
Loading

0 comments on commit a9af85f

Please sign in to comment.