Skip to content

Commit

Permalink
Fix gochannel-related races (#115)
Browse files Browse the repository at this point in the history
* Use closeLock to synchronize access to WaitGroup

* Add closedLock to fix races on router.Close()
  • Loading branch information
m110 authored Aug 28, 2019
1 parent 4768833 commit d5532e2
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 13 deletions.
19 changes: 15 additions & 4 deletions message/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ type Router struct {
handlersWg *sync.WaitGroup
runningHandlersWg *sync.WaitGroup

closeCh chan struct{}
closedCh chan struct{}
closed bool
closeCh chan struct{}
closedCh chan struct{}
closed bool
closedLock sync.Mutex

logger watermill.LoggerAdapter

Expand Down Expand Up @@ -325,7 +326,7 @@ func (r *Router) Run(ctx context.Context) (err error) {
// because for example all subscriptions are closed.
func (r *Router) closeWhenAllHandlersStopped() {
r.handlersWg.Wait()
if r.closed {
if r.isClosed() {
// already closed
return
}
Expand All @@ -348,6 +349,9 @@ func (r *Router) Running() chan struct{} {
}

func (r *Router) Close() error {
r.closedLock.Lock()
defer r.closedLock.Unlock()

if r.closed {
return nil
}
Expand All @@ -367,6 +371,13 @@ func (r *Router) Close() error {
return nil
}

func (r *Router) isClosed() bool {
r.closedLock.Lock()
defer r.closedLock.Unlock()

return r.closed
}

type handler struct {
name string
logger watermill.LoggerAdapter
Expand Down
51 changes: 49 additions & 2 deletions message/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"testing"
"time"

"github.com/ThreeDotsLabs/watermill/pubsub/tests"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -17,6 +15,7 @@ import (
"github.com/ThreeDotsLabs/watermill/message"
"github.com/ThreeDotsLabs/watermill/message/subscriber"
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"
"github.com/ThreeDotsLabs/watermill/pubsub/tests"
)

func TestRouter_functional(t *testing.T) {
Expand Down Expand Up @@ -444,6 +443,54 @@ func TestRouterDecoratorsOrder(t *testing.T) {
assert.Equal(t, "foobar", transformedMessage.Metadata.Get("sub"))
}

func TestRouter_concurrent_close(t *testing.T) {
logger := watermill.NewStdLogger(true, true)

router, err := message.NewRouter(message.RouterConfig{}, logger)
require.NoError(t, err)

go func() {
err := router.Close()
require.NoError(t, err)
}()

err = router.Close()
require.NoError(t, err)
}

func TestRouter_concurrent_close_on_handlers_closed(t *testing.T) {
logger := watermill.NewStdLogger(true, true)

router, err := message.NewRouter(message.RouterConfig{}, logger)
require.NoError(t, err)

_, sub := createPubSub()

router.AddNoPublisherHandler(
"handler",
"subTopic",
sub,
func(msg *message.Message) error {
return nil
},
)

go func() {
if err := router.Run(context.Background()); err != nil {
panic(err)
}
}()
<-router.Running()

go func() {
err := sub.Close()
require.NoError(t, err)
}()

err = router.Close()
require.NoError(t, err)
}

func createBenchSubscriber(b *testing.B) benchMockSubscriber {
var messagesToSend []*message.Message
for i := 0; i < b.N; i++ {
Expand Down
15 changes: 13 additions & 2 deletions pubsub/gochannel/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func NewGoChannel(config Config, logger watermill.LoggerAdapter) *GoChannel {
//
// Messages may be persisted or not, depending of persistent attribute.
func (g *GoChannel) Publish(topic string, messages ...*message.Message) error {
if g.closed {
if g.isClosed() {
return errors.New("Pub/Sub closed")
}

Expand Down Expand Up @@ -160,10 +160,15 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) (<-chan
//
// There are no consumer groups support etc. Every consumer will receive every produced message.
func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) {
g.closedLock.Lock()

if g.closed {
return nil, errors.New("Pub/Sub closed")
}

g.subscribersWg.Add(1)
g.closedLock.Unlock()

g.subscribersLock.Lock()

subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{})
Expand All @@ -176,7 +181,6 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag
logger: g.logger,
closing: make(chan struct{}),
}
g.subscribersWg.Add(1)

go func(s *subscriber, g *GoChannel) {
select {
Expand Down Expand Up @@ -261,6 +265,13 @@ func (g *GoChannel) topicSubscribers(topic string) []*subscriber {
return subscribers
}

func (g *GoChannel) isClosed() bool {
g.closedLock.Lock()
defer g.closedLock.Unlock()

return g.closed
}

func (g *GoChannel) Close() error {
g.closedLock.Lock()
defer g.closedLock.Unlock()
Expand Down
55 changes: 50 additions & 5 deletions pubsub/gochannel/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@ import (
"context"
"fmt"
"log"
"strconv"
"sync"
"testing"
"time"

"github.com/ThreeDotsLabs/watermill/pubsub/tests"

"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/ThreeDotsLabs/watermill/message/subscriber"
"github.com/stretchr/testify/require"
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"
"github.com/ThreeDotsLabs/watermill/pubsub/tests"
)

func createPersistentPubSub(t *testing.T) (message.Publisher, message.Subscriber) {
Expand Down Expand Up @@ -123,6 +122,52 @@ func TestPublishSubscribe_race_condition_on_subscribe(t *testing.T) {
}
}

func TestSubscribe_race_condition_when_closing(t *testing.T) {
testsCount := 15
if testing.Short() {
testsCount = 3
}

for i := 0; i < testsCount; i++ {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
pubSub := gochannel.NewGoChannel(
gochannel.Config{},
watermill.NewStdLogger(true, false),
)
go func() {
err := pubSub.Close()
require.NoError(t, err)
}()
_, err := pubSub.Subscribe(context.Background(), "topic")
require.NoError(t, err)
})
}
}

func TestPublish_race_condition_when_closing(t *testing.T) {
testsCount := 15
if testing.Short() {
testsCount = 3
}

for i := 0; i < testsCount; i++ {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
pubSub := gochannel.NewGoChannel(
gochannel.Config{},
watermill.NewStdLogger(true, false),
)
go func() {
err := pubSub.Close()
require.NoError(t, err)
}()
err := pubSub.Publish("topic", message.NewMessage(strconv.Itoa(i), nil))
require.NoError(t, err)
})
}
}

func testPublishSubscribeSubRace(t *testing.T) {
t.Helper()

Expand Down

0 comments on commit d5532e2

Please sign in to comment.