diff --git a/lock/leadlock.go b/lock/leadlock.go new file mode 100644 index 0000000..310016e --- /dev/null +++ b/lock/leadlock.go @@ -0,0 +1,73 @@ +package lock + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +type cond struct { + uptr unsafe.Pointer + mux sync.Mutex + cond *sync.Cond +} + +func newCond() *cond { + cond := &cond{} + cond.cond = sync.NewCond(&cond.mux) + return cond +} + +func (this *cond) wait() { + this.mux.Lock() + for this.uptr == nil { + this.cond.Wait() + } + this.mux.Unlock() +} + +func (this *cond) wake(uptr unsafe.Pointer) { + this.mux.Lock() + this.uptr = uptr + this.mux.Unlock() + this.cond.Broadcast() +} + +type LeadLock struct { + nr int32 + cond *cond + batch *cond + mux sync.Mutex +} + +func NewLeadLock() *LeadLock { + leadlock := &LeadLock{} + leadlock.cond = newCond() + return leadlock +} + +func (this *LeadLock) Lock() unsafe.Pointer { + var uptr unsafe.Pointer + n := atomic.AddInt32(&this.nr, 1) + ptr := (*unsafe.Pointer)(unsafe.Pointer(&this.cond)) + if n > 1 { + cond := (*cond)(atomic.LoadPointer(ptr)) + cond.wait() + uptr = cond.uptr + } else { + this.mux.Lock() + atomic.StoreInt32(&this.nr, 0) + this.batch = (*cond)(atomic.SwapPointer(ptr, unsafe.Pointer(newCond()))) + } + return uptr +} + +func (this *LeadLock) Unlock(uptr unsafe.Pointer) { + ptr := unsafe.Pointer(&this.batch) + batch := atomic.SwapPointer((*unsafe.Pointer)(ptr), unsafe.Pointer(nil)) + if batch == nil { + return + } + (*cond)(batch).wake(uptr) + this.mux.Unlock() +} diff --git a/singleflight/simple.go b/singleflight/simple.go new file mode 100644 index 0000000..74a413f --- /dev/null +++ b/singleflight/simple.go @@ -0,0 +1,63 @@ +package singleflight + +import ( + "runtime" + "sync" + "unsafe" + + "golang.org/x/sync/lock" +) + +type SimpleSingleFlight struct { + sync.Map +} + +func (this *SimpleSingleFlight) leadLock(key string) *lock.LeadLock { + iface, ok := this.Load(key) + if !ok { + iface, _ = this.LoadOrStore(key, lock.NewLeadLock()) + } + return iface.(*lock.LeadLock) +} + +type result struct { + v interface{} + err error +} + +func (this *SimpleSingleFlight) Do(key string, fn func() (interface{}, error)) (interface{}, error) { + var res *result + defer func() { + if e, ok := res.err.(*panicError); ok { + panic(e) + } else if res.err == errGoexit { + runtime.Goexit() + } + }() + + ll := this.leadLock(key) + ptr := ll.Lock() + if ptr != nil { + res = (*result)(ptr) + return res.v, res.err + } + + res = &result{} + func() { + var getBack bool + defer func() { + if !getBack { + if r := recover(); r != nil { + res.err = newPanicError(r) + } else { + res.err = errGoexit + } + } + ll.Unlock(unsafe.Pointer(res)) + }() + res.v, res.err = fn() + getBack = true + }() + + return res.v, res.err +} diff --git a/singleflight/simple_test.go b/singleflight/simple_test.go new file mode 100644 index 0000000..bfcd4c6 --- /dev/null +++ b/singleflight/simple_test.go @@ -0,0 +1,160 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package singleflight + +import ( + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo1(t *testing.T) { + var g SimpleSingleFlight + v, err := g.Do("key", func() (interface{}, error) { + return "bar", nil + }) + if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoErr1(t *testing.T) { + var g SimpleSingleFlight + someErr := errors.New("Some error") + v, err := g.Do("key", func() (interface{}, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress1(t *testing.T) { + var g SimpleSingleFlight + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (interface{}, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if s, _ := v.(string); s != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +// Test singleflight behaves correctly after Do panic. +// See https://github.com/golang/go/issues/41133 +func TestPanicDo1(t *testing.T) { + var g SimpleSingleFlight + fn := func() (interface{}, error) { + panic("invalid memory address or nil pointer dereference") + } + + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + t.Logf("Got panic: %v\n%s", err, debug.Stack()) + atomic.AddInt32(&panicCount, 1) + } + + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + + g.Do("key", fn) + }() + } + + select { + case <-done: + if panicCount != n { + t.Errorf("Expect %d panic, but got %d", n, panicCount) + } + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestGoexitDo1(t *testing.T) { + var g SimpleSingleFlight + fn := func() (interface{}, error) { + runtime.Goexit() + return nil, nil + } + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, err = g.Do("key", fn) + }() + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +}