diff --git a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java index 28fad39a0..1b83f032a 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/RestHighLevelClientWrapper.java @@ -38,6 +38,7 @@ import java.io.IOException; import org.opensearch.flint.core.storage.BulkRequestRateLimiter; +import org.opensearch.flint.core.storage.OpenSearchBulkRetryWrapper; import static org.opensearch.flint.core.metrics.MetricConstants.OS_READ_OP_METRIC_PREFIX; import static org.opensearch.flint.core.metrics.MetricConstants.OS_WRITE_OP_METRIC_PREFIX; @@ -49,6 +50,7 @@ public class RestHighLevelClientWrapper implements IRestHighLevelClient { private final RestHighLevelClient client; private final BulkRequestRateLimiter rateLimiter; + private final OpenSearchBulkRetryWrapper bulkRetryWrapper; private final static JacksonJsonpMapper JACKSON_MAPPER = new JacksonJsonpMapper(); @@ -57,9 +59,10 @@ public class RestHighLevelClientWrapper implements IRestHighLevelClient { * * @param client the RestHighLevelClient instance to wrap */ - public RestHighLevelClientWrapper(RestHighLevelClient client, BulkRequestRateLimiter rateLimiter) { + public RestHighLevelClientWrapper(RestHighLevelClient client, BulkRequestRateLimiter rateLimiter, OpenSearchBulkRetryWrapper bulkRetryWrapper) { this.client = client; this.rateLimiter = rateLimiter; + this.bulkRetryWrapper = bulkRetryWrapper; } @Override @@ -67,7 +70,7 @@ public BulkResponse bulk(BulkRequest bulkRequest, RequestOptions options) throws return execute(OS_WRITE_OP_METRIC_PREFIX, () -> { try { rateLimiter.acquirePermit(); - return client.bulk(bulkRequest, options); + return bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); } catch (InterruptedException e) { throw new RuntimeException("rateLimiter.acquirePermit was interrupted.", e); } diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java b/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java index bf5352f3b..8f6e2c07e 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/http/FlintRetryOptions.java @@ -8,10 +8,13 @@ import static java.time.temporal.ChronoUnit.SECONDS; import dev.failsafe.RetryPolicy; +import dev.failsafe.event.ExecutionAttemptedEvent; +import dev.failsafe.function.CheckedPredicate; import java.time.Duration; import java.util.Map; import java.util.Optional; import java.util.logging.Logger; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.flint.core.http.handler.ExceptionClassNameFailurePredicate; import org.opensearch.flint.core.http.handler.HttpStatusCodeResultPredicate; import java.io.Serializable; @@ -71,13 +74,33 @@ public RetryPolicy getRetryPolicy() { .handleIf(ExceptionClassNameFailurePredicate.create(getRetryableExceptionClassNames())) .handleResultIf(new HttpStatusCodeResultPredicate<>(getRetryableHttpStatusCodes())) // Logging listener - .onFailedAttempt(event -> - LOG.severe("Attempt to execute request failed: " + event)) - .onRetry(ex -> - LOG.warning("Retrying failed request at #" + ex.getAttemptCount())) + .onFailedAttempt(FlintRetryOptions::onFailure) + .onRetry(FlintRetryOptions::onRetry) .build(); } + public RetryPolicy getBulkRetryPolicy(CheckedPredicate resultPredicate) { + return RetryPolicy.builder() + // Using higher initial backoff to mitigate throttling quickly + .withBackoff(4, 30, SECONDS) + .withJitter(Duration.ofMillis(100)) + .withMaxRetries(getMaxRetries()) + // Do not retry on exception (will be handled by the other retry policy + .handleIf((ex) -> false) + .handleResultIf(resultPredicate) + .onFailedAttempt(FlintRetryOptions::onFailure) + .onRetry(FlintRetryOptions::onRetry) + .build(); + } + + private static void onFailure(ExecutionAttemptedEvent event) { + LOG.severe("Attempt to execute request failed: " + event); + } + + private static void onRetry(ExecutionAttemptedEvent event) { + LOG.warning("Retrying failed request at #" + event.getAttemptCount()); + } + /** * @return maximum retry option value */ diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapper.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapper.java new file mode 100644 index 000000000..279c9b642 --- /dev/null +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapper.java @@ -0,0 +1,105 @@ +package org.opensearch.flint.core.storage; + +import dev.failsafe.Failsafe; +import dev.failsafe.FailsafeException; +import dev.failsafe.RetryPolicy; +import dev.failsafe.function.CheckedPredicate; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.logging.Logger; +import org.opensearch.action.DocWriteRequest; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.flint.core.http.FlintRetryOptions; +import org.opensearch.rest.RestStatus; + +public class OpenSearchBulkRetryWrapper { + + private static final Logger LOG = Logger.getLogger(OpenSearchBulkRetryWrapper.class.getName()); + + private final RetryPolicy retryPolicy; + + public OpenSearchBulkRetryWrapper(FlintRetryOptions retryOptions) { + this.retryPolicy = retryOptions.getBulkRetryPolicy(bulkItemRetryableResultPredicate); + } + + /** + * Delegate bulk request to the client, and retry the request if the response contains retryable + * failure. It won't retry when bulk call thrown exception. + * @param client used to call bulk API + * @param bulkRequest requests passed to bulk method + * @param options options passed to bulk method + * @return Last result + */ + public BulkResponse bulkWithPartialRetry(RestHighLevelClient client, BulkRequest bulkRequest, + RequestOptions options) { + try { + final AtomicReference nextRequest = new AtomicReference<>(bulkRequest); + return Failsafe + .with(retryPolicy) + .get(() -> { + BulkResponse response = client.bulk(nextRequest.get(), options); + if (retryPolicy.getConfig().allowsRetries() && bulkItemRetryableResultPredicate.test( + response)) { + nextRequest.set(getRetryableRequest(nextRequest.get(), response)); + } + return response; + }); + } catch (FailsafeException ex) { + LOG.severe("Request failed permanently. Re-throwing original exception."); + + // unwrap original exception and throw + throw new RuntimeException(ex.getCause()); + } + } + + private BulkRequest getRetryableRequest(BulkRequest request, BulkResponse response) { + List> bulkItemRequests = request.requests(); + BulkItemResponse[] bulkItemResponses = response.getItems(); + BulkRequest nextRequest = new BulkRequest() + .setRefreshPolicy(request.getRefreshPolicy()); + nextRequest.setParentTask(request.getParentTask()); + for (int i = 0; i < bulkItemRequests.size(); i++) { + if (isItemRetryable(bulkItemResponses[i])) { + verifyIdMatch(bulkItemRequests.get(i), bulkItemResponses[i]); + nextRequest.add(bulkItemRequests.get(i)); + } + } + LOG.info(String.format("Added %d requests to nextRequest", nextRequest.requests().size())); + return nextRequest; + } + + private static void verifyIdMatch(DocWriteRequest request, BulkItemResponse response) { + if (request.id() != null && !request.id().equals(response.getId())) { + throw new RuntimeException("id doesn't match: " + request.id() + " / " + response.getId()); + } + } + + /** + * A predicate to decide if a BulkResponse is retryable or not. + */ + private static final CheckedPredicate bulkItemRetryableResultPredicate = bulkResponse -> + bulkResponse.hasFailures() && isRetryable(bulkResponse); + + private static boolean isRetryable(BulkResponse bulkResponse) { + if (Arrays.stream(bulkResponse.getItems()) + .anyMatch(itemResp -> isItemRetryable(itemResp))) { + LOG.info("Found retryable failure in the bulk response"); + return true; + } + return false; + } + + private static boolean isItemRetryable(BulkItemResponse itemResponse) { + return itemResponse.isFailed() && !isCreateConflict(itemResponse); + } + + private static boolean isCreateConflict(BulkItemResponse itemResp) { + return itemResp.getOpType() == DocWriteRequest.OpType.CREATE && + itemResp.getFailure().getStatus() == RestStatus.CONFLICT; + } +} diff --git a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java index 0f80d07c9..8bf80e90e 100644 --- a/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java +++ b/flint-core/src/main/scala/org/opensearch/flint/core/storage/OpenSearchClientUtils.java @@ -69,7 +69,8 @@ public static RestHighLevelClient createRestHighLevelClient(FlintOptions options public static IRestHighLevelClient createClient(FlintOptions options) { return new RestHighLevelClientWrapper(createRestHighLevelClient(options), - BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(options)); + BulkRequestRateLimiterHolder.getBulkRequestRateLimiter(options), + new OpenSearchBulkRetryWrapper(options.getRetryOptions())); } /** diff --git a/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapperTest.java b/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapperTest.java new file mode 100644 index 000000000..fa57da842 --- /dev/null +++ b/flint-core/src/test/scala/org/opensearch/flint/core/storage/OpenSearchBulkRetryWrapperTest.java @@ -0,0 +1,160 @@ +package org.opensearch.flint.core.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.action.DocWriteRequest.OpType; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkItemResponse.Failure; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.client.RequestOptions; +import org.opensearch.client.RestHighLevelClient; +import org.opensearch.flint.core.http.FlintRetryOptions; +import org.opensearch.rest.RestStatus; + +@ExtendWith(MockitoExtension.class) +class OpenSearchBulkRetryWrapperTest { + + @Mock + BulkRequest bulkRequest; + @Mock + RequestOptions options; + @Mock + BulkResponse successResponse; + @Mock + BulkResponse failureResponse; + @Mock + BulkResponse conflictResponse; + @Mock + RestHighLevelClient client; + @Mock + DocWriteResponse docWriteResponse; + @Mock + IndexRequest indexRequest0, indexRequest1; + @Mock IndexRequest docWriteRequest2; +// BulkItemRequest[] bulkItemRequests = new BulkItemRequest[] { +// new BulkItemRequest(0, docWriteRequest0), +// new BulkItemRequest(1, docWriteRequest1), +// new BulkItemRequest(2, docWriteRequest2), +// }; + BulkItemResponse successItem = new BulkItemResponse(0, OpType.CREATE, docWriteResponse); + BulkItemResponse failureItem = new BulkItemResponse(0, OpType.CREATE, + new Failure("index", "id", null, + RestStatus.TOO_MANY_REQUESTS)); + BulkItemResponse conflictItem = new BulkItemResponse(0, OpType.CREATE, + new Failure("index", "id", null, + RestStatus.CONFLICT)); + + FlintRetryOptions retryOptionsWithRetry = new FlintRetryOptions(Map.of("retry.max_retries", "2")); + FlintRetryOptions retryOptionsWithoutRetry = new FlintRetryOptions( + Map.of("retry.max_retries", "0")); + + @Test + public void withRetryWhenCallSucceed() throws Exception { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(bulkRequest, options)).thenReturn(successResponse); + when(successResponse.hasFailures()).thenReturn(false); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, successResponse); + verify(client).bulk(bulkRequest, options); + } + + @Test + public void withRetryWhenCallConflict() throws Exception { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(any(), eq(options))) + .thenReturn(conflictResponse); + mockConflictResponse(); + when(conflictResponse.hasFailures()).thenReturn(true); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, conflictResponse); + verify(client).bulk(bulkRequest, options); + } + + @Test + public void withRetryWhenCallFailOnce() throws Exception { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(any(), eq(options))) + .thenReturn(failureResponse) + .thenReturn(successResponse); + mockFailureResponse(); + when(successResponse.hasFailures()).thenReturn(false); + when(bulkRequest.requests()).thenReturn(ImmutableList.of(indexRequest0, indexRequest1)); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, successResponse); + verify(client, times(2)).bulk(any(), eq(options)); + } + + @Test + public void withRetryWhenAllCallFail() throws Exception { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(any(), eq(options))) + .thenReturn(failureResponse); + mockFailureResponse(); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, failureResponse); + verify(client, times(3)).bulk(any(), eq(options)); + } + + @Test + public void withRetryWhenCallThrowsShouldNotRetry() throws Exception { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithRetry); + when(client.bulk(bulkRequest, options)).thenThrow(new RuntimeException("test")); + + assertThrows(RuntimeException.class, + () -> bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options)); + + verify(client).bulk(bulkRequest, options); + } + + @Test + public void withoutRetryWhenCallFail() throws Exception { + OpenSearchBulkRetryWrapper bulkRetryWrapper = new OpenSearchBulkRetryWrapper( + retryOptionsWithoutRetry); + when(client.bulk(bulkRequest, options)) + .thenReturn(failureResponse); + mockFailureResponse(); + + BulkResponse response = bulkRetryWrapper.bulkWithPartialRetry(client, bulkRequest, options); + + assertEquals(response, failureResponse); + verify(client).bulk(bulkRequest, options); + } + + private void mockFailureResponse() { + when(failureResponse.hasFailures()).thenReturn(true); + when(failureResponse.getItems()).thenReturn(new BulkItemResponse[]{successItem, failureItem}); + } + + private void mockConflictResponse() { + when(conflictResponse.hasFailures()).thenReturn(true); + when(conflictResponse.getItems()).thenReturn(new BulkItemResponse[]{successItem, conflictItem}); + } +}