Skip to content

Commit

Permalink
Incorporate AsyncShardBatchFetch class changes
Browse files Browse the repository at this point in the history
Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Apr 12, 2024
1 parent 0f502df commit baca309
Show file tree
Hide file tree
Showing 10 changed files with 51 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ public void testBatchModeEnabled() throws Exception {
ensureGreen("test");
assertEquals(0, gatewayAllocator.getNumberOfStartedShardBatches());
assertEquals(0, gatewayAllocator.getNumberOfStoreShardBatches());
assertEquals(0,gatewayAllocator.getNumberOfInFlightFetches());
assertEquals(0, gatewayAllocator.getNumberOfInFlightFetches());
}

public void testBatchModeDisabled() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,7 @@ public ClusterModule(
this.shardsAllocator = createShardsAllocator(settings, clusterService.getClusterSettings(), clusterPlugins);
this.clusterService = clusterService;
this.indexNameExpressionResolver = new IndexNameExpressionResolver(threadContext);
this.allocationService = new AllocationService(
allocationDeciders,
shardsAllocator,
clusterInfoService,
snapshotsInfoService,
settings
);
this.allocationService = new AllocationService(allocationDeciders, shardsAllocator, clusterInfoService, snapshotsInfoService);
}

public static List<Entry> getNamedWriteables() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public class GatewayAllocator implements ExistingShardsAllocator {
private final ConcurrentMap<
ShardId,
AsyncShardFetch<TransportNodesListGatewayStartedShards.NodeGatewayStartedShards>> asyncFetchStarted = ConcurrentCollections
.newConcurrentMap();
.newConcurrentMap();
private final ConcurrentMap<ShardId, AsyncShardFetch<TransportNodesListShardStoreMetadata.NodeStoreFilesMetadata>> asyncFetchStore =
ConcurrentCollections.newConcurrentMap();
private Set<String> lastSeenEphemeralIds = Collections.emptySet();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@
import org.opensearch.common.util.set.Sets;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper.GatewayStartedShard;
import org.opensearch.index.store.Store;
import org.opensearch.indices.store.ShardAttributes;
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch;
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata;
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataHelper;
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataHelper.StoreFilesMetadata;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -48,10 +51,7 @@
import java.util.Set;
import java.util.Spliterators;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
Expand Down Expand Up @@ -110,9 +110,7 @@ public ShardsBatchGatewayAllocator(
@Override
public void cleanCaches() {
Stream.of(batchIdToStartedShardBatch, batchIdToStoreShardBatch).forEach(b -> {
Releasables.close(
b.values().stream().map(shardsBatch -> shardsBatch.asyncBatch).collect(Collectors.toList())
);
Releasables.close(b.values().stream().map(shardsBatch -> shardsBatch.asyncBatch).collect(Collectors.toList()));
b.clear();
});
}
Expand Down Expand Up @@ -227,9 +225,7 @@ protected Set<String> createAndUpdateBatches(RoutingAllocation allocation, boole
// get all batched shards
Map<ShardId, String> currentBatchedShards = new HashMap<>();
for (Map.Entry<String, ShardsBatch> batchEntry : currentBatches.entrySet()) {
batchEntry.getValue().getBatchedShards()
.forEach(shardId -> currentBatchedShards.put(shardId,
batchEntry.getKey()));
batchEntry.getValue().getBatchedShards().forEach(shardId -> currentBatchedShards.put(shardId, batchEntry.getKey()));
}

Set<ShardRouting> newShardsToBatch = Sets.newHashSet();
Expand Down Expand Up @@ -447,23 +443,11 @@ class InternalBatchAsyncFetch<T extends BaseNodeResponse, V> extends AsyncShardB
AsyncShardFetch.Lister<? extends BaseNodesResponse<T>, T> action,
String batchUUId,
Class<V> clazz,
BiFunction<DiscoveryNode, Map<ShardId, V>, T> responseBuilder,
Function<T, Map<ShardId, V>> shardsBatchDataGetter,
Supplier<V> emptyResponseBuilder,
Consumer<ShardId> handleFailedShard
V emptyShardResponse,
Predicate<V> emptyShardResponsePredicate,
ShardBatchResponseFactory<T, V> responseFactory
) {
super(
logger,
type,
map,
action,
batchUUId,
clazz,
responseBuilder,
shardsBatchDataGetter,
emptyResponseBuilder,
handleFailedShard
);
super(logger, type, map, action, batchUUId, clazz, emptyShardResponse, emptyShardResponsePredicate, responseFactory);
}

@Override
Expand Down Expand Up @@ -622,11 +606,10 @@ public ShardsBatch(String batchId, Map<ShardId, ShardEntry> shardsWithInfo, bool
shardIdsMap,
batchStartedAction,
batchId,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard.class,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::new,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::getNodeGatewayStartedShardsBatch,
() -> new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard(null, false, null, null),
this::removeShard
GatewayStartedShard.class,
new GatewayStartedShard(null, false, null, null),
GatewayStartedShard::isEmpty,
new ShardBatchResponseFactory<>(true)
);
} else {
asyncBatch = new InternalBatchAsyncFetch<>(
Expand All @@ -635,11 +618,10 @@ public ShardsBatch(String batchId, Map<ShardId, ShardEntry> shardsWithInfo, bool
shardIdsMap,
batchStoreAction,
batchId,
TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadata.class,
TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch::new,
TransportNodesListShardStoreMetadataBatch.NodeStoreFilesMetadataBatch::getNodeStoreFilesMetadataBatch,
this::buildEmptyReplicaShardResponse,
this::removeShard
NodeStoreFilesMetadata.class,
new NodeStoreFilesMetadata(new StoreFilesMetadata(null, Store.MetadataSnapshot.EMPTY, Collections.emptyList()), null),
NodeStoreFilesMetadata::isEmpty,
new ShardBatchResponseFactory<>(false)
);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

boolean isEmpty(NodeStoreFilesMetadata response) {
public static boolean isEmpty(NodeStoreFilesMetadata response) {
return response.storeFilesMetadata() == null
|| response.storeFilesMetadata().isEmpty() && response.getStoreFileFetchException() == null;
}
Expand Down
8 changes: 5 additions & 3 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@
import org.opensearch.extensions.ExtensionsManager;
import org.opensearch.extensions.NoopExtensionsManager;
import org.opensearch.gateway.GatewayAllocator;
import org.opensearch.gateway.ShardsBatchGatewayAllocator;
import org.opensearch.gateway.GatewayMetaState;
import org.opensearch.gateway.GatewayModule;
import org.opensearch.gateway.GatewayService;
import org.opensearch.gateway.MetaStateService;
import org.opensearch.gateway.PersistedClusterStateService;
import org.opensearch.gateway.ShardsBatchGatewayAllocator;
import org.opensearch.gateway.remote.RemoteClusterStateService;
import org.opensearch.http.HttpServerTransport;
import org.opensearch.identity.IdentityService;
Expand Down Expand Up @@ -1324,8 +1324,10 @@ protected Node(
// completes we trigger another reroute to try the allocation again. This means there is a circular dependency: the allocation
// service needs access to the existing shards allocators (e.g. the GatewayAllocator, ShardsBatchGatewayAllocator) which
// need to be able to trigger a reroute, which needs to call into the allocation service. We close the loop here:
clusterModule.setExistingShardsAllocators(injector.getInstance(GatewayAllocator.class),
injector.getInstance(ShardsBatchGatewayAllocator.class));
clusterModule.setExistingShardsAllocators(
injector.getInstance(GatewayAllocator.class),
injector.getInstance(ShardsBatchGatewayAllocator.class)
);

List<LifecycleComponent> pluginLifecycleComponents = pluginComponents.stream()
.filter(p -> p instanceof LifecycleComponent)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -298,8 +297,10 @@ public void testRejectsReservedExistingShardsAllocatorName() {
null,
threadContext
);
expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator(),
new TestShardBatchGatewayAllocator()));
expectThrows(
IllegalArgumentException.class,
() -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator(), new TestShardBatchGatewayAllocator())
);
}

public void testRejectsDuplicateExistingShardsAllocatorName() {
Expand All @@ -311,8 +312,10 @@ public void testRejectsDuplicateExistingShardsAllocatorName() {
null,
threadContext
);
expectThrows(IllegalArgumentException.class, () -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator(),
new TestShardBatchGatewayAllocator()));
expectThrows(
IllegalArgumentException.class,
() -> clusterModule.setExistingShardsAllocators(new TestGatewayAllocator(), new TestShardBatchGatewayAllocator())
);
}

private static ClusterPlugin existingShardsAllocatorPlugin(final String allocatorName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ public void testGetBatchIdExisting() {
for (String batchId : primaryBatches) {
if (shardRouting.primary() == true
&& testShardsBatchGatewayAllocator.getBatchIdToStartedShardBatch()
.get(batchId)
.getBatchedShards()
.contains(shardRouting.shardId())) {
.get(batchId)
.getBatchedShards()
.contains(shardRouting.shardId())) {
if (shardIdToBatchIdForStartedShards.containsKey(shardRouting.shardId())) {
fail("found duplicate shard routing for shard. One shard cant be in multiple batches " + shardRouting.shardId());
}
Expand All @@ -272,9 +272,9 @@ public void testGetBatchIdExisting() {
for (String batchId : replicaBatches) {
if (shardRouting.primary() == false
&& testShardsBatchGatewayAllocator.getBatchIdToStoreShardBatch()
.get(batchId)
.getBatchedShards()
.contains(shardRouting.shardId())) {
.get(batchId)
.getBatchedShards()
.contains(shardRouting.shardId())) {
if (shardIdToBatchIdForStoreShards.containsKey(shardRouting.shardId())) {
fail("found duplicate shard routing for shard. One shard cant be in multiple batches " + shardRouting.shardId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ public class ShardBatchCacheTests extends OpenSearchAllocationTestCase {
private static final String BATCH_ID = "b1";
private final DiscoveryNode node1 = newNode("node1");
private final DiscoveryNode node2 = newNode("node2");
// Needs to be enabled once ShardsBatchGatewayAllocator is pushed
// private final Map<ShardId, ShardsBatchGatewayAllocator.ShardEntry> batchInfo = new HashMap<>();
private final Map<ShardId, ShardsBatchGatewayAllocator.ShardEntry> batchInfo = new HashMap<>();
private AsyncShardBatchFetch.ShardBatchCache<NodeGatewayStartedShardsBatch, GatewayStartedShard> shardCache;
private List<ShardId> shardsInBatch = new ArrayList<>();
private static final int NUMBER_OF_SHARDS_DEFAULT = 10;
Expand Down Expand Up @@ -162,7 +161,7 @@ public void testShardsDataWithException() {
null
);

// assertEquals(5, batchInfo.size());
assertEquals(10, batchInfo.size());
assertEquals(2, fetchData.size());
assertEquals(10, fetchData.get(node1).getNodeGatewayStartedShardsBatch().size());
assertTrue(fetchData.get(node2).getNodeGatewayStartedShardsBatch().isEmpty());
Expand Down Expand Up @@ -210,10 +209,10 @@ private void fillShards(Map<ShardId, ShardAttributes> shardAttributesMap, int nu
for (ShardId shardId : shardsInBatch) {
ShardAttributes attr = new ShardAttributes("");
shardAttributesMap.put(shardId, attr);
// batchInfo.put(
// shardId,
// new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id()))
// );
batchInfo.put(
shardId,
new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id()))
);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.gateway.PrimaryShardBatchAllocator;
import org.opensearch.gateway.ReplicaShardBatchAllocator;
import org.opensearch.gateway.ShardsBatchGatewayAllocator;
import org.opensearch.gateway.TransportNodesGatewayStartedShardHelper;
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.opensearch.indices.store.TransportNodesListShardStoreMetadataBatch;
Expand Down Expand Up @@ -46,15 +47,15 @@ protected AsyncShardFetch.FetchResult<TransportNodesListGatewayStartedShardsBatc
for (Map.Entry<String, Map<ShardId, ShardRouting>> entry : knownAllocations.entrySet()) {
String nodeId = entry.getKey();
Map<ShardId, ShardRouting> shardsOnNode = entry.getValue();
HashMap<ShardId, TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard> adaptedResponse = new HashMap<>();
HashMap<ShardId, TransportNodesGatewayStartedShardHelper.GatewayStartedShard> adaptedResponse = new HashMap<>();

for (ShardRouting shardRouting : eligibleShards) {
ShardId shardId = shardRouting.shardId();
Set<String> ignoreNodes = allocation.getIgnoreNodes(shardId);

if (shardsOnNode.containsKey(shardId) && ignoreNodes.contains(nodeId) == false && currentNodes.nodeExists(nodeId)) {
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard nodeShard =
new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard(
TransportNodesGatewayStartedShardHelper.GatewayStartedShard nodeShard =
new TransportNodesGatewayStartedShardHelper.GatewayStartedShard(
shardsOnNode.get(shardId).allocationId().getId(),
shardsOnNode.get(shardId).primary(),
getReplicationCheckpoint(shardId, nodeId)
Expand Down

0 comments on commit baca309

Please sign in to comment.