Skip to content

Commit

Permalink
Merge commit from fork (#9)
Browse files Browse the repository at this point in the history
* Merge commit from fork

lower than what was previously reported
GHSA-22qq-3xwm-r5x4

* Merge commit from fork

* fix test

* fix(test): fix TestBlockPoolMaliciousNode DATA RACE (backport cometbft#4636) (cometbft#4641)

Co-authored-by: Anton Kaliaev <[email protected]>

---------

Co-authored-by: Anton Kaliaev <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 3, 2025
1 parent ce418f8 commit 1964da4
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- `[blocksync]` Ban peer if it reports height lower than what was previously reported
([ASA-2025-001](https://github.com/cometbft/cometbft/security/advisories/GHSA-22qq-3xwm-r5x4))
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- `[types]` Check that `Part.Index` equals `Part.Proof.Index`
([ASA-2025-001](https://github.com/cometbft/cometbft/security/advisories/GHSA-r3r4-g7hq-pq4f))
22 changes: 21 additions & 1 deletion blocksync/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,21 @@ func (pool *BlockPool) SetPeerRange(peerID p2p.ID, base int64, height int64) {

peer := pool.peers[peerID]
if peer != nil {
if base < peer.base || height < peer.height {
pool.Logger.Info("Peer is reporting height/base that is lower than what it previously reported",
"peer", peerID,
"height", height, "base", base,
"prevHeight", peer.height, "prevBase", peer.base)
// RemovePeer will redo all requesters associated with this peer.
pool.removePeer(peerID)
pool.banPeer(peerID)
return
}
peer.base = base
peer.height = height
} else {
if pool.isPeerBanned(peerID) {
pool.Logger.Debug("Ignoring banned peer", peerID)
pool.Logger.Debug("Ignoring banned peer", "peer", peerID)
return
}
peer = newBPPeer(pool, peerID, base, height)
Expand All @@ -400,6 +410,7 @@ func (pool *BlockPool) RemovePeer(peerID p2p.ID) {
pool.removePeer(peerID)
}

// CONTRACT: pool.mtx must be locked.
func (pool *BlockPool) removePeer(peerID p2p.ID) {
for _, requester := range pool.requesters {
if requester.didRequestFrom(peerID) {
Expand Down Expand Up @@ -440,11 +451,20 @@ func (pool *BlockPool) updateMaxPeerHeight() {
pool.maxPeerHeight = max
}

// IsPeerBanned returns true if the peer is banned.
func (pool *BlockPool) IsPeerBanned(peerID p2p.ID) bool {
pool.mtx.Lock()
defer pool.mtx.Unlock()
return pool.isPeerBanned(peerID)
}

// CONTRACT: pool.mtx must be locked.
func (pool *BlockPool) isPeerBanned(peerID p2p.ID) bool {
// Todo: replace with cmttime.Since in future versions
return time.Since(pool.bannedPeers[peerID]) < time.Second*60
}

// CONTRACT: pool.mtx must be locked.
func (pool *BlockPool) banPeer(peerID p2p.ID) {
pool.Logger.Debug("Banning peer", peerID)
pool.bannedPeers[peerID] = cmttime.Now()
Expand Down
122 changes: 121 additions & 1 deletion blocksync/pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package blocksync

import (
"fmt"
"math"
"testing"
"time"

Expand Down Expand Up @@ -362,7 +363,126 @@ func TestBlockPoolMaliciousNode(t *testing.T) {
// Process request
peers[request.PeerID].inputChan <- inputData{t, pool, request}
case <-testTicker.C:
banned := pool.isPeerBanned("bad")
banned := pool.IsPeerBanned("bad")
bannedOnce = bannedOnce || banned // Keep bannedOnce true, even if the malicious peer gets unbanned
caughtUp := pool.IsCaughtUp()
// Success: pool caught up and malicious peer was banned at least once
if caughtUp && bannedOnce {
t.Logf("Pool caught up, malicious peer was banned at least once, start consensus.")
return
}
// Failure: the pool caught up without banning the bad peer at least once
require.False(t, caughtUp, "Network caught up without banning the malicious peer at least once.")
// Failure: the network could not catch up in the allotted time
require.True(t, time.Since(startTime) < MaliciousTestMaximumLength, "Network ran too long, stopping test.")
}
}
}

func TestBlockPoolMaliciousNodeMaxInt64(t *testing.T) {
// Setup:
// * each peer has blocks 1..N but the malicious peer reports 1..max(int64) (blocks N+1... do not exist)
// * The malicious peer then reports 1..N this time
// * Afterwards, it can choose to disconnect or stay connected to serve blocks that it has
// * The node ends up stuck in blocksync forever because max height is never reached (as of 63a2a6458)
// Additional notes:
// * When a peer is removed, we only update max height if it equals peer's
// height. The aforementioned scenario where peer reports its height twice
// lowering the height was not accounted for.
const initialHeight = 7
peers := testPeers{
p2p.ID("good"): &testPeer{p2p.ID("good"), 1, initialHeight, make(chan inputData), false},
p2p.ID("bad"): &testPeer{p2p.ID("bad"), 1, math.MaxInt64, make(chan inputData), true},
p2p.ID("good1"): &testPeer{p2p.ID("good1"), 1, initialHeight, make(chan inputData), false},
}
errorsCh := make(chan peerError, 3)
requestsCh := make(chan BlockRequest)

pool := NewBlockPool(1, requestsCh, errorsCh)
pool.SetLogger(log.TestingLogger())

err := pool.Start()
if err != nil {
t.Error(err)
}

t.Cleanup(func() {
if err := pool.Stop(); err != nil {
t.Error(err)
}
})

peers.start()
t.Cleanup(func() { peers.stop() })

// Simulate blocks created on each peer regularly and update pool max height.
go func() {
// Introduce each peer
for _, peer := range peers {
pool.SetPeerRange(peer.id, peer.base, peer.height)
}

// Report the lower height
peers["bad"].height = initialHeight
pool.SetPeerRange(p2p.ID("bad"), 1, initialHeight)

ticker := time.NewTicker(1 * time.Second) // Speed of new block creation
defer ticker.Stop()
for {
select {
case <-pool.Quit():
return
case <-ticker.C:
for _, peer := range peers {
peer.height++ // Network height increases on all peers
pool.SetPeerRange(peer.id, peer.base, peer.height) // Tell the pool that a new height is available
}
}
}
}()

// Start a goroutine to verify blocks
go func() {
ticker := time.NewTicker(500 * time.Millisecond) // Speed of new block creation
defer ticker.Stop()
for {
select {
case <-pool.Quit():
return
case <-ticker.C:
first, second, _ := pool.PeekTwoBlocks()
if first != nil && second != nil {
if second.LastCommit == nil {
// Second block is fake
pool.RemovePeerAndRedoAllPeerRequests(second.Height)
} else {
pool.PopRequest()
}
}
}
}
}()

testTicker := time.NewTicker(200 * time.Millisecond) // speed of test execution
t.Cleanup(func() { testTicker.Stop() })

bannedOnce := false // true when the malicious peer was banned at least once
startTime := time.Now()

// Pull from channels
for {
select {
case err := <-errorsCh:
if err.peerID == "bad" { // ignore errors from the malicious peer
t.Log(err)
} else {
t.Error(err)
}
case request := <-requestsCh:
// Process request
peers[request.PeerID].inputChan <- inputData{t, pool, request}
case <-testTicker.C:
banned := pool.IsPeerBanned("bad")
bannedOnce = bannedOnce || banned // Keep bannedOnce true, even if the malicious peer gets unbanned
caughtUp := pool.IsCaughtUp()
// Success: pool caught up and malicious peer was banned at least once
Expand Down
19 changes: 18 additions & 1 deletion types/part_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ var (
ErrPartInvalidSize = errors.New("error inner part with invalid size")
)

// ErrInvalidPart is an error type for invalid parts.
type ErrInvalidPart struct {
Reason error
}

func (e ErrInvalidPart) Error() string {
return fmt.Sprintf("invalid part: %v", e.Reason)
}

func (e ErrInvalidPart) Unwrap() error {
return e.Reason
}

type Part struct {
Index uint32 `json:"index"`
Bytes cmtbytes.HexBytes `json:"bytes"`
Expand All @@ -37,8 +50,11 @@ func (part *Part) ValidateBasic() error {
if int64(part.Index) < part.Proof.Total-1 && len(part.Bytes) != int(BlockPartSizeBytes) {
return ErrPartInvalidSize
}
if int64(part.Index) != part.Proof.Index {
return ErrInvalidPart{Reason: fmt.Errorf("part index %d != proof index %d", part.Index, part.Proof.Index)}
}
if err := part.Proof.ValidateBasic(); err != nil {
return fmt.Errorf("wrong Proof: %w", err)
return ErrInvalidPart{Reason: fmt.Errorf("wrong Proof: %w", err)}
}
return nil
}
Expand Down Expand Up @@ -275,6 +291,7 @@ func (ps *PartSet) Total() uint32 {
return ps.total
}

// CONTRACT: part is validated using ValidateBasic.
func (ps *PartSet) AddPart(part *Part) (bool, error) {
// TODO: remove this? would be preferable if this only returned (false, nil)
// when its a duplicate block part
Expand Down
8 changes: 7 additions & 1 deletion types/part_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestPartSetHeaderValidateBasic(t *testing.T) {
}
}

func TestPartValidateBasic(t *testing.T) {
func TestPart_ValidateBasic(t *testing.T) {
testCases := []struct {
testName string
malleatePart func(*Part)
Expand All @@ -137,6 +137,7 @@ func TestPartValidateBasic(t *testing.T) {
pt.Index = 1
pt.Bytes = make([]byte, BlockPartSizeBytes-1)
pt.Proof.Total = 2
pt.Proof.Index = 1
}, false},
{"Too small inner part", func(pt *Part) {
pt.Index = 0
Expand All @@ -149,6 +150,11 @@ func TestPartValidateBasic(t *testing.T) {
Index: 1,
LeafHash: make([]byte, 1024*1024),
}
pt.Index = 1
}, true},
{"Index mismatch", func(pt *Part) {
pt.Index = 1
pt.Proof.Index = 0
}, true},
}

Expand Down

0 comments on commit 1964da4

Please sign in to comment.