Skip to content

Commit

Permalink
Remove ShardId from ShardAttributes as it's not useful and add test f…
Browse files Browse the repository at this point in the history
…or ShardBatchCache

Signed-off-by: Aman Khare <[email protected]>
  • Loading branch information
Aman Khare committed Mar 5, 2024
1 parent 5ab6aed commit da6f9cb
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public static Map<ShardId, ShardAttributes> prepareRequestMap(String[] indices,
);
for (int shardIdNum = 0; shardIdNum < primaryShardCount; shardIdNum++) {
final ShardId shardId = new ShardId(index, shardIdNum);
shardIdShardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
shardIdShardAttributesMap.put(shardId, new ShardAttributes(customDataPath));
}
}
return shardIdShardAttributesMap;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ protected AsyncShardFetch(
this.logger = logger;
this.type = type;
shardAttributesMap = new HashMap<>();
shardAttributesMap.put(shardId, new ShardAttributes(shardId, customDataPath));
shardAttributesMap.put(shardId, new ShardAttributes(customDataPath));
this.action = (Lister<BaseNodesResponse<T>, T>) action;
this.reroutingKey = "ShardId=[" + shardId.toString() + "]";
this.shardCache = new ShardCache<T>(logger, reroutingKey, type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ boolean hasAnyNodeFetching() {
* @param failedNodes return failedNodes with the nodes where fetch has failed.
* @return Map of cache data for every DiscoveryNode.
*/
Map<DiscoveryNode, K> getCacheData(DiscoveryNodes nodes, Set<String> failedNodes) {
Map<DiscoveryNode, K> getCacheData(DiscoveryNodes nodes, Set<String> failedNodes) {
Map<DiscoveryNode, K> fetchData = new HashMap<>();
for (Iterator<? extends Map.Entry<String, ? extends BaseNodeEntry>> it = getCache().entrySet().iterator(); it.hasNext();) {
Map.Entry<String, BaseNodeEntry> entry = (Map.Entry<String, BaseNodeEntry>) it.next();
Expand Down
3 changes: 2 additions & 1 deletion server/src/main/java/org/opensearch/gateway/ShardCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
*/
public class ShardCache<K extends BaseNodeResponse> extends BaseShardCache<K> {

private final Map<String, NodeEntry<K>> cache = new HashMap<>();
private final Map<String, NodeEntry<K>> cache;

public ShardCache(Logger logger, String logKey, String type) {
super(logger, logKey, type);
cache = new HashMap<>();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ else if (shardRouting.primary() == primary) {
if (batchSize > 0) {
ShardEntry shardEntry = new ShardEntry(
new ShardAttributes(
currentShard.shardId(),
IndexMetadata.INDEX_DATA_PATH_SETTING.get(allocation.metadata().index(currentShard.index()).getSettings())
),
currentShard
Expand Down Expand Up @@ -705,7 +704,7 @@ public String toString() {
/**
* Holds information about a shard to be allocated in a batch.
*/
public class ShardEntry {
public static class ShardEntry {

private final ShardAttributes shardAttributes;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.index.shard.ShardId;

import java.io.IOException;

Expand All @@ -22,24 +21,17 @@
* @opensearch.internal
*/
public class ShardAttributes implements Writeable {
private final ShardId shardId;
@Nullable
private final String customDataPath;

public ShardAttributes(ShardId shardId, String customDataPath) {
this.shardId = shardId;
public ShardAttributes(String customDataPath) {
this.customDataPath = customDataPath;
}

public ShardAttributes(StreamInput in) throws IOException {
shardId = new ShardId(in);
customDataPath = in.readString();
}

public ShardId getShardId() {
return shardId;
}

/**
* Returns the custom data path that is used to look up information for this shard.
* Returns an empty string if no custom data path is used for this index.
Expand All @@ -51,12 +43,11 @@ public String getCustomDataPath() {
}

public void writeTo(StreamOutput out) throws IOException {
shardId.writeTo(out);
out.writeString(customDataPath);
}

@Override
public String toString() {
return "ShardAttributes{" + "shardId=" + shardId + ", customDataPath='" + customDataPath + '\'' + '}';
return "ShardAttributes{" + ", customDataPath='" + customDataPath + '\'' + '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ public void setUp() throws Exception {
HashMap<ShardId, ShardAttributes> shardToCustomDataPath = new HashMap<>();
ShardId shardId0 = new ShardId("index1", "index_uuid1", 0);
ShardId shardId1 = new ShardId("index2", "index_uuid2", 0);
shardToCustomDataPath.put(shardId0, new ShardAttributes(shardId0, ""));
shardToCustomDataPath.put(shardId1, new ShardAttributes(shardId1, ""));
shardToCustomDataPath.put(shardId0, new ShardAttributes(""));
shardToCustomDataPath.put(shardId1, new ShardAttributes(""));
this.test = new TestFetch(threadPool, shardToCustomDataPath);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import org.opensearch.core.index.shard.ShardId;

import java.util.HashSet;
import java.util.Set;
import java.util.ArrayList;
import java.util.List;

public class BatchTestUtil {
public static Set<ShardId> setUpShards(int numberOfShards) {
Set<ShardId> shards = new HashSet<>();
public static List<ShardId> setUpShards(int numberOfShards) {
List<ShardId> shards = new ArrayList<>();
for (int shardNumber = 0; shardNumber < numberOfShards; shardNumber++) {
ShardId shardId = new ShardId("test", "_na_", shardNumber);
shards.add(shardId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
public class PrimaryShardBatchAllocatorTests extends OpenSearchAllocationTestCase {

private final ShardId shardId = new ShardId("test", "_na_", 0);
private static Set<ShardId> shardsInBatch;
private static List<ShardId> shardsInBatch;
private final DiscoveryNode node1 = newNode("node1");
private final DiscoveryNode node2 = newNode("node2");
private final DiscoveryNode node3 = newNode("node3");
Expand Down
145 changes: 132 additions & 13 deletions server/src/test/java/org/opensearch/gateway/ShardBatchCacheTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,30 @@

package org.opensearch.gateway;

import org.opensearch.cluster.OpenSearchAllocationTestCase;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.routing.ShardRouting;
import org.opensearch.cluster.routing.ShardRoutingState;
import org.opensearch.cluster.routing.TestShardRouting;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard;
import org.opensearch.gateway.TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch;
import org.opensearch.indices.store.ShardAttributes;
import org.opensearch.test.OpenSearchTestCase;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class ShardBatchCacheTests extends OpenSearchTestCase {
private final ShardId shardId = new ShardId("test", "_na_", 0);

public class ShardBatchCacheTests extends OpenSearchAllocationTestCase {
private static final String BATCH_ID = "b1";
private final DiscoveryNode node1 = newNode("node1");
private final DiscoveryNode node2 = newNode("node2");
private final DiscoveryNode node3 = newNode("node3");
private Map<ShardId, ShardsBatchGatewayAllocator.ShardEntry> batchInfo = new HashMap<>();
private ShardBatchCache<
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard> shardCache;
private ShardBatchCache<NodeGatewayStartedShardsBatch, NodeGatewayStartedShard> shardCache;
private List<ShardId> shardsInBatch = new ArrayList<>();

public void setupShardBatchCache(String batchId) {
Map<ShardId, ShardAttributes> shardAttributesMap = new HashMap<>();
Expand All @@ -31,21 +41,130 @@ public void setupShardBatchCache(String batchId) {
"batch_shards_started",
shardAttributesMap,
"BatchID=[" + batchId + "]",
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard.class,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::new,
TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShardsBatch::getNodeGatewayStartedShardsBatch,
() -> new TransportNodesListGatewayStartedShardsBatch.NodeGatewayStartedShard(null, false, null, null),
NodeGatewayStartedShard.class,
NodeGatewayStartedShardsBatch::new,
NodeGatewayStartedShardsBatch::getNodeGatewayStartedShardsBatch,
() -> new NodeGatewayStartedShard(null, false, null, null),
this::removeShard
);
}

public void testClearShardCache() {
setupShardBatchCache(BATCH_ID);
ShardId shard = shardsInBatch.iterator().next();
this.shardCache.initData(node1);
this.shardCache.markAsFetching(List.of(node1.getId()), 1);
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getEmptyPrimaryResponse(shardsInBatch)));
assertTrue(
this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null)
.get(node1)
.getNodeGatewayStartedShardsBatch()
.containsKey(shard)
);
this.shardCache.clearShardCache(shard);
assertFalse(
this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null)
.get(node1)
.getNodeGatewayStartedShardsBatch()
.containsKey(shard)
);
}

public void testGetCacheData() {
setupShardBatchCache(BATCH_ID);
ShardId shard = shardsInBatch.iterator().next();
this.shardCache.initData(node1);
this.shardCache.initData(node2);
this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1);
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getEmptyPrimaryResponse(shardsInBatch)));
assertTrue(
this.shardCache.getCacheData(DiscoveryNodes.builder().add(node1).build(), null)
.get(node1)
.getNodeGatewayStartedShardsBatch()
.containsKey(shard)
);
assertTrue(
this.shardCache.getCacheData(DiscoveryNodes.builder().add(node2).build(), null)
.get(node2)
.getNodeGatewayStartedShardsBatch()
.isEmpty()
);
}

public void testInitCacheData() {
setupShardBatchCache(BATCH_ID);
this.shardCache.initData(node1);
this.shardCache.initData(node2);
assertEquals(2, shardCache.getCache().size());
}

public void testPutData() {
setupShardBatchCache(BATCH_ID);
ShardId shard = shardsInBatch.iterator().next();
this.shardCache.initData(node1);
this.shardCache.initData(node2);
this.shardCache.markAsFetching(List.of(node1.getId(), node2.getId()), 1);
this.shardCache.putData(node1, new NodeGatewayStartedShardsBatch(node1, getActualPrimaryResponse(shardsInBatch)));
this.shardCache.putData(node2, new NodeGatewayStartedShardsBatch(node1, getEmptyPrimaryResponse(shardsInBatch)));

Map<DiscoveryNode, NodeGatewayStartedShardsBatch> fetchData = shardCache.getCacheData(
DiscoveryNodes.builder().add(node1).add(node2).build(),
null
);
assertEquals(2, fetchData.size());
assertEquals(10, fetchData.get(node1).getNodeGatewayStartedShardsBatch().size());
assertEquals("alloc-1", fetchData.get(node1).getNodeGatewayStartedShardsBatch().get(shard).allocationId());

assertEquals(10, fetchData.get(node2).getNodeGatewayStartedShardsBatch().size());
assertTrue(fetchData.get(node2).getNodeGatewayStartedShardsBatch().get(shard).isEmpty());
}

public void testFilterFailedShards() {
// ToDo
}

private Map<ShardId, NodeGatewayStartedShard> getEmptyPrimaryResponse(List<ShardId> shards) {
Map<ShardId, NodeGatewayStartedShard> shardData = new HashMap<>();
for (ShardId shard : shards) {
shardData.put(shard, new NodeGatewayStartedShard(null, false, null, null));
}
return shardData;
}

private Map<ShardId, NodeGatewayStartedShard> getActualPrimaryResponse(List<ShardId> shards) {
int allocationId = 1;
Map<ShardId, NodeGatewayStartedShard> shardData = new HashMap<>();
for (ShardId shard : shards) {
shardData.put(shard, new NodeGatewayStartedShard("alloc-" + allocationId++, false, null, null));
}
return shardData;
}

public void removeShard(ShardId shardId) {
batchInfo.remove(shardId);
}

private void fillShards(Map<ShardId, ShardAttributes> shardAttributesMap) {
BatchTestUtil.setUpShards(10);
// ToDo
shardsInBatch = BatchTestUtil.setUpShards(10);
for (ShardId shardId : shardsInBatch) {
ShardAttributes attr = new ShardAttributes("");
shardAttributesMap.put(shardId, attr);
batchInfo.put(
shardId,
new ShardsBatchGatewayAllocator.ShardEntry(attr, randomShardRouting(shardId.getIndexName(), shardId.id()))
);
}
}

private ShardRouting randomShardRouting(String index, int shard) {
ShardRoutingState state = randomFrom(ShardRoutingState.values());
return TestShardRouting.newShardRouting(
index,
shard,
state == ShardRoutingState.UNASSIGNED ? null : "1",
state == ShardRoutingState.RELOCATING ? "2" : null,
state != ShardRoutingState.UNASSIGNED && randomBoolean(),
state
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ public class ShardAttributesTests extends OpenSearchTestCase {
String customDataPath = "/path/to/data";

public void testShardAttributesConstructor() {
ShardAttributes attributes = new ShardAttributes(shardId, customDataPath);
ShardAttributes attributes = new ShardAttributes(customDataPath);
assertEquals(attributes.getCustomDataPath(), customDataPath);
}

public void testSerialization() throws IOException {
ShardAttributes attributes1 = new ShardAttributes(shardId, customDataPath);
ShardAttributes attributes1 = new ShardAttributes(customDataPath);
ByteArrayOutputStream bytes = new ByteArrayOutputStream();
StreamOutput output = new DataOutputStreamOutput(new DataOutputStream(bytes));
attributes1.writeTo(output);
Expand Down

0 comments on commit da6f9cb

Please sign in to comment.