Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 36 additions & 54 deletions dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log"
"runtime"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -29,8 +30,12 @@ type Interface[K comparable, V any] interface {
Flush()
}

var ErrNoResultProvided = errors.New("no result provided")

// BatchFunc is a function, which when given a slice of keys (string), returns a slice of `results`.
// It's important that the length of the input keys matches the length of the output results.
// Should the batch function return nil for a result, it will be treated as return an error
// of `ErrNoResultProvided` for that key.
//
// The keys passed to this function are guaranteed to be unique
type BatchFunc[K comparable, V any] func(context.Context, []K) []*Result[V]
Expand Down Expand Up @@ -131,8 +136,9 @@ type ThunkMany[V any] func() ([]V, []error)

// type used to on input channel
type batchRequest[K comparable, V any] struct {
key K
channel chan *Result[V]
key K
result atomic.Pointer[Result[V]]
done chan struct{}
}

// Option allows for configuration of Loader fields.
Expand Down Expand Up @@ -221,11 +227,9 @@ func NewBatchedLoader[K comparable, V any](batchFn BatchFunc[K, V], opts ...Opti
// the registered BatchFunc.
func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
ctx, finish := l.tracer.TraceLoad(originalContext, key)

c := make(chan *Result[V], 1)
var result struct {
mu sync.RWMutex
value *Result[V]
req := &batchRequest[K, V]{
key: key,
done: make(chan struct{}),
}

// We need to lock both the batchLock and cacheLock because the batcher can
Expand Down Expand Up @@ -254,34 +258,19 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
defer l.cacheLock.Unlock()

thunk := func() (V, error) {
result.mu.RLock()
resultNotSet := result.value == nil
result.mu.RUnlock()

if resultNotSet {
result.mu.Lock()
if v, ok := <-c; ok {
result.value = v
}
result.mu.Unlock()
}
result.mu.RLock()
defer result.mu.RUnlock()
<-req.done
result := req.result.Load()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result can be nil which would cause a panic. Please, add a nil check!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

result can be nil which would cause a panic. Please, add a nil check!

Ok, I see, it can be nil if the user provided batch function did not set the items[i] element. Rather than check it each time the thunk is invoked, what if we always ensured that it was never nil.

In this loop:

for i, req := range reqs {
		req.result.Store(items[i])
		close(req.done)
	}

We could check item items[i] was nil and set to a item which an error that indicates that no value was set. What would you like the error to be? Should I just make one with fmt.Errorf() or should we declare one as a var ErrNoValueProvided = errors.New("no value provided") so the user could use errors.Is against it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For starters, I'd be happy with just fmt.Errorf() - at the very least we remove the risk of a panic. Whether you add a specific error for that is up to you - no strong opinion here.

Copy link
Member

@pavelnikolov pavelnikolov Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we always ensured that it was never nil

How do you imagine that? Can you provide an example, please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the initial code also has would panic if results[i] of the batch function was nil.

dataloader/dataloader.go

Lines 264 to 272 in ab5318f

result.value = v
}
result.mu.Unlock()
}
result.mu.RLock()
defer result.mu.RUnlock()
var ev *PanicErrorWrapper
var es *SkipCacheError
if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){
the code will check result.value.Error and result.value` could be nil.

I have made a separate to address this with a known error and updated the comment. This way if the batch function doesn't return a result for a given item, we can automatically give it an error result. (Similarly to how it does that if the length of the result isn't the same.)

efe1a00

var ev *PanicErrorWrapper
var es *SkipCacheError
if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){
if result.Error != nil && (errors.As(result.Error, &ev) || errors.As(result.Error, &es)) {
l.Clear(ctx, key)
}
return result.value.Data, result.value.Error
return result.Data, result.Error
}
defer finish(thunk)

l.cache.Set(ctx, key, thunk)

// this is sent to batch fn. It contains the key and the channel to return
// the result on
req := &batchRequest[K, V]{key, c}

// start the batch window if it hasn't already started.
if l.curBatcher == nil {
l.curBatcher = l.newBatcher(l.silent, l.tracer)
Expand Down Expand Up @@ -338,8 +327,9 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk
length = len(keys)
data = make([]V, length)
errors = make([]error, length)
c = make(chan *ResultMany[V], 1)
result atomic.Pointer[ResultMany[V]]
wg sync.WaitGroup
done = make(chan struct{})
)

resolve := func(ctx context.Context, i int) {
Expand All @@ -356,6 +346,7 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk
}

go func() {
defer close(done)
wg.Wait()

// errs is nil unless there exists a non-nil error.
Expand All @@ -368,30 +359,13 @@ func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) Thunk
}
}

c <- &ResultMany[V]{Data: data, Error: errs}
close(c)
result.Store(&ResultMany[V]{Data: data, Error: errs})
}()

var result struct {
mu sync.RWMutex
value *ResultMany[V]
}

thunkMany := func() ([]V, []error) {
result.mu.RLock()
resultNotSet := result.value == nil
result.mu.RUnlock()

if resultNotSet {
result.mu.Lock()
if v, ok := <-c; ok {
result.value = v
}
result.mu.Unlock()
}
result.mu.RLock()
defer result.mu.RUnlock()
return result.value.Data, result.value.Error
<-done
r := result.Load()
return r.Data, r.Error
}

defer finish(thunkMany)
Expand Down Expand Up @@ -498,8 +472,8 @@ func (b *batcher[K, V]) batch(originalContext context.Context) {

if panicErr != nil {
for _, req := range reqs {
req.channel <- &Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}}
close(req.channel)
req.result.Store(&Result[V]{Error: &PanicErrorWrapper{panicError: fmt.Errorf("Panic received in batch function: %v", panicErr)}})
close(req.done)
}
return
}
Expand All @@ -517,16 +491,24 @@ func (b *batcher[K, V]) batch(originalContext context.Context) {
`, keys, items)}

for _, req := range reqs {
req.channel <- err
close(req.channel)
req.result.Store(err)
close(req.done)
}

return
}

var notSetResult *Result[V] // don't allocate unless we need it
for i, req := range reqs {
req.channel <- items[i]
close(req.channel)
if items[i] == nil {
if notSetResult == nil {
notSetResult = &Result[V]{Error: ErrNoResultProvided}
}
req.result.Store(notSetResult)
} else {
req.result.Store(items[i])
}
close(req.done)
}
}

Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/graph-gophers/dataloader/v7

go 1.18
go 1.19

require (
github.com/hashicorp/golang-lru v0.5.4
Expand Down