Skip to content

Commit

Permalink
Merge pull request #215 from onflow/fxamacker/fix-omt-has
Browse files Browse the repository at this point in the history
Make OrderedMap.Has distinguish KeyNotFoundError
  • Loading branch information
fxamacker authored Nov 19, 2021
2 parents ee5b67b + b3e3b5c commit 6a06b64
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 36 deletions.
9 changes: 8 additions & 1 deletion map.go
Original file line number Diff line number Diff line change
Expand Up @@ -3386,7 +3386,14 @@ func NewMapWithRootID(storage SlabStorage, rootID StorageID, digestBuilder Diges

func (m *OrderedMap) Has(comparator ValueComparator, hip HashInputProvider, key Value) (bool, error) {
_, err := m.Get(comparator, hip, key)
return err == nil, nil
if err != nil {
var knf *KeyNotFoundError
if errors.As(err, &knf) {
return false, nil
}
return false, err
}
return true, nil
}

func (m *OrderedMap) Get(comparator ValueComparator, hip HashInputProvider, key Value) (Storable, error) {
Expand Down
107 changes: 72 additions & 35 deletions map_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package atree

import (
"errors"
"fmt"
"math/rand"
"reflect"
Expand Down Expand Up @@ -70,6 +71,23 @@ func (d mockDigester) Levels() int {

func (d mockDigester) Reset() {}

type errorDigesterBuilder struct {
err error
}

var _ DigesterBuilder = &errorDigesterBuilder{}

func newErrorDigesterBuilder(err error) *errorDigesterBuilder {
return &errorDigesterBuilder{err: err}
}

func (h *errorDigesterBuilder) SetSeed(_ uint64, _ uint64) {
}

func (h *errorDigesterBuilder) Digest(hip HashInputProvider, value Value) (Digester, error) {
return nil, h.err
}

func verifyEmptyMap(
t *testing.T,
storage *PersistentSlabStorage,
Expand Down Expand Up @@ -351,53 +369,72 @@ func TestMapSetAndGet(t *testing.T) {

func TestMapHas(t *testing.T) {

const (
mapSize = 2048
keyStringSize = 16
)
t.Run("no error", func(t *testing.T) {
const (
mapSize = 2048
keyStringSize = 16
)

r := newRand(t)
r := newRand(t)

keys := make(map[Value]bool, mapSize*2)
keysToInsert := make([]Value, 0, mapSize)
keysToNotInsert := make([]Value, 0, mapSize)
for len(keysToInsert) < mapSize || len(keysToNotInsert) < mapSize {
k := NewStringValue(randStr(r, keyStringSize))
if !keys[k] {
keys[k] = true
keys := make(map[Value]bool, mapSize*2)
keysToInsert := make([]Value, 0, mapSize)
keysToNotInsert := make([]Value, 0, mapSize)
for len(keysToInsert) < mapSize || len(keysToNotInsert) < mapSize {
k := NewStringValue(randStr(r, keyStringSize))
if !keys[k] {
keys[k] = true

if len(keysToInsert) < mapSize {
keysToInsert = append(keysToInsert, k)
} else {
keysToNotInsert = append(keysToNotInsert, k)
if len(keysToInsert) < mapSize {
keysToInsert = append(keysToInsert, k)
} else {
keysToNotInsert = append(keysToNotInsert, k)
}
}
}
}

typeInfo := testTypeInfo{42}
address := Address{1, 2, 3, 4, 5, 6, 7, 8}
storage := newTestPersistentStorage(t)

m, err := NewMap(storage, address, newBasicDigesterBuilder(), typeInfo)
require.NoError(t, err)
typeInfo := testTypeInfo{42}
address := Address{1, 2, 3, 4, 5, 6, 7, 8}
storage := newTestPersistentStorage(t)

for i, k := range keysToInsert {
existingStorable, err := m.Set(compare, hashInputProvider, k, Uint64Value(i))
m, err := NewMap(storage, address, newBasicDigesterBuilder(), typeInfo)
require.NoError(t, err)
require.Nil(t, existingStorable)
}

for _, k := range keysToInsert {
exist, err := m.Has(compare, hashInputProvider, k)
require.NoError(t, err)
require.True(t, exist)
}
for i, k := range keysToInsert {
existingStorable, err := m.Set(compare, hashInputProvider, k, Uint64Value(i))
require.NoError(t, err)
require.Nil(t, existingStorable)
}

for _, k := range keysToNotInsert {
exist, err := m.Has(compare, hashInputProvider, k)
for _, k := range keysToInsert {
exist, err := m.Has(compare, hashInputProvider, k)
require.NoError(t, err)
require.True(t, exist)
}

for _, k := range keysToNotInsert {
exist, err := m.Has(compare, hashInputProvider, k)
require.NoError(t, err)
require.False(t, exist)
}
})

t.Run("error", func(t *testing.T) {

typeInfo := testTypeInfo{42}
address := Address{1, 2, 3, 4, 5, 6, 7, 8}
storage := newTestPersistentStorage(t)

testErr := errors.New("test")
digesterBuilder := newErrorDigesterBuilder(testErr)

m, err := NewMap(storage, address, digesterBuilder, typeInfo)
require.NoError(t, err)

exist, err := m.Has(compare, hashInputProvider, Uint64Value(0))
require.Equal(t, testErr, err)
require.False(t, exist)
}
})
}

func TestMapRemove(t *testing.T) {
Expand Down

0 comments on commit 6a06b64

Please sign in to comment.