Skip to content

Commit

Permalink
Retry bulk request to OpenSearch (#572)
Browse files Browse the repository at this point in the history
* Add retry to bulk request

Signed-off-by: Tomoyuki Morita <[email protected]>

* Retry only failed items

Signed-off-by: Tomoyuki Morita <[email protected]>

* Address comments

Signed-off-by: Tomoyuki Morita <[email protected]>

* Fix isCreateConflict

Signed-off-by: Tomoyuki Morita <[email protected]>

* Add and fix unit tests

Signed-off-by: Tomoyuki Morita <[email protected]>

---------

Signed-off-by: Tomoyuki Morita <[email protected]>
  • Loading branch information
ykmr1224 authored Aug 22, 2024
1 parent b407a06 commit 3db16ec
Show file tree
Hide file tree
Showing 5 changed files with 299 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import java.io.IOException;
import org.opensearch.flint.core.storage.BulkRequestRateLimiter;
import org.opensearch.flint.core.storage.OpenSearchBulkRetryWrapper;

import static org.opensearch.flint.core.metrics.MetricConstants.OS_READ_OP_METRIC_PREFIX;
import static org.opensearch.flint.core.metrics.MetricConstants.OS_WRITE_OP_METRIC_PREFIX;
Expand All @@ -49,6 +50,7 @@
public class RestHighLevelClientWrapper implements IRestHighLevelClient {
private final RestHighLevelClient client;
private final BulkRequestRateLimiter rateLimiter;
private final OpenSearchBulkRetryWrapper bulkRetryWrapper;

private final static JacksonJsonpMapper JACKSON_MAPPER = new JacksonJsonpMapper();

Expand All @@ -57,17 +59,18 @@ public class RestHighLevelClientWrapper implements IRestHighLevelClient {
*
* @param client the RestHighLevelClient instance to wrap
*/
public RestHighLevelClientWrapper(RestHighLevelClient client, BulkRequestRateLimiter rateLimiter) {
public RestHighLevelClientWrapper(RestHighLevelClient client, BulkRequestRateLimiter rateLimiter, OpenSearchBulkRetryWrapper bulkRetryWrapper) {
this.client = client;
this.rateLimiter = rateLimiter;
this.bulkRetryWrapper = bulkRetryWrapper;
}

@Override
public BulkResponse bulk(BulkRequest bulkRequest, RequestOptions options) throws IOException {
return execute(OS_WRITE_OP_METRIC_PREFIX, () -> {
try {
rateLimiter.acquirePermit();
return client.bulk(bulkRequest, options);
return bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options);
} catch (InterruptedException e) {
throw new RuntimeException("rateLimiter.acquirePermit was interrupted.", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
import static java.time.temporal.ChronoUnit.SECONDS;

import dev.failsafe.RetryPolicy;
import dev.failsafe.event.ExecutionAttemptedEvent;
import dev.failsafe.function.CheckedPredicate;
import java.time.Duration;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.flint.core.http.handler.ExceptionClassNameFailurePredicate;
import org.opensearch.flint.core.http.handler.HttpStatusCodeResultPredicate;
import java.io.Serializable;
Expand Down Expand Up @@ -71,13 +74,33 @@ public <T> RetryPolicy<T> getRetryPolicy() {
.handleIf(ExceptionClassNameFailurePredicate.create(getRetryableExceptionClassNames()))
.handleResultIf(new HttpStatusCodeResultPredicate<>(getRetryableHttpStatusCodes()))
// Logging listener
.onFailedAttempt(event ->
LOG.severe("Attempt to execute request failed: " + event))
.onRetry(ex ->
LOG.warning("Retrying failed request at #" + ex.getAttemptCount()))
.onFailedAttempt(FlintRetryOptions::onFailure)
.onRetry(FlintRetryOptions::onRetry)
.build();
}

public RetryPolicy<BulkResponse> getBulkRetryPolicy(CheckedPredicate<BulkResponse> resultPredicate) {
return RetryPolicy.<BulkResponse>builder()
// Using higher initial backoff to mitigate throttling quickly
.withBackoff(4, 30, SECONDS)
.withJitter(Duration.ofMillis(100))
.withMaxRetries(getMaxRetries())
// Do not retry on exception (will be handled by the other retry policy
.handleIf((ex) -> false)
.handleResultIf(resultPredicate)
.onFailedAttempt(FlintRetryOptions::onFailure)
.onRetry(FlintRetryOptions::onRetry)
.build();
}

private static <T> void onFailure(ExecutionAttemptedEvent<T> event) {
LOG.severe("Attempt to execute request failed: " + event);
}

private static <T> void onRetry(ExecutionAttemptedEvent<T> event) {
LOG.warning("Retrying failed request at #" + event.getAttemptCount());
}

/**
* @return maximum retry option value
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package org.opensearch.flint.core.storage;

import dev.failsafe.Failsafe;
import dev.failsafe.FailsafeException;
import dev.failsafe.RetryPolicy;
import dev.failsafe.function.CheckedPredicate;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Logger;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.RestHighLevelClient;
import org.opensearch.flint.core.http.FlintRetryOptions;
import org.opensearch.rest.RestStatus;

public class OpenSearchBulkRetryWrapper {

private static final Logger LOG = Logger.getLogger(OpenSearchBulkRetryWrapper.class.getName());

private final RetryPolicy<BulkResponse> retryPolicy;

public OpenSearchBulkRetryWrapper(FlintRetryOptions retryOptions) {
this.retryPolicy = retryOptions.getBulkRetryPolicy(bulkItemRetryableResultPredicate);
}

/**
* Delegate bulk request to the client, and retry the request if the response contains retryable
* failure. It won't retry when bulk call thrown exception.
* @param client used to call bulk API
* @param bulkRequest requests passed to bulk method
* @param options options passed to bulk method
* @return Last result
*/
public BulkResponse bulkWithPartialRetry(RestHighLevelClient client, BulkRequest bulkRequest,
RequestOptions options) {
try {
final AtomicReference<BulkRequest> nextRequest = new AtomicReference<>(bulkRequest);
return Failsafe
.with(retryPolicy)
.get(() -> {
BulkResponse response = client.bulk(nextRequest.get(), options);
if (retryPolicy.getConfig().allowsRetries() && bulkItemRetryableResultPredicate.test(
response)) {
nextRequest.set(getRetryableRequest(nextRequest.get(), response));
}
return response;
});
} catch (FailsafeException ex) {
LOG.severe("Request failed permanently. Re-throwing original exception.");

// unwrap original exception and throw
throw new RuntimeException(ex.getCause());
}
}

private BulkRequest getRetryableRequest(BulkRequest request, BulkResponse response) {
List<DocWriteRequest<?>> bulkItemRequests = request.requests();
BulkItemResponse[] bulkItemResponses = response.getItems();
BulkRequest nextRequest = new BulkRequest()
.setRefreshPolicy(request.getRefreshPolicy());
nextRequest.setParentTask(request.getParentTask());
for (int i = 0; i < bulkItemRequests.size(); i++) {
if (isItemRetryable(bulkItemResponses[i])) {
verifyIdMatch(bulkItemRequests.get(i), bulkItemResponses[i]);
nextRequest.add(bulkItemRequests.get(i));
}
}
LOG.info(String.format("Added %d requests to nextRequest", nextRequest.requests().size()));
return nextRequest;
}

private static void verifyIdMatch(DocWriteRequest<?> request, BulkItemResponse response) {
if (request.id() != null && !request.id().equals(response.getId())) {
throw new RuntimeException("id doesn't match: " + request.id() + " / " + response.getId());
}
}

/**
* A predicate to decide if a BulkResponse is retryable or not.
*/
private static final CheckedPredicate<BulkResponse> bulkItemRetryableResultPredicate = bulkResponse ->
bulkResponse.hasFailures() && isRetryable(bulkResponse);

private static boolean isRetryable(BulkResponse bulkResponse) {
if (Arrays.stream(bulkResponse.getItems())
.anyMatch(itemResp -> isItemRetryable(itemResp))) {
LOG.info("Found retryable failure in the bulk response");
return true;
}
return false;
}

private static boolean isItemRetryable(BulkItemResponse itemResponse) {
return itemResponse.isFailed() && !isCreateConflict(itemResponse);
}

private static boolean isCreateConflict(BulkItemResponse itemResp) {
return itemResp.getOpType() == DocWriteRequest.OpType.CREATE &&
itemResp.getFailure().getStatus() == RestStatus.CONFLICT;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public static RestHighLevelClient createRestHighLevelClient(FlintOptions options

public static IRestHighLevelClient createClient(FlintOptions options) {
return new RestHighLevelClientWrapper(createRestHighLevelClient(options),
BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(options));
BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(options),
new OpenSearchBulkRetryWrapper(options.getRetryOptions()));
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package org.opensearch.flint.core.storage;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.action.DocWriteRequest.OpType;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkItemResponse.Failure;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.RestHighLevelClient;
import org.opensearch.flint.core.http.FlintRetryOptions;
import org.opensearch.rest.RestStatus;

@ExtendWith(MockitoExtension.class)
class OpenSearchBulkRetryWrapperTest {

@Mock
BulkRequest bulkRequest;
@Mock
RequestOptions options;
@Mock
BulkResponse successResponse;
@Mock
BulkResponse failureResponse;
@Mock
BulkResponse conflictResponse;
@Mock
RestHighLevelClient client;
@Mock
DocWriteResponse docWriteResponse;
@Mock
IndexRequest indexRequest0, indexRequest1;
@Mock IndexRequest docWriteRequest2;
// BulkItemRequest[] bulkItemRequests = new BulkItemRequest[] {
// new BulkItemRequest(0, docWriteRequest0),
// new BulkItemRequest(1, docWriteRequest1),
// new BulkItemRequest(2, docWriteRequest2),
// };
BulkItemResponse successItem = new BulkItemResponse(0, OpType.CREATE, docWriteResponse);
BulkItemResponse failureItem = new BulkItemResponse(0, OpType.CREATE,
new Failure("index", "id", null,
RestStatus.TOO_MANY_REQUESTS));
BulkItemResponse conflictItem = new BulkItemResponse(0, OpType.CREATE,
new Failure("index", "id", null,
RestStatus.CONFLICT));

FlintRetryOptions retryOptionsWithRetry = new FlintRetryOptions(Map.of("retry.max_retries", "2"));
FlintRetryOptions retryOptionsWithoutRetry = new FlintRetryOptions(
Map.of("retry.max_retries", "0"));

@Test
public void withRetryWhenCallSucceed() throws Exception {
OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper(
retryOptionsWithRetry);
when(client.bulk(bulkRequest, options)).thenReturn(successResponse);
when(successResponse.hasFailures()).thenReturn(false);

BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options);

assertEquals(response, successResponse);
verify(client).bulk(bulkRequest, options);
}

@Test
public void withRetryWhenCallConflict() throws Exception {
OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper(
retryOptionsWithRetry);
when(client.bulk(any(), eq(options)))
.thenReturn(conflictResponse);
mockConflictResponse();
when(conflictResponse.hasFailures()).thenReturn(true);

BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options);

assertEquals(response, conflictResponse);
verify(client).bulk(bulkRequest, options);
}

@Test
public void withRetryWhenCallFailOnce() throws Exception {
OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper(
retryOptionsWithRetry);
when(client.bulk(any(), eq(options)))
.thenReturn(failureResponse)
.thenReturn(successResponse);
mockFailureResponse();
when(successResponse.hasFailures()).thenReturn(false);
when(bulkRequest.requests()).thenReturn(ImmutableList.of(indexRequest0, indexRequest1));

BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options);

assertEquals(response, successResponse);
verify(client, times(2)).bulk(any(), eq(options));
}

@Test
public void withRetryWhenAllCallFail() throws Exception {
OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper(
retryOptionsWithRetry);
when(client.bulk(any(), eq(options)))
.thenReturn(failureResponse);
mockFailureResponse();

BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options);

assertEquals(response, failureResponse);
verify(client, times(3)).bulk(any(), eq(options));
}

@Test
public void withRetryWhenCallThrowsShouldNotRetry() throws Exception {
OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper(
retryOptionsWithRetry);
when(client.bulk(bulkRequest, options)).thenThrow(new RuntimeException("test"));

assertThrows(RuntimeException.class,
() -> bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options));

verify(client).bulk(bulkRequest, options);
}

@Test
public void withoutRetryWhenCallFail() throws Exception {
OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper(
retryOptionsWithoutRetry);
when(client.bulk(bulkRequest, options))
.thenReturn(failureResponse);
mockFailureResponse();

BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options);

assertEquals(response, failureResponse);
verify(client).bulk(bulkRequest, options);
}

private void mockFailureResponse() {
when(failureResponse.hasFailures()).thenReturn(true);
when(failureResponse.getItems()).thenReturn(new BulkItemResponse[]{successItem, failureItem});
}

private void mockConflictResponse() {
when(conflictResponse.hasFailures()).thenReturn(true);
when(conflictResponse.getItems()).thenReturn(new BulkItemResponse[]{successItem, conflictItem});
}
}

0 comments on commit 3db16ec

Please sign in to comment.