Skip to content

Commit

Permalink
Make result order deterministic (#126)
Browse files Browse the repository at this point in the history
This makes the order of results in a `Result.*Pool` deterministic so
that the order of the result slice corresponds with the order of tasks
submitted. As an example of why this would be useful, it makes it easy
to rewrite `iter.Map` in terms of `ResultPool`. Additionally, it's a
generally nice and intuitive property to be able to match the index of
the result slice with the index of the input slice.
  • Loading branch information
camdencheek authored Jan 19, 2024
1 parent 4afefce commit 8427ccd
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 27 deletions.
7 changes: 3 additions & 4 deletions pool/result_context_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ type ResultContextPool[T any] struct {
// Go submits a task to the pool. If all goroutines in the pool
// are busy, a call to Go() will block until the task can be started.
func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) {
idx := p.agg.nextIndex()
p.contextPool.Go(func(ctx context.Context) error {
res, err := f(ctx)
if err == nil || p.collectErrored {
p.agg.add(res)
}
p.agg.save(idx, res, err != nil)
return err
})
}
Expand All @@ -33,7 +32,7 @@ func (p *ResultContextPool[T]) Go(f func(context.Context) (T, error)) {
// returns an error if any of the tasks errored.
func (p *ResultContextPool[T]) Wait() ([]T, error) {
err := p.contextPool.Wait()
return p.agg.results, err
return p.agg.collect(p.collectErrored), err
}

// WithCollectErrored configures the pool to still collect the result of a task
Expand Down
2 changes: 0 additions & 2 deletions pool/result_context_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"sort"
"strconv"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -223,7 +222,6 @@ func TestResultContextPool(t *testing.T) {
})
}
res, err := g.Wait()
sort.Ints(res)
require.Equal(t, expected, res)
require.NoError(t, err)
require.Equal(t, int64(0), currentConcurrent.Load())
Expand Down
12 changes: 5 additions & 7 deletions pool/result_error_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ import (
// type and an error. Tasks are executed in the pool with Go(), then the
// results of the tasks are returned by Wait().
//
// The order of the results is not guaranteed to be the same as the order the
// tasks were submitted. If your use case requires consistent ordering,
// consider using the `stream` package or `Map` from the `iter` package.
// The order of the results is guaranteed to be the same as the order the
// tasks were submitted.
//
// The configuration methods (With*) will panic if they are used after calling
// Go() for the first time.
Expand All @@ -23,11 +22,10 @@ type ResultErrorPool[T any] struct {
// Go submits a task to the pool. If all goroutines in the pool
// are busy, a call to Go() will block until the task can be started.
func (p *ResultErrorPool[T]) Go(f func() (T, error)) {
idx := p.agg.nextIndex()
p.errorPool.Go(func() error {
res, err := f()
if err == nil || p.collectErrored {
p.agg.add(res)
}
p.agg.save(idx, res, err != nil)
return err
})
}
Expand All @@ -36,7 +34,7 @@ func (p *ResultErrorPool[T]) Go(f func() (T, error)) {
// returning the results and any errors from tasks.
func (p *ResultErrorPool[T]) Wait() ([]T, error) {
err := p.errorPool.Wait()
return p.agg.results, err
return p.agg.collect(p.collectErrored), err
}

// WithCollectErrored configures the pool to still collect the result of a task
Expand Down
2 changes: 1 addition & 1 deletion pool/result_error_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestResultErrorGroup(t *testing.T) {
func TestResultErrorPool(t *testing.T) {
t.Parallel()

err1 := errors.New("err1")
Expand Down
63 changes: 55 additions & 8 deletions pool/result_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pool

import (
"context"
"sort"
"sync"
)

Expand All @@ -19,9 +20,8 @@ func NewWithResults[T any]() *ResultPool[T] {
// Tasks are executed in the pool with Go(), then the results of the tasks are
// returned by Wait().
//
// The order of the results is not guaranteed to be the same as the order the
// tasks were submitted. If your use case requires consistent ordering,
// consider using the `stream` package or `Map` from the `iter` package.
// The order of the results is guaranteed to be the same as the order the
// tasks were submitted.
type ResultPool[T any] struct {
pool Pool
agg resultAggregator[T]
Expand All @@ -30,16 +30,17 @@ type ResultPool[T any] struct {
// Go submits a task to the pool. If all goroutines in the pool
// are busy, a call to Go() will block until the task can be started.
func (p *ResultPool[T]) Go(f func() T) {
idx := p.agg.nextIndex()
p.pool.Go(func() {
p.agg.add(f())
p.agg.save(idx, f(), false)
})
}

// Wait cleans up all spawned goroutines, propagating any panics, and returning
// a slice of results from tasks that did not panic.
func (p *ResultPool[T]) Wait() []T {
p.pool.Wait()
return p.agg.results
return p.agg.collect(true)
}

// MaxGoroutines returns the maximum size of the pool.
Expand Down Expand Up @@ -83,11 +84,57 @@ func (p *ResultPool[T]) panicIfInitialized() {
// goroutines. The zero value is valid and ready to use.
type resultAggregator[T any] struct {
mu sync.Mutex
len int
results []T
errored []int
}

func (r *resultAggregator[T]) add(res T) {
// nextIndex reserves a slot for a result. The returned value should be passed
// to save() when adding a result to the aggregator.
func (r *resultAggregator[T]) nextIndex() int {
r.mu.Lock()
r.results = append(r.results, res)
r.mu.Unlock()
defer r.mu.Unlock()

nextIdx := r.len
r.len += 1
return nextIdx
}

func (r *resultAggregator[T]) save(i int, res T, errored bool) {
r.mu.Lock()
defer r.mu.Unlock()

if i >= len(r.results) {
old := r.results
r.results = make([]T, r.len)
copy(r.results, old)
}

r.results[i] = res

if errored {
r.errored = append(r.errored, i)
}
}

// collect returns the set of aggregated results.
func (r *resultAggregator[T]) collect(collectErrored bool) []T {
if !r.mu.TryLock() {
panic("collect should not be called until all goroutines have exited")
}

if collectErrored || len(r.errored) == 0 {
return r.results
}

filtered := r.results[:0]
sort.Ints(r.errored)
for i, e := range r.errored {
if i == 0 {
filtered = append(filtered, r.results[:e]...)
} else {
filtered = append(filtered, r.results[r.errored[i-1]+1:e]...)
}
}
return filtered
}
26 changes: 21 additions & 5 deletions pool/result_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package pool_test

import (
"fmt"
"sort"
"math/rand"
"strconv"
"sync/atomic"
"testing"
Expand All @@ -22,8 +22,6 @@ func ExampleResultPool() {
})
}
res := p.Wait()
// Result order is nondeterministic, so sort them first
sort.Ints(res)
fmt.Println(res)

// Output:
Expand Down Expand Up @@ -62,10 +60,29 @@ func TestResultGroup(t *testing.T) {
})
}
res := g.Wait()
sort.Ints(res)
require.Equal(t, expected, res)
})

t.Run("deterministic order", func(t *testing.T) {
t.Parallel()
p := pool.NewWithResults[int]()
results := make([]int, 100)
for i := 0; i < 100; i++ {
results[i] = i
}
for _, result := range results {
result := result
p.Go(func() int {
// Add a random sleep to make it exceedingly unlikely that the
// results are returned in the order they are submitted.
time.Sleep(time.Duration(rand.Int()%100) * time.Millisecond)
return result
})
}
got := p.Wait()
require.Equal(t, results, got)
})

t.Run("limit", func(t *testing.T) {
t.Parallel()
for _, maxGoroutines := range []int{1, 10, 100} {
Expand All @@ -90,7 +107,6 @@ func TestResultGroup(t *testing.T) {
})
}
res := g.Wait()
sort.Ints(res)
require.Equal(t, expected, res)
require.Equal(t, int64(0), errCount.Load())
require.Equal(t, int64(0), currentConcurrent.Load())
Expand Down

0 comments on commit 8427ccd

Please sign in to comment.