Skip to content

Commit 0d77204

Browse files
mergify[bot]evan-forbestzdybalrach-id
authored
feat: parallel binary merkle root and proof generation (backport #2366) (#2462)
PR to make the proof generation 15x faster (for 64 KiB leaves and 16 threads) and merkle root generation 12x faster (also only for 64KiB and ofc 16 threads) leaves but use more memory. Most of the benefit comes from hashing the leaves in parallel. this PR is almost entirely AI generated, however note the fuzzers and correctness tests that compare against the old implementation. benchmark for root calculation: ``` cpu: AMD Ryzen 7 7840U w/ Radeon 780M Graphics BenchmarkParallelComparison/Original_1000_64KiB_leaves-16 38 30281229 ns/op 64096 B/op 2000 allocs/op BenchmarkParallelComparison/_1000_64KiB_leaves-16 484 2489376 ns/op 254185 B/op 4076 allocs/op BenchmarkParallelComparison/Original_100_000_2KiB_leaves-16 9 114582639 ns/op 6400096 B/op 200000 allocs/op BenchmarkParallelComparison/_100_000_2KiB_leaves-16 58 20056233 ns/op 24808424 B/op 400076 allocs/op BenchmarkParallelComparison/Original_1000_1KiB_leaves-16 3314 346491 ns/op 32096 B/op 1000 allocs/op BenchmarkParallelComparison/_1000_1KiB_leaves-16 6133 175902 ns/op 130487 B/op 2076 allocs/op BenchmarkParallelComparison/Original_1000_32B_leaves-16 5510 208403 ns/op 64097 B/op 2000 allocs/op BenchmarkParallelComparison/_1000_32B_leaves-16 4570 225208 ns/op 254107 B/op 4076 allocs/op PASS ok github.com/cometbft/cometbft/crypto/merkle 11.189s ``` benchmark for proof generation (only 4000 64KiB leaves) ``` BenchmarkParallelProofGeneration/4000_64KiB_leaves_Original_ProofsFromByteSlices-16 7 153370828 ns/op 1709.22 MB/s 299266662 B/op 52000 allocs/op BenchmarkParallelProofGeneration/4000_64KiB_leaves_Parallel_ProofsFromByteSlices-16 100 10183035 ns/op 25743.21 MB/s 4781977 B/op 64094 allocs/op ```<hr>This is an automatic backport of pull request #2366 done by [Mergify](https://mergify.com). Co-authored-by: Evan Forbes <[email protected]> Co-authored-by: Tomasz Zdybał <[email protected]> Co-authored-by: CHAMI Rachid <[email protected]>
1 parent 9656670 commit 0d77204

File tree

8 files changed

+1262
-113
lines changed

8 files changed

+1262
-113
lines changed

consensus/propagation/types/types.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,14 @@ func (c *CompactBlock) Proofs() ([]*merkle.Proof, error) {
208208

209209
c.proofsCache = make([]*merkle.Proof, 0, len(c.PartsHashes))
210210

211-
root, proofs := merkle.ProofsFromLeafHashes(c.PartsHashes[:total])
211+
root, proofs := merkle.ParallelProofsFromLeafHashes(c.PartsHashes[:total])
212212
c.proofsCache = append(c.proofsCache, proofs...)
213213

214214
if !bytes.Equal(root, c.Proposal.BlockID.PartSetHeader.Hash) {
215215
return c.proofsCache, fmt.Errorf("incorrect PartsHash: original root")
216216
}
217217

218-
parityRoot, eproofs := merkle.ProofsFromLeafHashes(c.PartsHashes[total:])
218+
parityRoot, eproofs := merkle.ParallelProofsFromLeafHashes(c.PartsHashes[total:])
219219
c.proofsCache = append(c.proofsCache, eproofs...)
220220

221221
if !bytes.Equal(c.BpHash, parityRoot) {

crypto/merkle/fuzz_test.go

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
package merkle
2+
3+
import (
4+
"bytes"
5+
"crypto/rand"
6+
"testing"
7+
"testing/quick"
8+
)
9+
10+
// FuzzParallelImplementations tests that all parallel implementations
11+
// produce identical results to the original sequential implementation
12+
func FuzzParallelImplementations(f *testing.F) {
13+
// Seed with diverse test cases
14+
f.Add([]byte{1, 2, 3, 4, 5})
15+
f.Add([]byte{0xFF, 0xAB, 0xCD, 0xEF})
16+
f.Add(make([]byte, 1000)) // Large zero-filled data
17+
18+
f.Fuzz(func(t *testing.T, data []byte) {
19+
if len(data) == 0 {
20+
return
21+
}
22+
23+
// Use data bytes to determine variable parameters for more robust testing
24+
dataIdx := 0
25+
nextByte := func() byte {
26+
if dataIdx >= len(data) {
27+
dataIdx = 0
28+
}
29+
b := data[dataIdx]
30+
dataIdx++
31+
return b
32+
}
33+
34+
// Vary the number of items (1 to 100)
35+
numItems := max(1, int(nextByte())%100+1)
36+
37+
// Vary leaf size ranges based on fuzz input
38+
leafSizeVariant := nextByte() % 5
39+
var minLeafSize, maxLeafSize int
40+
switch leafSizeVariant {
41+
case 0: // Small leaves (1-32 bytes)
42+
minLeafSize, maxLeafSize = 1, 32
43+
case 1: // Medium leaves (32-512 bytes)
44+
minLeafSize, maxLeafSize = 32, 512
45+
case 2: // Large leaves (512-8192 bytes)
46+
minLeafSize, maxLeafSize = 512, 8192
47+
case 3: // Mixed sizes (1-2048 bytes)
48+
minLeafSize, maxLeafSize = 1, 2048
49+
case 4: // Very large leaves (~64KB, like Celestia blocks)
50+
minLeafSize, maxLeafSize = 65536-1024, 65536+1024 // ~65KB ± 1KB
51+
}
52+
53+
// Create items with varying sizes
54+
items := make([][]byte, numItems)
55+
for i := 0; i < numItems; i++ {
56+
// Vary leaf size within the range
57+
leafSize := minLeafSize + int(nextByte())%(maxLeafSize-minLeafSize+1)
58+
59+
// Create leaf data by cycling through available data
60+
leaf := make([]byte, leafSize)
61+
for j := 0; j < leafSize; j++ {
62+
leaf[j] = nextByte()
63+
}
64+
items[i] = leaf
65+
}
66+
67+
// Test with the generated inputs
68+
testParallelCorrectness(t, items)
69+
70+
// Additional test with different leaf count but same data pattern
71+
if numItems > 1 {
72+
// Test with fewer items (stress different tree shapes)
73+
reducedCount := max(1, numItems/2)
74+
reducedItems := make([][]byte, reducedCount)
75+
for i := 0; i < reducedCount; i++ {
76+
reducedItems[i] = items[i*2%len(items)] // Sample every other item
77+
}
78+
testParallelCorrectness(t, reducedItems)
79+
}
80+
81+
// Test with power-of-2 and non-power-of-2 counts (different tree structures)
82+
if numItems >= 4 {
83+
// Power of 2 test
84+
powerOf2Count := 1
85+
for powerOf2Count < numItems {
86+
powerOf2Count *= 2
87+
}
88+
powerOf2Count /= 2 // Get largest power of 2 <= numItems
89+
90+
if powerOf2Count >= 2 && powerOf2Count != numItems {
91+
powerOf2Items := make([][]byte, powerOf2Count)
92+
for i := 0; i < powerOf2Count; i++ {
93+
powerOf2Items[i] = items[i%len(items)]
94+
}
95+
testParallelCorrectness(t, powerOf2Items)
96+
}
97+
}
98+
})
99+
}
100+
101+
func testParallelCorrectness(t *testing.T, items [][]byte) {
102+
// Get reference result from original implementation
103+
expected := HashFromByteSlices(items)
104+
105+
// Test the optimized parallel implementation
106+
implementations := map[string]func([][]byte) []byte{
107+
"ParallelHashFromByteSlices": ParallelHashFromByteSlices,
108+
}
109+
110+
for name, impl := range implementations {
111+
result := impl(items)
112+
if !bytes.Equal(expected, result) {
113+
t.Errorf("%s produced different result than HashFromByteSlices", name)
114+
t.Errorf("Expected: %x", expected)
115+
t.Errorf("Got: %x", result)
116+
t.Errorf("Items count: %d", len(items))
117+
}
118+
}
119+
120+
// Also test proof generation correctness
121+
testParallelProofCorrectness(t, items)
122+
}
123+
124+
// testParallelProofCorrectness tests that parallel proof generation
125+
// produces identical results to the original implementation
126+
func testParallelProofCorrectness(t *testing.T, items [][]byte) {
127+
// Get reference results from original implementation
128+
expectedRoot, expectedProofs := ProofsFromByteSlices(items)
129+
130+
// Test parallel proof generation
131+
actualRoot, actualProofs := ParallelProofsFromByteSlices(items)
132+
133+
// Root hashes must match
134+
if !bytes.Equal(expectedRoot, actualRoot) {
135+
t.Errorf("ParallelProofsFromByteSlices root hash differs from ProofsFromByteSlices")
136+
t.Errorf("Expected root: %x", expectedRoot)
137+
t.Errorf("Got root: %x", actualRoot)
138+
t.Errorf("Items count: %d", len(items))
139+
return
140+
}
141+
142+
// Number of proofs must match
143+
if len(expectedProofs) != len(actualProofs) {
144+
t.Errorf("ParallelProofsFromByteSlices proof count differs from ProofsFromByteSlices")
145+
t.Errorf("Expected: %d proofs", len(expectedProofs))
146+
t.Errorf("Got: %d proofs", len(actualProofs))
147+
return
148+
}
149+
150+
// Each proof must be identical
151+
for i := range expectedProofs {
152+
expected := expectedProofs[i]
153+
actual := actualProofs[i]
154+
155+
if expected.Total != actual.Total {
156+
t.Errorf("Proof %d Total differs: expected %d, got %d", i, expected.Total, actual.Total)
157+
}
158+
if expected.Index != actual.Index {
159+
t.Errorf("Proof %d Index differs: expected %d, got %d", i, expected.Index, actual.Index)
160+
}
161+
if !bytes.Equal(expected.LeafHash, actual.LeafHash) {
162+
t.Errorf("Proof %d LeafHash differs", i)
163+
}
164+
if len(expected.Aunts) != len(actual.Aunts) {
165+
t.Errorf("Proof %d Aunts count differs: expected %d, got %d", i, len(expected.Aunts), len(actual.Aunts))
166+
continue
167+
}
168+
for j := range expected.Aunts {
169+
if !bytes.Equal(expected.Aunts[j], actual.Aunts[j]) {
170+
t.Errorf("Proof %d Aunt %d differs", i, j)
171+
}
172+
}
173+
174+
// Verify the proof can verify against the root with the original item
175+
if err := actual.Verify(actualRoot, items[i]); err != nil {
176+
t.Errorf("Parallel proof %d failed verification: %v", i, err)
177+
}
178+
}
179+
}
180+
181+
// TestParallelImplementationsProperty uses property-based testing
182+
func TestParallelImplementationsProperty(t *testing.T) {
183+
property := func(numItems uint8, itemSize uint16) bool {
184+
if numItems == 0 || numItems > 100 || itemSize == 0 || itemSize > 1000 {
185+
return true // Skip invalid inputs
186+
}
187+
188+
items := make([][]byte, numItems)
189+
for i := range items {
190+
items[i] = make([]byte, itemSize)
191+
if _, err := rand.Read(items[i]); err != nil {
192+
t.Fatalf("Failed to read random data: %v", err)
193+
}
194+
}
195+
196+
expected := HashFromByteSlices(items)
197+
198+
// Test the optimized implementation
199+
implementations := []func([][]byte) []byte{
200+
ParallelHashFromByteSlices,
201+
}
202+
203+
for _, impl := range implementations {
204+
result := impl(items)
205+
if !bytes.Equal(expected, result) {
206+
return false
207+
}
208+
}
209+
210+
return true
211+
}
212+
213+
if err := quick.Check(property, nil); err != nil {
214+
t.Error(err)
215+
}
216+
}
217+
218+
// TestParallelImplementationsLargeDataset tests with realistic dataset sizes
219+
func TestParallelImplementationsLargeDataset(t *testing.T) {
220+
// Test with dataset similar to actual use case: 4000 items of 64KiB each
221+
const numItems = 128 // Reduced for testing, but same pattern
222+
const itemSize = 1024 * 64 // 1KiB for testing
223+
224+
items := make([][]byte, numItems)
225+
for i := range items {
226+
items[i] = make([]byte, itemSize)
227+
if _, err := rand.Read(items[i]); err != nil {
228+
t.Fatalf("Failed to read random data: %v", err)
229+
}
230+
}
231+
232+
expected := HashFromByteSlices(items)
233+
234+
// Test the optimized implementation
235+
implementations := map[string]func([][]byte) []byte{
236+
"ParallelHashFromByteSlices": ParallelHashFromByteSlices,
237+
}
238+
239+
for name, impl := range implementations {
240+
t.Run(name, func(t *testing.T) {
241+
result := impl(items)
242+
if !bytes.Equal(expected, result) {
243+
t.Errorf("%s produced incorrect result", name)
244+
}
245+
})
246+
}
247+
}

0 commit comments

Comments
 (0)