Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(experiment): add a parallel AMT traversal function #84

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions amt.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ func (r *Root) ForEachAt(ctx context.Context, start uint64, cb func(uint64, *cbg
return r.node.forEachAt(ctx, r.store, r.bitWidth, r.height, start, 0, cb)
}

func (r *Root) ForEachParallel(ctx context.Context, concurrency int, cb func(uint64, *cbg.Deferred) error) error {
return r.node.forEachAtParallel(ctx, r.store, r.bitWidth, r.height, 0, 0, cb, concurrency)
}

func (r *Root) ForEachAtParallel(ctx context.Context, concurrency int, start uint64, cb func(uint64, *cbg.Deferred) error) error {
return r.node.forEachAtParallel(ctx, r.store, r.bitWidth, r.height, start, 0, cb, concurrency)
}

// FirstSetIndex finds the lowest index in this AMT that has a value set for
// it. If this operation is called on an empty AMT, an ErrNoValues will be
// returned.
Expand Down
197 changes: 194 additions & 3 deletions amt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ var (
bitWidths2to3 = []uint{2, 3}
)

func runTestWithBitWidthsOnly(t *testing.T, bitwidths []uint, fn func(*testing.T, ...Option)) {
t.Helper()
for _, bw := range bitwidths {
t.Run(fmt.Sprintf("bitwidth=%d", bw), func(t *testing.T) { fn(t, UseTreeBitWidth(bw)) })
}
}

func runTestWithBitWidths(t *testing.T, bitwidths []uint, fn func(*testing.T, ...Option)) {
t.Helper()
if testing.Short() {
Expand Down Expand Up @@ -62,7 +69,7 @@ func newMockBlocks() *mockBlocks {
return &mockBlocks{make(map[cid.Cid]block.Block), sync.Mutex{}, 0, 0}
}

func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) {
func (mb *mockBlocks) Get(ctx context.Context, c cid.Cid) (block.Block, error) {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
d, ok := mb.data[c]
Expand All @@ -73,14 +80,41 @@ func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) {
return nil, fmt.Errorf("Not Found")
}

func (mb *mockBlocks) Put(b block.Block) error {
func (mb *mockBlocks) GetMany(ctx context.Context, cs []cid.Cid) ([]block.Block, []cid.Cid, error) {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
blocks := make([]block.Block, 0, len(cs))
missingCIDs := make([]cid.Cid, 0, len(cs))
for _, c := range cs {
mb.getCount++
d, ok := mb.data[c]
if !ok {
missingCIDs = append(missingCIDs, c)
} else {
blocks = append(blocks, d)
}
}
return blocks, missingCIDs, nil
}

func (mb *mockBlocks) Put(ctx context.Context, b block.Block) error {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
mb.putCount++
mb.data[b.Cid()] = b
return nil
}

func (mb *mockBlocks) PutMany(ctx context.Context, bs []block.Block) error {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
for _, b := range bs {
mb.putCount++
mb.data[b.Cid()] = b
}
return nil
}

func (mb *mockBlocks) report(b *testing.B) {
mb.dataMu.Lock()
defer mb.dataMu.Unlock()
Expand Down Expand Up @@ -342,6 +376,7 @@ func TestForEachWithoutFlush(t *testing.T) {
require.NoError(t, err)
set1 := make(map[uint64]struct{})
set2 := make(map[uint64]struct{})
set3 := make(map[uint64]struct{})
for _, val := range vals {
err := amt.Set(ctx, val, cborstr(""))
require.NoError(t, err)
Expand All @@ -357,14 +392,23 @@ func TestForEachWithoutFlush(t *testing.T) {
assert.Equal(t, make(map[uint64]struct{}), set1)

// ensure it still works after flush
_, err = amt.Flush(ctx)
c, err := amt.Flush(ctx)
require.NoError(t, err)

amt.ForEach(ctx, func(u uint64, deferred *cbg.Deferred) error {
delete(set2, u)
return nil
})
assert.Equal(t, make(map[uint64]struct{}), set2)

// ensure that it works with a loaded AMT
loadedAMT, err := LoadAMT(ctx, bs, c, opts...)
err = loadedAMT.ForEach(ctx, func(u uint64, deferred *cbg.Deferred) error {
delete(set3, u)
return nil
})
require.NoError(t, err)
assert.Equal(t, make(map[uint64]struct{}), set3)
}
})
}
Expand Down Expand Up @@ -794,6 +838,94 @@ func TestForEach(t *testing.T) {
})
}

func TestForEachParallel(t *testing.T) {
bs := cbor.NewGetManyCborStore(newMockBlocks())
ctx := context.Background()
a, err := NewAMT(bs)
require.NoError(t, err)

r := rand.New(rand.NewSource(101))

indexes := make(map[uint64]struct{})
for i := 0; i < 10000; i++ {
if r.Intn(2) == 0 {
indexes[uint64(i)] = struct{}{}
}
}

for i := range indexes {
if err := a.Set(ctx, i, cborstr("value")); err != nil {
t.Fatal(err)
}
}

for i := range indexes {
assertGet(ctx, t, a, i, "value")
}

assertCount(t, a, uint64(len(indexes)))

// test before flush
m := sync.Mutex{}
foundVals := make(map[uint64]struct{})
err = a.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
m.Lock()
foundVals[i] = struct{}{}
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(foundVals) != len(indexes) {
t.Fatal("didnt see enough values")
}

c, err := a.Flush(ctx)
if err != nil {
t.Fatal(err)
}

assertCount(t, a, uint64(len(indexes)))

// test after flush
foundVals = make(map[uint64]struct{})
err = a.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
m.Lock()
foundVals[i] = struct{}{}
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(foundVals) != len(indexes) {
t.Fatal("didnt see enough values")
}

na, err := LoadAMT(ctx, bs, c)
if err != nil {
t.Fatal(err)
}

assertCount(t, na, uint64(len(indexes)))

// test from loaded AMT
foundVals = make(map[uint64]struct{})
err = na.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
m.Lock()
foundVals[i] = struct{}{}
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(foundVals) != len(indexes) {
t.Fatal("didnt see enough values")
}
}

func TestForEachAt(t *testing.T) {
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
bs := cbor.NewCborStore(newMockBlocks())
Expand Down Expand Up @@ -858,6 +990,65 @@ func TestForEachAt(t *testing.T) {
})
}

func TestForEachAtParallel(t *testing.T) {
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
bs := cbor.NewGetManyCborStore(newMockBlocks())
ctx := context.Background()
a, err := NewAMT(bs, opts...)
require.NoError(t, err)

r := rand.New(rand.NewSource(101))

var indexes []uint64
for i := 0; i < cbg.MaxLength; i++ { // above bitwidth 13, inserting more than cbg.MaxLength causes node.Values to exceed the cbg.MaxLength
indexes = append(indexes, uint64(i))
if err := a.Set(ctx, uint64(i), cborstr(fmt.Sprint(i))); err != nil {
t.Fatal(err)
}
}

for _, i := range indexes {
assertGet(ctx, t, a, i, fmt.Sprint(i))
}

assertCount(t, a, uint64(len(indexes)))

c, err := a.Flush(ctx)
if err != nil {
t.Fatal(err)
}

na, err := LoadAMT(ctx, bs, c, opts...)
if err != nil {
t.Fatal(err)
}

assertCount(t, na, uint64(len(indexes)))
m := sync.Mutex{}
for try := 0; try < 10; try++ {
start := uint64(r.Intn(cbg.MaxLength))

expectedIndexes := make(map[uint64]struct{})
for i := start; i < cbg.MaxLength; i++ {
expectedIndexes[i] = struct{}{}
}

err = na.ForEachAtParallel(ctx, 16, start, func(i uint64, v *cbg.Deferred) error {
m.Lock()
delete(expectedIndexes, i)
m.Unlock()
return nil
})
if err != nil {
t.Fatal(err)
}
if len(expectedIndexes) != 0 {
t.Fatal("didnt see enough values")
}
}
})
}

func TestFirstSetIndex(t *testing.T) {
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
bs := cbor.NewCborStore(newMockBlocks())
Expand Down
40 changes: 22 additions & 18 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,37 @@ module github.com/filecoin-project/go-amt-ipld/v4
go 1.20

require (
github.com/ipfs/go-block-format v0.0.2
github.com/ipfs/go-cid v0.0.7
github.com/ipfs/go-ipld-cbor v0.0.4
github.com/ipfs/go-block-format v0.1.2
github.com/ipfs/go-cid v0.4.1
github.com/ipfs/go-ipld-cbor v0.1.0
github.com/stretchr/testify v1.7.0
github.com/whyrusleeping/cbor-gen v0.0.0-20220323183124-98fa8256a799
github.com/whyrusleeping/cbor-gen v0.0.0-20230818171029-f91ae536ca25
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2
)

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/ipfs/go-ipfs-util v0.0.1 // indirect
github.com/ipfs/go-ipld-format v0.0.1 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/ipfs/go-ipfs-util v0.0.2 // indirect
github.com/ipfs/go-ipld-format v0.5.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.5 // indirect
github.com/kr/pretty v0.1.0 // indirect
github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect
github.com/minio/sha256-simd v0.1.1-0.20190913151208-6de447530771 // indirect
github.com/mr-tron/base58 v1.1.3 // indirect
github.com/multiformats/go-base32 v0.0.3 // indirect
github.com/multiformats/go-base36 v0.1.0 // indirect
github.com/multiformats/go-multibase v0.0.3 // indirect
github.com/multiformats/go-multihash v0.0.13 // indirect
github.com/multiformats/go-varint v0.0.5 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect
github.com/multiformats/go-base32 v0.1.0 // indirect
github.com/multiformats/go-base36 v0.2.0 // indirect
github.com/multiformats/go-multibase v0.2.0 // indirect
github.com/multiformats/go-multihash v0.2.3 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/polydawn/refmt v0.0.0-20190221155625-df39d6c2d992 // indirect
github.com/polydawn/refmt v0.89.0 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
golang.org/x/crypto v0.1.0 // indirect
golang.org/x/sys v0.1.0 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/sys v0.12.0 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v3 v3.0.0 // indirect
lukechampine.com/blake3 v1.2.1 // indirect
)

replace github.com/ipfs/go-ipld-cbor => github.com/vulcanize/go-ipld-cbor v0.1.1-internal-0.0.1
Loading