diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java index 1a9eaf99a5ad8..48b0e957a826f 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java +++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/AllocationService.java @@ -58,7 +58,6 @@ import org.opensearch.cluster.routing.allocation.decider.Decision; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; -import org.opensearch.common.util.concurrent.TimeoutAwareRunnable; import org.opensearch.gateway.GatewayAllocator; import org.opensearch.gateway.PriorityComparator; import org.opensearch.gateway.ShardsBatchGatewayAllocator; @@ -75,9 +74,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.Collectors; import static java.util.Collections.emptyList; @@ -627,26 +624,10 @@ private void allocateExistingUnassignedShards(RoutingAllocation allocation) { private void allocateAllUnassignedShards(RoutingAllocation allocation) { ExistingShardsAllocator allocator = existingShardsAllocators.get(ShardsBatchGatewayAllocator.ALLOCATOR_NAME); - executeTimedRunnables(allocator.allocateAllUnassignedShards(allocation, true), () -> allocator.getPrimaryBatchAllocatorTimeout().millis(), true); + allocator.allocateAllUnassignedShards(allocation, true).run(); allocator.afterPrimariesBeforeReplicas(allocation); // Replicas Assignment - executeTimedRunnables(allocator.allocateAllUnassignedShards(allocation, false), () -> allocator.getReplicaBatchAllocatorTimeout().millis(), false); - } - - private void executeTimedRunnables(List runnables, Supplier maxRunTimeSupplier, boolean primary) { - logger.info("Executing timed runnables for primary [{}] of size [{}]", primary, runnables.size()); - Collections.shuffle(runnables); - long startTime = System.nanoTime(); - for (TimeoutAwareRunnable workQueue : runnables) { - if (System.nanoTime() - startTime < TimeValue.timeValueMillis(maxRunTimeSupplier.get()).nanos()) { - logger.info("Starting primary [{}] batch to allocate", primary); - workQueue.run(); - } else { - logger.info("Timing out primary [{}] batch to allocate", primary); - workQueue.onTimeout(); - } - } - logger.info("Time taken to execute timed runnables in this cycle:[{}ms]", TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)); + allocator.allocateAllUnassignedShards(allocation, false).run(); } private void disassociateDeadNodes(RoutingAllocation allocation) { diff --git a/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java b/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java index 91d9d08f4c333..57098dab89b95 100644 --- a/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java +++ b/server/src/main/java/org/opensearch/cluster/routing/allocation/ExistingShardsAllocator.java @@ -39,13 +39,13 @@ import org.opensearch.common.Nullable; import org.opensearch.common.settings.Setting; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.BatchRunnableExecutor; import org.opensearch.common.util.concurrent.TimeoutAwareRunnable; import org.opensearch.gateway.GatewayAllocator; import org.opensearch.gateway.ShardsBatchGatewayAllocator; import java.util.ArrayList; import java.util.List; -import java.util.function.Consumer; /** * Searches for, and allocates, shards for which there is an existing on-disk copy somewhere in the cluster. The default implementation is @@ -112,7 +112,7 @@ void allocateUnassigned( * * Allocation service will currently run the default implementation of it implemented by {@link ShardsBatchGatewayAllocator} */ - default List allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) { + default BatchRunnableExecutor allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) { RoutingNodes.UnassignedShards.UnassignedIterator iterator = allocation.routingNodes().unassigned().iterator(); List runnables = new ArrayList<>(); while (iterator.hasNext()) { @@ -131,15 +131,7 @@ public void run() { }); } } - return runnables; - } - - default TimeValue getPrimaryBatchAllocatorTimeout() { - return TimeValue.MINUS_ONE; - } - - default TimeValue getReplicaBatchAllocatorTimeout() { - return TimeValue.MINUS_ONE; + return new BatchRunnableExecutor(runnables, () -> TimeValue.MINUS_ONE); } /** diff --git a/server/src/main/java/org/opensearch/common/util/BatchRunnableExecutor.java b/server/src/main/java/org/opensearch/common/util/BatchRunnableExecutor.java new file mode 100644 index 0000000000000..dc610161da22f --- /dev/null +++ b/server/src/main/java/org/opensearch/common/util/BatchRunnableExecutor.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.util; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.TimeoutAwareRunnable; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +/** + * The executor that executes a batch of {@link TimeoutAwareRunnable} and triggers a timeout based on {@link TimeValue} timeout + */ +public class BatchRunnableExecutor implements Runnable { + + private final Supplier timeoutSupplier; + + private final List timeoutAwareRunnables; + + private static final Logger logger = LogManager.getLogger(BatchRunnableExecutor.class); + + public BatchRunnableExecutor(List timeoutAwareRunnables, Supplier timeoutSupplier) { + this.timeoutSupplier = timeoutSupplier; + this.timeoutAwareRunnables = timeoutAwareRunnables; + } + + @Override + public void run() { + logger.debug("Starting execution of runnable of size [{}]", timeoutAwareRunnables.size()); + Collections.shuffle(timeoutAwareRunnables); + long startTime = System.nanoTime(); + for (TimeoutAwareRunnable workQueue : timeoutAwareRunnables) { + if (System.nanoTime() - startTime > timeoutSupplier.get().nanos()) { + workQueue.run(); + } else { + logger.debug("Executing timeout for runnable of size [{}]", timeoutAwareRunnables.size()); + workQueue.onTimeout(); + } + } + logger.debug("Time taken to execute timed runnables in this cycle:[{}ms]", + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)); + } + +} diff --git a/server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java b/server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java index b84a43004b54d..2d890c5e01e45 100644 --- a/server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java +++ b/server/src/main/java/org/opensearch/common/util/concurrent/TimeoutAwareRunnable.java @@ -9,7 +9,7 @@ package org.opensearch.common.util.concurrent; /** - * Runnable that is aware of a timeout and can execute another {@link Runnable} when a timeout is reached + * Runnable that is aware of a timeout */ public interface TimeoutAwareRunnable extends Runnable { diff --git a/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java b/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java index 936b25f0dd435..36a4049bac8fd 100644 --- a/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java +++ b/server/src/main/java/org/opensearch/gateway/ShardsBatchGatewayAllocator.java @@ -31,6 +31,7 @@ import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.BatchRunnableExecutor; import org.opensearch.common.util.concurrent.ConcurrentCollections; import org.opensearch.common.util.concurrent.TimeoutAwareRunnable; import org.opensearch.common.util.set.Sets; @@ -56,7 +57,6 @@ import java.util.Spliterators; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -168,24 +168,6 @@ protected ShardsBatchGatewayAllocator(long batchSize) { this.replicaShardsBatchGatewayAllocatorTimeout = null; } - @Override - public TimeValue getPrimaryBatchAllocatorTimeout() { - return this.primaryShardsBatchGatewayAllocatorTimeout; - } - - public void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) { - this.primaryShardsBatchGatewayAllocatorTimeout = primaryShardsBatchGatewayAllocatorTimeout; - } - - @Override - public TimeValue getReplicaBatchAllocatorTimeout() { - return this.primaryShardsBatchGatewayAllocatorTimeout; - } - - public void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) { - this.replicaShardsBatchGatewayAllocatorTimeout = replicaShardsBatchGatewayAllocatorTimeout; - } - // for tests @Override @@ -245,14 +227,14 @@ public void allocateUnassigned( } @Override - public List allocateAllUnassignedShards(final RoutingAllocation allocation, boolean primary) { + public BatchRunnableExecutor allocateAllUnassignedShards(final RoutingAllocation allocation, boolean primary) { assert primaryShardBatchAllocator != null; assert replicaShardBatchAllocator != null; return innerAllocateUnassignedBatch(allocation, primaryShardBatchAllocator, replicaShardBatchAllocator, primary); } - protected List innerAllocateUnassignedBatch( + protected BatchRunnableExecutor innerAllocateUnassignedBatch( RoutingAllocation allocation, PrimaryShardBatchAllocator primaryBatchShardAllocator, ReplicaShardBatchAllocator replicaBatchShardAllocator, @@ -261,7 +243,7 @@ protected List innerAllocateUnassignedBatch( // create batches for unassigned shards Set batchesToAssign = createAndUpdateBatches(allocation, primary); if (batchesToAssign.isEmpty()) { - return Collections.emptyList(); + return null; } List runnables = new ArrayList<>(); if (primary) { @@ -287,6 +269,7 @@ public void run() { } })); + return new BatchRunnableExecutor(runnables, () -> primaryShardsBatchGatewayAllocatorTimeout); } else { batchIdToStoreShardBatch.values() .stream() @@ -311,8 +294,8 @@ public void run() { } })); + return new BatchRunnableExecutor(runnables, () -> replicaShardsBatchGatewayAllocatorTimeout); } - return runnables; } // visible for testing @@ -816,4 +799,12 @@ public int getNumberOfStartedShardBatches() { public int getNumberOfStoreShardBatches() { return batchIdToStoreShardBatch.size(); } + + private void setPrimaryBatchAllocatorTimeout(TimeValue primaryShardsBatchGatewayAllocatorTimeout) { + this.primaryShardsBatchGatewayAllocatorTimeout = primaryShardsBatchGatewayAllocatorTimeout; + } + + private void setReplicaBatchAllocatorTimeout(TimeValue replicaShardsBatchGatewayAllocatorTimeout) { + this.replicaShardsBatchGatewayAllocatorTimeout = replicaShardsBatchGatewayAllocatorTimeout; + } } diff --git a/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java b/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java index 5f76c4f9b8189..0eb4bb6935bac 100644 --- a/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java +++ b/test/framework/src/main/java/org/opensearch/test/gateway/TestShardBatchGatewayAllocator.java @@ -13,7 +13,7 @@ import org.opensearch.cluster.routing.ShardRouting; import org.opensearch.cluster.routing.allocation.AllocateUnassignedDecision; import org.opensearch.cluster.routing.allocation.RoutingAllocation; -import org.opensearch.common.util.concurrent.TimeoutAwareRunnable; +import org.opensearch.common.util.BatchRunnableExecutor; import org.opensearch.core.index.shard.ShardId; import org.opensearch.gateway.AsyncShardFetch; import org.opensearch.gateway.PrimaryShardBatchAllocator; @@ -29,7 +29,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Consumer; public class TestShardBatchGatewayAllocator extends ShardsBatchGatewayAllocator { @@ -104,7 +103,7 @@ protected boolean hasInitiatedFetching(ShardRouting shard) { }; @Override - public List allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) { + public BatchRunnableExecutor allocateAllUnassignedShards(RoutingAllocation allocation, boolean primary) { currentNodes = allocation.nodes(); return innerAllocateUnassignedBatch(allocation, primaryBatchShardAllocator, replicaBatchShardAllocator, primary); }