Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make result order deterministic #126

Merged
merged 8 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions pool/result_context_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@ import (
type ResultContextPool[T any] struct {
contextPool ContextPool
agg resultAggregator[T]
collectErrored bool
includeErrored bool
}

// Go submits a task to the pool. If all goroutines in the pool
// are busy, a call to Go() will block until the task can be started.
func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) {
idx := p.agg.nextIndex()
p.contextPool.Go(func(ctx context.Context) error {
res, err := f(ctx)
if err == nil || p.collectErrored {
p.agg.add(res)
}
p.agg.save(idx, res, err != nil)
return err
})
}
Expand All @@ -33,15 +32,15 @@ func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) {
// returns an error if any of the tasks errored.
func (p *ResultContextPool[T]) Wait() ([]T, error) {
err := p.contextPool.Wait()
return p.agg.results, err
return p.agg.collect(p.includeErrored), err
}

// WithCollectErrored configures the pool to still collect the result of a task
// even if the task returned an error. By default, the result of tasks that errored
// are ignored and only the error is collected.
func (p *ResultContextPool[T]) WithCollectErrored() *ResultContextPool[T] {
p.panicIfInitialized()
p.collectErrored = true
p.includeErrored = true
return p
}

Expand Down
2 changes: 0 additions & 2 deletions pool/result_context_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"sort"
"strconv"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -223,7 +222,6 @@ func TestResultContextPool(t *testing.T) {
})
}
res, err := g.Wait()
sort.Ints(res)
require.Equal(t, expected, res)
require.NoError(t, err)
require.Equal(t, int64(0), currentConcurrent.Load())
Expand Down
16 changes: 7 additions & 9 deletions pool/result_error_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,24 @@ import (
// type and an error. Tasks are executed in the pool with Go(), then the
// results of the tasks are returned by Wait().
//
// The order of the results is not guaranteed to be the same as the order the
// tasks were submitted. If your use case requires consistent ordering,
// consider using the `stream` package or `Map` from the `iter` package.
// The order of the results is guaranteed to be the same as the order the
// tasks were submitted.
//
// The configuration methods (With*) will panic if they are used after calling
// Go() for the first time.
type ResultErrorPool[T any] struct {
errorPool ErrorPool
agg resultAggregator[T]
collectErrored bool
includeErrored bool
}

// Go submits a task to the pool. If all goroutines in the pool
// are busy, a call to Go() will block until the task can be started.
func (p *ResultErrorPool[T]) Go(f func() (T, error)) {
idx := p.agg.nextIndex()
p.errorPool.Go(func() error {
res, err := f()
if err == nil || p.collectErrored {
p.agg.add(res)
}
p.agg.save(idx, res, err != nil)
return err
})
}
Expand All @@ -36,15 +34,15 @@ func (p *ResultErrorPool[T]) Go(f func() (T, error)) {
// returning the results and any errors from tasks.
func (p *ResultErrorPool[T]) Wait() ([]T, error) {
err := p.errorPool.Wait()
return p.agg.results, err
return p.agg.collect(p.includeErrored), err
}

// WithCollectErrored configures the pool to still collect the result of a task
// even if the task returned an error. By default, the result of tasks that errored
// are ignored and only the error is collected.
func (p *ResultErrorPool[T]) WithCollectErrored() *ResultErrorPool[T] {
p.panicIfInitialized()
p.collectErrored = true
p.includeErrored = true
return p
}

Expand Down
2 changes: 1 addition & 1 deletion pool/result_error_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestResultErrorGroup(t *testing.T) {
func TestResultErrorPool(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
Expand Down
63 changes: 55 additions & 8 deletions pool/result_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pool

import (
"context"
"sort"
"sync"
)

Expand All @@ -19,9 +20,8 @@ func NewWithResults[T any]() *ResultPool[T] {
// Tasks are executed in the pool with Go(), then the results of the tasks are
// returned by Wait().
//
// The order of the results is not guaranteed to be the same as the order the
// tasks were submitted. If your use case requires consistent ordering,
// consider using the `stream` package or `Map` from the `iter` package.
// The order of the results is guaranteed to be the same as the order the
// tasks were submitted.
type ResultPool[T any] struct {
pool Pool
agg resultAggregator[T]
Expand All @@ -30,16 +30,17 @@ type ResultPool[T any] struct {
// Go submits a task to the pool. If all goroutines in the pool
// are busy, a call to Go() will block until the task can be started.
func (p *ResultPool[T]) Go(f func() T) {
idx := p.agg.nextIndex()
p.pool.Go(func() {
p.agg.add(f())
p.agg.save(idx, f(), false)
})
}

// Wait cleans up all spawned goroutines, propagating any panics, and returning
// a slice of results from tasks that did not panic.
func (p *ResultPool[T]) Wait() []T {
p.pool.Wait()
return p.agg.results
return p.agg.collect(true)
}

// MaxGoroutines returns the maximum size of the pool.
Expand Down Expand Up @@ -83,11 +84,57 @@ func (p *ResultPool[T]) panicIfInitialized() {
// goroutines. The zero value is valid and ready to use.
type resultAggregator[T any] struct {
mu sync.Mutex
len int
results []T
errored []int
}
Comment on lines 85 to 90
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to resultAggregator are central to this PR.

Summary:

  • Add a nextIndex method that makes it possible to reserve a slot in the result slice
  • Change the add method to save, which also takes an the index from nextIndex and whether the result errored
  • Add a collect method that respects the WithCollectErrored option by filtering out errored results.


func (r *resultAggregator[T]) add(res T) {
// nextIndex reserves a slot for a result. The returned value should be passed
// to save() when adding a result to the aggregator.
func (r *resultAggregator[T]) nextIndex() int {
r.mu.Lock()
r.results = append(r.results, res)
r.mu.Unlock()
defer r.mu.Unlock()

nextIdx := r.len
r.len += 1
return nextIdx
}

func (r *resultAggregator[T]) save(i int, res T, errored bool) {
r.mu.Lock()
defer r.mu.Unlock()

if i >= len(r.results) {
old := r.results
r.results = make([]T, r.len)
copy(r.results, old)
}

r.results[i] = res

if errored {
r.errored = append(r.errored, i)
}
}

// collect returns the set of aggregated results.
func (r *resultAggregator[T]) collect(includeErrored bool) []T {
if !r.mu.TryLock() {
panic("collect should not be called until all goroutines have exited")
}

if includeErrored || len(r.errored) == 0 {
return r.results
}

filtered := r.results[:0]
sort.Ints(r.errored)
for i, e := range r.errored {
if i == 0 {
filtered = append(filtered, r.results[:e]...)
} else {
filtered = append(filtered, r.results[r.errored[i-1]+1:e]...)
}
}
return filtered
}
27 changes: 22 additions & 5 deletions pool/result_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package pool_test

import (
"fmt"
"sort"
"math/rand"
"strconv"
"sync/atomic"
"testing"
Expand All @@ -22,8 +22,6 @@ func ExampleResultPool() {
})
}
res := p.Wait()
// Result order is nondeterministic, so sort them first
sort.Ints(res)
fmt.Println(res)

// Output:
Expand Down Expand Up @@ -62,10 +60,30 @@ func TestResultGroup(t *testing.T) {
})
}
res := g.Wait()
sort.Ints(res)
require.Equal(t, expected, res)
})

t.Run("deterministic order", func(t *testing.T) {
t.Parallel()
p := pool.NewWithResults[int]()
results := make([]int, 100)
for i := 0; i < 100; i++ {
results[i] = i
}
for _, result := range results {
result := result
p.Go(func() int {
// Add a random sleep to make it exceedingly unlikely that the
// results are returned in the order they are submitted.
time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond)
return result
})
}
got := p.Wait()
require.Equal(t, results, got)

})

t.Run("limit", func(t *testing.T) {
t.Parallel()
for _, maxGoroutines := range []int{1, 10, 100} {
Expand All @@ -90,7 +108,6 @@ func TestResultGroup(t *testing.T) {
})
}
res := g.Wait()
sort.Ints(res)
require.Equal(t, expected, res)
require.Equal(t, int64(0), errCount.Load())
require.Equal(t, int64(0), currentConcurrent.Load())
Expand Down
Loading