Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add leadlock and simple single flight #11

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions lock/leadlock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package lock

import (
"sync"
"sync/atomic"
"unsafe"
)

type cond struct {
uptr unsafe.Pointer
mux sync.Mutex
cond *sync.Cond
}

func newCond() *cond {
cond := &cond{}
cond.cond = sync.NewCond(&cond.mux)
return cond
}

func (this *cond) wait() {
this.mux.Lock()
for this.uptr == nil {
this.cond.Wait()
}
this.mux.Unlock()
}

func (this *cond) wake(uptr unsafe.Pointer) {
this.mux.Lock()
this.uptr = uptr
this.mux.Unlock()
this.cond.Broadcast()
}

type LeadLock struct {
nr int32
cond *cond
batch *cond
mux sync.Mutex
}

func NewLeadLock() *LeadLock {
leadlock := &LeadLock{}
leadlock.cond = newCond()
return leadlock
}

func (this *LeadLock) Lock() unsafe.Pointer {
var uptr unsafe.Pointer
n := atomic.AddInt32(&this.nr, 1)
ptr := (*unsafe.Pointer)(unsafe.Pointer(&this.cond))
if n > 1 {
cond := (*cond)(atomic.LoadPointer(ptr))
cond.wait()
uptr = cond.uptr
} else {
this.mux.Lock()
atomic.StoreInt32(&this.nr, 0)
this.batch = (*cond)(atomic.SwapPointer(ptr, unsafe.Pointer(newCond())))
}
return uptr
}

func (this *LeadLock) Unlock(uptr unsafe.Pointer) {
ptr := unsafe.Pointer(&this.batch)
batch := atomic.SwapPointer((*unsafe.Pointer)(ptr), unsafe.Pointer(nil))
if batch == nil {
return
}
(*cond)(batch).wake(uptr)
this.mux.Unlock()
}
63 changes: 63 additions & 0 deletions singleflight/simple.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package singleflight

import (
"runtime"
"sync"
"unsafe"

"golang.org/x/sync/lock"
)

type SimpleSingleFlight struct {
sync.Map
}

func (this *SimpleSingleFlight) leadLock(key string) *lock.LeadLock {
iface, ok := this.Load(key)
if !ok {
iface, _ = this.LoadOrStore(key, lock.NewLeadLock())
}
return iface.(*lock.LeadLock)
}

type result struct {
v interface{}
err error
}

func (this *SimpleSingleFlight) Do(key string, fn func() (interface{}, error)) (interface{}, error) {
var res *result
defer func() {
if e, ok := res.err.(*panicError); ok {
panic(e)
} else if res.err == errGoexit {
runtime.Goexit()
}
}()

ll := this.leadLock(key)
ptr := ll.Lock()
if ptr != nil {
res = (*result)(ptr)
return res.v, res.err
}

res = &result{}
func() {
var getBack bool
defer func() {
if !getBack {
if r := recover(); r != nil {
res.err = newPanicError(r)
} else {
res.err = errGoexit
}
}
ll.Unlock(unsafe.Pointer(res))
}()
res.v, res.err = fn()
getBack = true
}()

return res.v, res.err
}
160 changes: 160 additions & 0 deletions singleflight/simple_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package singleflight

import (
"errors"
"fmt"
"runtime"
"runtime/debug"
"sync"
"sync/atomic"
"testing"
"time"
)

func TestDo1(t *testing.T) {
var g SimpleSingleFlight
v, err := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}

func TestDoErr1(t *testing.T) {
var g SimpleSingleFlight
someErr := errors.New("Some error")
v, err := g.Do("key", func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr %v", err, someErr)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}

func TestDoDupSuppress1(t *testing.T) {
var g SimpleSingleFlight
var wg1, wg2 sync.WaitGroup
c := make(chan string, 1)
var calls int32
fn := func() (interface{}, error) {
if atomic.AddInt32(&calls, 1) == 1 {
// First invocation.
wg1.Done()
}
v := <-c
c <- v // pump; make available for any future calls

time.Sleep(10 * time.Millisecond) // let more goroutines enter Do

return v, nil
}

const n = 10
wg1.Add(1)
for i := 0; i < n; i++ {
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
wg1.Done()
v, err := g.Do("key", fn)
if err != nil {
t.Errorf("Do error: %v", err)
return
}
if s, _ := v.(string); s != "bar" {
t.Errorf("Do = %T %v; want %q", v, v, "bar")
}
}()
}
wg1.Wait()
// At least one goroutine is in fn now and all of them have at
// least reached the line before the Do.
c <- "bar"
wg2.Wait()
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
}
}

// Test singleflight behaves correctly after Do panic.
// See https://github.com/golang/go/issues/41133
func TestPanicDo1(t *testing.T) {
var g SimpleSingleFlight
fn := func() (interface{}, error) {
panic("invalid memory address or nil pointer dereference")
}

const n = 5
waited := int32(n)
panicCount := int32(0)
done := make(chan struct{})
for i := 0; i < n; i++ {
go func() {
defer func() {
if err := recover(); err != nil {
t.Logf("Got panic: %v\n%s", err, debug.Stack())
atomic.AddInt32(&panicCount, 1)
}

if atomic.AddInt32(&waited, -1) == 0 {
close(done)
}
}()

g.Do("key", fn)
}()
}

select {
case <-done:
if panicCount != n {
t.Errorf("Expect %d panic, but got %d", n, panicCount)
}
case <-time.After(time.Second):
t.Fatalf("Do hangs")
}
}

func TestGoexitDo1(t *testing.T) {
var g SimpleSingleFlight
fn := func() (interface{}, error) {
runtime.Goexit()
return nil, nil
}

const n = 5
waited := int32(n)
done := make(chan struct{})
for i := 0; i < n; i++ {
go func() {
var err error
defer func() {
if err != nil {
t.Errorf("Error should be nil, but got: %v", err)
}
if atomic.AddInt32(&waited, -1) == 0 {
close(done)
}
}()
_, err = g.Do("key", fn)
}()
}

select {
case <-done:
case <-time.After(time.Second):
t.Fatalf("Do hangs")
}
}