Skip to content

Commit

Permalink
Support value mutation from non-readonly iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
fxamacker committed Sep 29, 2023
1 parent d3de291 commit e88a73e
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 66 deletions.
55 changes: 32 additions & 23 deletions array.go
Original file line number Diff line number Diff line change
Expand Up @@ -3112,12 +3112,13 @@ func (a *Array) Storable(_ SlabStorage, _ Address, maxInlineSize uint64) (Storab
var emptyArrayIterator = &ArrayIterator{}

type ArrayIterator struct {
storage SlabStorage
id SlabID
dataSlab *ArrayDataSlab
index int
remainingCount int
readOnly bool
array *Array
id SlabID
dataSlab *ArrayDataSlab
indexInArray int
indexInDataSlab int
remainingCount int
readOnly bool
}

func (i *ArrayIterator) Next() (Value, error) {
Expand All @@ -3130,7 +3131,7 @@ func (i *ArrayIterator) Next() (Value, error) {
return nil, nil
}

slab, found, err := i.storage.Retrieve(i.id)
slab, found, err := i.array.Storage.Retrieve(i.id)
if err != nil {
// Wrap err as external error (if needed) because err is returned by SlabStorage interface.
return nil, wrapErrorfAsExternalErrorIfNeeded(err, fmt.Sprintf("failed to retrieve slab %s", i.id))
Expand All @@ -3140,22 +3141,29 @@ func (i *ArrayIterator) Next() (Value, error) {
}

i.dataSlab = slab.(*ArrayDataSlab)
i.index = 0
i.indexInDataSlab = 0
}

var element Value
var err error
if i.index < len(i.dataSlab.elements) {
element, err = i.dataSlab.elements[i.index].StoredValue(i.storage)
if i.indexInDataSlab < len(i.dataSlab.elements) {
element, err = i.dataSlab.elements[i.indexInDataSlab].StoredValue(i.array.Storage)
if err != nil {
// Wrap err as external error (if needed) because err is returned by Storable interface.
return nil, wrapErrorfAsExternalErrorIfNeeded(err, "failed to get storable's stored value")
}

i.index++
if !i.readOnly {
// Set up notification callback in child value so
// when child value is modified parent a is notified.
i.array.setCallbackWithChild(uint64(i.indexInArray), element, maxInlineArrayElementSize)
}

i.indexInDataSlab++
i.indexInArray++
}

if i.index >= len(i.dataSlab.elements) {
if i.indexInDataSlab >= len(i.dataSlab.elements) {
i.id = i.dataSlab.next
i.dataSlab = nil
}
Expand All @@ -3173,7 +3181,7 @@ func (a *Array) Iterator() (*ArrayIterator, error) {
}

return &ArrayIterator{
storage: a.Storage,
array: a,
id: slab.SlabID(),
dataSlab: slab,
remainingCount: int(a.Count()),
Expand Down Expand Up @@ -3235,11 +3243,12 @@ func (a *Array) RangeIterator(startIndex uint64, endIndex uint64) (*ArrayIterato
}

return &ArrayIterator{
storage: a.Storage,
id: dataSlab.SlabID(),
dataSlab: dataSlab,
index: int(index),
remainingCount: int(numberOfElements),
array: a,
id: dataSlab.SlabID(),
dataSlab: dataSlab,
indexInArray: int(startIndex),
indexInDataSlab: int(index),
remainingCount: int(numberOfElements),
}, nil
}

Expand All @@ -3254,7 +3263,7 @@ func (a *Array) ReadOnlyRangeIterator(startIndex uint64, endIndex uint64) (*Arra

type ArrayIterationFunc func(element Value) (resume bool, err error)

func iterate(iterator *ArrayIterator, fn ArrayIterationFunc) error {
func iterateArray(iterator *ArrayIterator, fn ArrayIterationFunc) error {
for {
value, err := iterator.Next()
if err != nil {
Expand All @@ -3281,7 +3290,7 @@ func (a *Array) Iterate(fn ArrayIterationFunc) error {
// Don't need to wrap error as external error because err is already categorized by Array.Iterator().
return err
}
return iterate(iterator, fn)
return iterateArray(iterator, fn)
}

func (a *Array) IterateReadOnly(fn ArrayIterationFunc) error {
Expand All @@ -3290,7 +3299,7 @@ func (a *Array) IterateReadOnly(fn ArrayIterationFunc) error {
// Don't need to wrap error as external error because err is already categorized by Array.ReadOnlyIterator().
return err
}
return iterate(iterator, fn)
return iterateArray(iterator, fn)
}

func (a *Array) IterateRange(startIndex uint64, endIndex uint64, fn ArrayIterationFunc) error {
Expand All @@ -3299,7 +3308,7 @@ func (a *Array) IterateRange(startIndex uint64, endIndex uint64, fn ArrayIterati
// Don't need to wrap error as external error because err is already categorized by Array.RangeIterator().
return err
}
return iterate(iterator, fn)
return iterateArray(iterator, fn)
}

func (a *Array) IterateReadOnlyRange(startIndex uint64, endIndex uint64, fn ArrayIterationFunc) error {
Expand All @@ -3308,7 +3317,7 @@ func (a *Array) IterateReadOnlyRange(startIndex uint64, endIndex uint64, fn Arra
// Don't need to wrap error as external error because err is already categorized by Array.ReadOnlyRangeIterator().
return err
}
return iterate(iterator, fn)
return iterateArray(iterator, fn)
}

func (a *Array) Count() uint64 {
Expand Down
125 changes: 125 additions & 0 deletions array_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,67 @@ func TestArrayIterate(t *testing.T) {

require.Equal(t, count/2, i)
})

t.Run("mutation", func(t *testing.T) {
SetThreshold(256)
defer SetThreshold(1024)

const arraySize = 15

typeInfo := testTypeInfo{42}
storage := newTestPersistentStorage(t)
address := Address{1, 2, 3, 4, 5, 6, 7, 8}

array, err := NewArray(storage, address, typeInfo)
require.NoError(t, err)

expectedValues := make([]Value, arraySize)
for i := uint64(0); i < arraySize; i++ {
childArray, err := NewArray(storage, address, typeInfo)
require.NoError(t, err)

v := Uint64Value(i)
err = childArray.Append(v)
require.NoError(t, err)

err = array.Append(childArray)
require.NoError(t, err)

expectedValues[i] = arrayValue{v}
}
require.True(t, array.root.IsData())

sizeBeforeMutation := array.root.Header().size

i := 0
newElement := Uint64Value(0)
err = array.Iterate(func(v Value) (bool, error) {
childArray, ok := v.(*Array)
require.True(t, ok)
require.Equal(t, uint64(1), childArray.Count())
require.True(t, childArray.Inlined())

err := childArray.Append(newElement)
require.NoError(t, err)

expectedChildArrayValues, ok := expectedValues[i].(arrayValue)
require.True(t, ok)

expectedChildArrayValues = append(expectedChildArrayValues, newElement)
expectedValues[i] = expectedChildArrayValues

i++

require.Equal(t, array.root.Header().size, sizeBeforeMutation+uint32(i)*newElement.ByteSize())

return true, nil
})
require.NoError(t, err)
require.Equal(t, arraySize, i)
require.True(t, array.root.IsData())

verifyArray(t, storage, typeInfo, address, array, expectedValues, false)
})
}

func testArrayIterateRange(t *testing.T, array *Array, values []Value) {
Expand Down Expand Up @@ -1058,6 +1119,70 @@ func TestArrayIterateRange(t *testing.T) {
require.Equal(t, testErr, externalError.Unwrap())
require.Equal(t, count/2, i)
})

t.Run("mutation", func(t *testing.T) {
SetThreshold(256)
defer SetThreshold(1024)

const arraySize = 15

typeInfo := testTypeInfo{42}
storage := newTestPersistentStorage(t)
address := Address{1, 2, 3, 4, 5, 6, 7, 8}

array, err := NewArray(storage, address, typeInfo)
require.NoError(t, err)

expectedValues := make([]Value, arraySize)
for i := uint64(0); i < arraySize; i++ {
childArray, err := NewArray(storage, address, typeInfo)
require.NoError(t, err)

v := Uint64Value(i)
err = childArray.Append(v)
require.NoError(t, err)

err = array.Append(childArray)
require.NoError(t, err)

expectedValues[i] = arrayValue{v}
}
require.True(t, array.root.IsData())

sizeBeforeMutation := array.root.Header().size

i := 0
startIndex := uint64(1)
endIndex := array.Count() - 2
newElement := Uint64Value(0)
err = array.IterateRange(startIndex, endIndex, func(v Value) (bool, error) {
childArray, ok := v.(*Array)
require.True(t, ok)
require.Equal(t, uint64(1), childArray.Count())
require.True(t, childArray.Inlined())

err := childArray.Append(newElement)
require.NoError(t, err)

index := int(startIndex) + i
expectedChildArrayValues, ok := expectedValues[index].(arrayValue)
require.True(t, ok)

expectedChildArrayValues = append(expectedChildArrayValues, newElement)
expectedValues[index] = expectedChildArrayValues

i++

require.Equal(t, array.root.Header().size, sizeBeforeMutation+uint32(i)*newElement.ByteSize())

return true, nil
})
require.NoError(t, err)
require.Equal(t, endIndex-startIndex, uint64(i))
require.True(t, array.root.IsData())

verifyArray(t, storage, typeInfo, address, array, expectedValues, false)
})
}

func TestArrayRootSlabID(t *testing.T) {
Expand Down
Loading

0 comments on commit e88a73e

Please sign in to comment.