Skip to content

Commit

Permalink
parallel ForEach unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
i-norden committed Oct 17, 2023
1 parent 15c1451 commit f65a943
Showing 1 changed file with 195 additions and 3 deletions.
198 changes: 195 additions & 3 deletions amt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@ func init() {
}

var (
bitWidths3 = []uint{8}
bitWidths2to18 = []uint{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}
bitWidths2to3 = []uint{2, 3}
)

func runTestWithBitWidthsOnly(t *testing.T, bitwidths []uint, fn func(*testing.T, ...Option)) {
t.Helper()
for _, bw := range bitwidths {
t.Run(fmt.Sprintf("bitwidth=%d", bw), func(t *testing.T) { fn(t, UseTreeBitWidth(bw)) })
}
}

func runTestWithBitWidths(t *testing.T, bitwidths []uint, fn func(*testing.T, ...Option)) {
t.Helper()
if testing.Short() {
Expand Down Expand Up @@ -62,7 +70,7 @@ func newMockBlocks() *mockBlocks {
return &mockBlocks{make(map[cid.Cid]block.Block), sync.Mutex{}, 0, 0}
}

func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) {
func (mb *mockBlocks) Get(ctx context.Context, c cid.Cid) (block.Block, error) {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
d, ok := mb.data[c]
Expand All @@ -73,14 +81,41 @@ func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) {
return nil, fmt.Errorf("Not Found")
}

func (mb *mockBlocks) Put(b block.Block) error {
func (mb *mockBlocks) GetMany(ctx context.Context, cs []cid.Cid) ([]block.Block, []cid.Cid, error) {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
blocks := make([]block.Block, 0, len(cs))
missingCIDs := make([]cid.Cid, 0, len(cs))
for _, c := range cs {
mb.getCount++
d, ok := mb.data[c]
if !ok {
missingCIDs = append(missingCIDs, c)
} else {
blocks = append(blocks, d)
}
}
return blocks, missingCIDs, nil
}

func (mb *mockBlocks) Put(ctx context.Context, b block.Block) error {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
mb.putCount++
mb.data[b.Cid()] = b
return nil
}

func (mb *mockBlocks) PutMany(ctx context.Context, bs []block.Block) error {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
for _, b := range bs {
mb.putCount++
mb.data[b.Cid()] = b
}
return nil
}

func (mb *mockBlocks) report(b *testing.B) {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
Expand Down Expand Up @@ -342,6 +377,7 @@ func TestForEachWithoutFlush(t *testing.T) {
require.NoError(t, err)
set1 := make(map[uint64]struct{})
set2 := make(map[uint64]struct{})
set3 := make(map[uint64]struct{})
for _, val := range vals {
err := amt.Set(ctx, val, cborstr(""))
require.NoError(t, err)
Expand All @@ -357,14 +393,23 @@ func TestForEachWithoutFlush(t *testing.T) {
assert.Equal(t, make(map[uint64]struct{}), set1)

// ensure it still works after flush
_, err = amt.Flush(ctx)
c, err := amt.Flush(ctx)
require.NoError(t, err)

amt.ForEach(ctx, func(u uint64, deferred *cbg.Deferred) error {
delete(set2, u)
return nil
})
assert.Equal(t, make(map[uint64]struct{}), set2)

// ensure that it works with a loaded AMT
loadedAMT, err := LoadAMT(ctx, bs, c, opts...)
err = loadedAMT.ForEach(ctx, func(u uint64, deferred *cbg.Deferred) error {
delete(set3, u)
return nil
})
require.NoError(t, err)
assert.Equal(t, make(map[uint64]struct{}), set3)
}
})
}
Expand Down Expand Up @@ -794,6 +839,94 @@ func TestForEach(t *testing.T) {
})
}

func TestForEachParallel(t *testing.T) {
bs := cbor.NewGetManyCborStore(newMockBlocks())
ctx := context.Background()
a, err := NewAMT(bs)
require.NoError(t, err)

r := rand.New(rand.NewSource(101))

indexes := make(map[uint64]struct{})
for i := 0; i < 10000; i++ {
if r.Intn(2) == 0 {
indexes[uint64(i)] = struct{}{}
}
}

for i := range indexes {
if err := a.Set(ctx, i, cborstr("value")); err != nil {
t.Fatal(err)
}
}

for i := range indexes {
assertGet(ctx, t, a, i, "value")
}

assertCount(t, a, uint64(len(indexes)))

// test before flush
m := sync.Mutex{}
foundVals := make(map[uint64]struct{})
err = a.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
m.Lock()
foundVals[i] = struct{}{}
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(foundVals) != len(indexes) {
t.Fatal("didnt see enough values")
}

c, err := a.Flush(ctx)
if err != nil {
t.Fatal(err)
}

assertCount(t, a, uint64(len(indexes)))

// test after flush
foundVals = make(map[uint64]struct{})
err = a.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
m.Lock()
foundVals[i] = struct{}{}
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(foundVals) != len(indexes) {
t.Fatal("didnt see enough values")
}

na, err := LoadAMT(ctx, bs, c)
if err != nil {
t.Fatal(err)
}

assertCount(t, na, uint64(len(indexes)))

// test from loaded AMT
foundVals = make(map[uint64]struct{})
err = na.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
m.Lock()
foundVals[i] = struct{}{}
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(foundVals) != len(indexes) {
t.Fatal("didnt see enough values")
}
}

func TestForEachAt(t *testing.T) {
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
bs := cbor.NewCborStore(newMockBlocks())
Expand Down Expand Up @@ -858,6 +991,65 @@ func TestForEachAt(t *testing.T) {
})
}

func TestForEachAtParallel(t *testing.T) {
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
bs := cbor.NewGetManyCborStore(newMockBlocks())
ctx := context.Background()
a, err := NewAMT(bs, opts...)
require.NoError(t, err)

r := rand.New(rand.NewSource(101))

var indexes []uint64
for i := 0; i < cbg.MaxLength; i++ { // above bitwidth 13, inserting more than cbg.MaxLength causes node.Values to exceed the cbg.MaxLength
indexes = append(indexes, uint64(i))
if err := a.Set(ctx, uint64(i), cborstr(fmt.Sprint(i))); err != nil {
t.Fatal(err)
}
}

for _, i := range indexes {
assertGet(ctx, t, a, i, fmt.Sprint(i))
}

assertCount(t, a, uint64(len(indexes)))

c, err := a.Flush(ctx)
if err != nil {
t.Fatal(err)
}

na, err := LoadAMT(ctx, bs, c, opts...)
if err != nil {
t.Fatal(err)
}

assertCount(t, na, uint64(len(indexes)))
m := sync.Mutex{}
for try := 0; try < 10; try++ {
start := uint64(r.Intn(cbg.MaxLength))

expectedIndexes := make(map[uint64]struct{})
for i := start; i < cbg.MaxLength; i++ {
expectedIndexes[i] = struct{}{}
}

err = na.ForEachAtParallel(ctx, 16, start, func(i uint64, v *cbg.Deferred) error {
m.Lock()
delete(expectedIndexes, i)
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(expectedIndexes) != 0 {
t.Fatal("didnt see enough values")
}
}
})
}

func TestFirstSetIndex(t *testing.T) {
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
bs := cbor.NewCborStore(newMockBlocks())
Expand Down

0 comments on commit f65a943

Please sign in to comment.