Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor AggregateAttestationPool #9225

Draft
wants to merge 14 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,23 @@

package tech.pegasys.teku.statetransition.attestation;

import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
import java.util.stream.Stream;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.tuweni.bytes.Bytes;
import org.apache.tuweni.bytes.Bytes32;
import org.hyperledger.besu.plugin.services.MetricsSystem;
import tech.pegasys.teku.ethereum.events.SlotEventsChannel;
Expand Down Expand Up @@ -76,14 +75,15 @@ public class AggregatingAttestationPool implements SlotEventsChannel {
*/
public static final int DEFAULT_MAXIMUM_ATTESTATION_COUNT = 187_500;

private final Map<Bytes, MatchingDataAttestationGroup> attestationGroupByDataHash =
new HashMap<>();
private final NavigableMap<UInt64, Set<Bytes>> dataHashBySlot = new TreeMap<>();
private final Map<Bytes32, MatchingDataAttestationGroup> attestationGroupByDataHash =
new ConcurrentSkipListMap<>();
private final NavigableMap<UInt64, Set<Bytes32>> dataHashBySlot = new ConcurrentSkipListMap<>();

private final Spec spec;
private final RecentChainData recentChainData;
private final SettableGauge sizeGauge;
private final int maximumAttestationCount;
final AtomicBoolean isCleanupRunning = new AtomicBoolean(false);

private final AtomicInteger size = new AtomicInteger(0);

Expand All @@ -103,9 +103,10 @@ public AggregatingAttestationPool(
this.maximumAttestationCount = maximumAttestationCount;
}

public synchronized void add(final ValidatableAttestation attestation) {
public void add(final ValidatableAttestation attestation) {
final Optional<Int2IntMap> committeesSize =
attestation.getCommitteesSize().or(() -> getCommitteesSize(attestation.getAttestation()));

getOrCreateAttestationGroup(attestation.getAttestation(), committeesSize)
.ifPresent(
attestationGroup -> {
Expand All @@ -114,14 +115,8 @@ public synchronized void add(final ValidatableAttestation attestation) {
updateSize(1);
}
});
// Always keep the latest slot attestations, so we don't discard everything
int currentSize = getSize();
while (dataHashBySlot.size() > 1 && currentSize > maximumAttestationCount) {
LOG.trace("Attestation cache at {} exceeds {}, ", currentSize, maximumAttestationCount);
final UInt64 firstSlotToKeep = dataHashBySlot.firstKey().plus(1);
removeAttestationsPriorToSlot(firstSlotToKeep);
currentSize = getSize();
}

cleanupCache(Optional.empty());
}

private Optional<Int2IntMap> getCommitteesSize(final Attestation attestation) {
Expand All @@ -148,8 +143,13 @@ private Optional<MatchingDataAttestationGroup> getOrCreateAttestationGroup(
attestationData.getTarget().getRoot());
return Optional.empty();
}
return maybeCreateAttestationGroup(attestationData, committeesSize);
}

private synchronized Optional<MatchingDataAttestationGroup> maybeCreateAttestationGroup(
final AttestationData attestationData, final Optional<Int2IntMap> committeesSize) {
dataHashBySlot
.computeIfAbsent(attestationData.getSlot(), slot -> new HashSet<>())
.computeIfAbsent(attestationData.getSlot(), slot -> Sets.newConcurrentHashSet())
.add(attestationData.hashTreeRoot());
final MatchingDataAttestationGroup attestationGroup =
attestationGroupByDataHash.computeIfAbsent(
Expand Down Expand Up @@ -211,36 +211,24 @@ private Optional<Int2IntMap> getCommitteesSizeUsingTheState(
}

@Override
public synchronized void onSlot(final UInt64 slot) {
if (slot.compareTo(ATTESTATION_RETENTION_SLOTS) <= 0) {
public void onSlot(final UInt64 slot) {
if (slot.isLessThan(ATTESTATION_RETENTION_SLOTS)) {
return;
}
final UInt64 firstValidAttestationSlot = slot.minus(ATTESTATION_RETENTION_SLOTS);
removeAttestationsPriorToSlot(firstValidAttestationSlot);
cleanupCache(Optional.of(slot.minus(ATTESTATION_RETENTION_SLOTS)));
}

private void removeAttestationsPriorToSlot(final UInt64 firstValidAttestationSlot) {
final Collection<Set<Bytes>> dataHashesToRemove =
dataHashBySlot.headMap(firstValidAttestationSlot, false).values();
dataHashesToRemove.stream()
.flatMap(Set::stream)
.forEach(
key -> {
final int removed = attestationGroupByDataHash.get(key).size();
attestationGroupByDataHash.remove(key);
updateSize(-removed);
});
if (!dataHashesToRemove.isEmpty()) {
public void onAttestationsIncludedInBlock(
final UInt64 slot, final Iterable<Attestation> attestations) {
final Optional<UInt64> maybeCurrentSlot = recentChainData.getCurrentSlot();
if (maybeCurrentSlot.isEmpty()
|| maybeCurrentSlot.get().minusMinZero(slot).isGreaterThan(ATTESTATION_RETENTION_SLOTS)) {
LOG.trace(
"firstValidAttestationSlot: {}, removing: {}",
() -> firstValidAttestationSlot,
dataHashesToRemove::size);
"Attestations included in block at slot {}, head slot {} - skipping.",
slot,
maybeCurrentSlot);
return;
}
dataHashesToRemove.clear();
}

public synchronized void onAttestationsIncludedInBlock(
final UInt64 slot, final Iterable<Attestation> attestations) {
attestations.forEach(attestation -> onAttestationIncludedInBlock(slot, attestation));
}

Expand All @@ -259,11 +247,11 @@ private void updateSize(final int delta) {
sizeGauge.set(currentSize);
}

public synchronized int getSize() {
public int getSize() {
return size.get();
}

public synchronized SszList<Attestation> getAttestationsForBlock(
public SszList<Attestation> getAttestationsForBlock(
final BeaconState stateAtBlockSlot, final AttestationForkChecker forkChecker) {
final UInt64 currentEpoch = spec.getCurrentEpoch(stateAtBlockSlot);
final int previousEpochLimit = spec.getPreviousEpochAttestationCapacity(stateAtBlockSlot);
Expand Down Expand Up @@ -305,8 +293,49 @@ public synchronized SszList<Attestation> getAttestationsForBlock(
.collect(attestationsSchema.collector());
}

void cleanupCache(final Optional<UInt64> maybeSlot) {
// one cleanup at a time can run
if (!isCleanupRunning.compareAndSet(false, true)) {
return;
}

try {
if (maybeSlot.isEmpty()) {
while (dataHashBySlot.size() > 1 && size.get() > maximumAttestationCount) {
LOG.trace("Attestation cache at {} exceeds {}, ", size.get(), maximumAttestationCount);
removeAttestationsPriorToSlot(dataHashBySlot.firstKey().plus(1));
}
} else {
removeAttestationsPriorToSlot(maybeSlot.get());
}
} finally {
isCleanupRunning.set(false);
}
}

private void removeAttestationsPriorToSlot(final UInt64 firstValidAttestationSlot) {
final AtomicInteger count = new AtomicInteger(0);
dataHashBySlot.headMap(firstValidAttestationSlot, false).values().stream()
.flatMap(Set::stream)
.forEach(
key -> {
final MatchingDataAttestationGroup matchingDataAttestationGroup =
attestationGroupByDataHash.remove(key);
if (matchingDataAttestationGroup != null) {
updateSize(-matchingDataAttestationGroup.size());
count.incrementAndGet();
}
});
if (count.get() > 0) {
LOG.trace(
"firstValidAttestationSlot: {}, removed {} keys",
() -> firstValidAttestationSlot,
count::get);
}
}

private Stream<Attestation> streamAggregatesForDataHashesBySlot(
final Set<Bytes> dataHashSetForSlot,
final Set<Bytes32> dataHashSetForSlot,
final BeaconState stateAtBlockSlot,
final AttestationForkChecker forkChecker,
final boolean blockRequiresAttestationsWithCommitteeBits) {
Expand All @@ -324,15 +353,13 @@ private Stream<Attestation> streamAggregatesForDataHashesBySlot(
.sorted(ATTESTATION_INCLUSION_COMPARATOR);
}

public synchronized List<Attestation> getAttestations(
public List<Attestation> getAttestations(
final Optional<UInt64> maybeSlot, final Optional<UInt64> maybeCommitteeIndex) {

final Predicate<Map.Entry<UInt64, Set<Bytes>>> filterForSlot =
final Predicate<Map.Entry<UInt64, Set<Bytes32>>> filterForSlot =
(entry) -> maybeSlot.map(slot -> entry.getKey().equals(slot)).orElse(true);

final UInt64 slot = maybeSlot.orElse(recentChainData.getCurrentSlot().orElse(UInt64.ZERO));
final SchemaDefinitions schemaDefinitions = spec.atSlot(slot).getSchemaDefinitions();

final SchemaDefinitions schemaDefinitions = spec.atSlot(slot).getSchemaDefinitions();
final boolean requiresCommitteeBits =
schemaDefinitions.getAttestationSchema().requiresCommitteeBits();

Expand All @@ -354,13 +381,13 @@ private boolean isValid(
return spec.validateAttestation(stateAtBlockSlot, attestationData).isEmpty();
}

public synchronized Optional<ValidatableAttestation> createAggregateFor(
public Optional<ValidatableAttestation> createAggregateFor(
final Bytes32 attestationHashTreeRoot, final Optional<UInt64> committeeIndex) {
return Optional.ofNullable(attestationGroupByDataHash.get(attestationHashTreeRoot))
.flatMap(attestations -> attestations.stream(committeeIndex).findFirst());
}

public synchronized void onReorg(final UInt64 commonAncestorSlot) {
public void onReorg(final UInt64 commonAncestorSlot) {
attestationGroupByDataHash.values().forEach(group -> group.onReorg(commonAncestorSlot));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@

package tech.pegasys.teku.statetransition.attestation;

import com.google.common.collect.Sets;
import it.unimi.dsi.fastutil.ints.Int2IntMap;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.NavigableMap;
import java.util.Optional;
import java.util.Set;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.tuweni.bytes.Bytes32;
Expand All @@ -50,7 +50,7 @@
public class MatchingDataAttestationGroup implements Iterable<ValidatableAttestation> {

private final NavigableMap<Integer, Set<ValidatableAttestation>> attestationsByValidatorCount =
new TreeMap<>(Comparator.reverseOrder()); // Most validators first
new ConcurrentSkipListMap<>(Comparator.reverseOrder()); // Most validators first

private final Spec spec;
private Optional<Bytes32> committeeShufflingSeed = Optional.empty();
Expand All @@ -70,7 +70,7 @@ public class MatchingDataAttestationGroup implements Iterable<ValidatableAttesta
* {@link AggregatingAttestationPool} once it is too old to be included in blocks (32 slots).
*/
private final NavigableMap<UInt64, AttestationBitsAggregator> includedValidatorsBySlot =
new TreeMap<>();
new ConcurrentSkipListMap<>();

/** Precalculated combined list of included validators across all blocks. */
private AttestationBitsAggregator includedValidators;
Expand Down Expand Up @@ -110,10 +110,11 @@ public boolean add(final ValidatableAttestation attestation) {
if (committeeShufflingSeed.isEmpty()) {
committeeShufflingSeed = attestation.getCommitteeShufflingSeed();
}
// uses a concurrent Set for safety
return attestationsByValidatorCount
.computeIfAbsent(
attestation.getAttestation().getAggregationBits().getBitCount(),
count -> new HashSet<>())
count -> Sets.newConcurrentHashSet())
.add(attestation);
}

Expand Down Expand Up @@ -155,7 +156,7 @@ public Stream<ValidatableAttestation> stream(
return StreamSupport.stream(spliterator(committeeIndex), false);
}

public Spliterator<ValidatableAttestation> spliterator(final Optional<UInt64> committeeIndex) {
private Spliterator<ValidatableAttestation> spliterator(final Optional<UInt64> committeeIndex) {
return Spliterators.spliteratorUnknownSize(iterator(committeeIndex), 0);
}

Expand Down Expand Up @@ -254,40 +255,35 @@ private class AggregatingIterator implements Iterator<ValidatableAttestation> {
private final Optional<UInt64> maybeCommitteeIndex;
private final AttestationBitsAggregator includedValidators;

private Iterator<ValidatableAttestation> remainingAttestations = getRemainingAttestations();

private AggregatingIterator(final Optional<UInt64> committeeIndex) {
this.maybeCommitteeIndex = committeeIndex;
includedValidators = MatchingDataAttestationGroup.this.includedValidators.copy();
}

@Override
public boolean hasNext() {
if (!remainingAttestations.hasNext()) {
remainingAttestations = getRemainingAttestations();
}
return remainingAttestations.hasNext();
return streamRemainingAttestations().findAny().isPresent();
}

@Override
public ValidatableAttestation next() {
final AggregateAttestationBuilder builder =
new AggregateAttestationBuilder(spec, attestationData);
remainingAttestations.forEachRemaining(
candidate -> {
if (builder.aggregate(candidate)) {
includedValidators.or(candidate.getAttestation());
}
});
streamRemainingAttestations()
.forEach(
candidate -> {
if (builder.aggregate(candidate)) {
includedValidators.or(candidate.getAttestation());
}
});
return builder.buildAggregate();
}

public Iterator<ValidatableAttestation> getRemainingAttestations() {
private Stream<ValidatableAttestation> streamRemainingAttestations() {
return attestationsByValidatorCount.values().stream()
.flatMap(Set::stream)
.filter(this::isAttestationRelevant)
.filter(candidate -> !includedValidators.isSuperSetOf(candidate.getAttestation()))
.iterator();
.filter(candidate -> !includedValidators.isSuperSetOf(candidate.getAttestation()));
}

private boolean isAttestationRelevant(final ValidatableAttestation candidate) {
Expand Down
Loading