diff --git a/message/router.go b/message/router.go index 6c2c36861..5c6c99e59 100644 --- a/message/router.go +++ b/message/router.go @@ -111,9 +111,10 @@ type Router struct { handlersWg *sync.WaitGroup runningHandlersWg *sync.WaitGroup - closeCh chan struct{} - closedCh chan struct{} - closed bool + closeCh chan struct{} + closedCh chan struct{} + closed bool + closedLock sync.Mutex logger watermill.LoggerAdapter @@ -325,7 +326,7 @@ func (r *Router) Run(ctx context.Context) (err error) { // because for example all subscriptions are closed. func (r *Router) closeWhenAllHandlersStopped() { r.handlersWg.Wait() - if r.closed { + if r.isClosed() { // already closed return } @@ -348,6 +349,9 @@ func (r *Router) Running() chan struct{} { } func (r *Router) Close() error { + r.closedLock.Lock() + defer r.closedLock.Unlock() + if r.closed { return nil } @@ -367,6 +371,13 @@ func (r *Router) Close() error { return nil } +func (r *Router) isClosed() bool { + r.closedLock.Lock() + defer r.closedLock.Unlock() + + return r.closed +} + type handler struct { name string logger watermill.LoggerAdapter diff --git a/message/router_test.go b/message/router_test.go index ab54d3329..9551fbe58 100644 --- a/message/router_test.go +++ b/message/router_test.go @@ -7,8 +7,6 @@ import ( "testing" "time" - "github.com/ThreeDotsLabs/watermill/pubsub/tests" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -17,6 +15,7 @@ import ( "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message/subscriber" "github.com/ThreeDotsLabs/watermill/pubsub/gochannel" + "github.com/ThreeDotsLabs/watermill/pubsub/tests" ) func TestRouter_functional(t *testing.T) { @@ -444,6 +443,54 @@ func TestRouterDecoratorsOrder(t *testing.T) { assert.Equal(t, "foobar", transformedMessage.Metadata.Get("sub")) } +func TestRouter_concurrent_close(t *testing.T) { + logger := watermill.NewStdLogger(true, true) + + router, err := message.NewRouter(message.RouterConfig{}, logger) + require.NoError(t, err) + + go func() { + err := router.Close() + require.NoError(t, err) + }() + + err = router.Close() + require.NoError(t, err) +} + +func TestRouter_concurrent_close_on_handlers_closed(t *testing.T) { + logger := watermill.NewStdLogger(true, true) + + router, err := message.NewRouter(message.RouterConfig{}, logger) + require.NoError(t, err) + + _, sub := createPubSub() + + router.AddNoPublisherHandler( + "handler", + "subTopic", + sub, + func(msg *message.Message) error { + return nil + }, + ) + + go func() { + if err := router.Run(context.Background()); err != nil { + panic(err) + } + }() + <-router.Running() + + go func() { + err := sub.Close() + require.NoError(t, err) + }() + + err = router.Close() + require.NoError(t, err) +} + func createBenchSubscriber(b *testing.B) benchMockSubscriber { var messagesToSend []*message.Message for i := 0; i < b.N; i++ { diff --git a/pubsub/gochannel/pubsub.go b/pubsub/gochannel/pubsub.go index 877a474c4..58c231b18 100644 --- a/pubsub/gochannel/pubsub.go +++ b/pubsub/gochannel/pubsub.go @@ -80,7 +80,7 @@ func NewGoChannel(config Config, logger watermill.LoggerAdapter) *GoChannel { // // Messages may be persisted or not, depending of persistent attribute. func (g *GoChannel) Publish(topic string, messages ...*message.Message) error { - if g.closed { + if g.isClosed() { return errors.New("Pub/Sub closed") } @@ -160,10 +160,15 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) (<-chan // // There are no consumer groups support etc. Every consumer will receive every produced message. func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *message.Message, error) { + g.closedLock.Lock() + if g.closed { return nil, errors.New("Pub/Sub closed") } + g.subscribersWg.Add(1) + g.closedLock.Unlock() + g.subscribersLock.Lock() subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{}) @@ -176,7 +181,6 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag logger: g.logger, closing: make(chan struct{}), } - g.subscribersWg.Add(1) go func(s *subscriber, g *GoChannel) { select { @@ -261,6 +265,13 @@ func (g *GoChannel) topicSubscribers(topic string) []*subscriber { return subscribers } +func (g *GoChannel) isClosed() bool { + g.closedLock.Lock() + defer g.closedLock.Unlock() + + return g.closed +} + func (g *GoChannel) Close() error { g.closedLock.Lock() defer g.closedLock.Unlock() diff --git a/pubsub/gochannel/pubsub_test.go b/pubsub/gochannel/pubsub_test.go index b47e8438b..2a6d47493 100644 --- a/pubsub/gochannel/pubsub_test.go +++ b/pubsub/gochannel/pubsub_test.go @@ -4,20 +4,19 @@ import ( "context" "fmt" "log" + "strconv" "sync" "testing" "time" - "github.com/ThreeDotsLabs/watermill/pubsub/tests" - - "github.com/ThreeDotsLabs/watermill/pubsub/gochannel" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message/subscriber" - "github.com/stretchr/testify/require" + "github.com/ThreeDotsLabs/watermill/pubsub/gochannel" + "github.com/ThreeDotsLabs/watermill/pubsub/tests" ) func createPersistentPubSub(t *testing.T) (message.Publisher, message.Subscriber) { @@ -123,6 +122,52 @@ func TestPublishSubscribe_race_condition_on_subscribe(t *testing.T) { } } +func TestSubscribe_race_condition_when_closing(t *testing.T) { + testsCount := 15 + if testing.Short() { + testsCount = 3 + } + + for i := 0; i < testsCount; i++ { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() + pubSub := gochannel.NewGoChannel( + gochannel.Config{}, + watermill.NewStdLogger(true, false), + ) + go func() { + err := pubSub.Close() + require.NoError(t, err) + }() + _, err := pubSub.Subscribe(context.Background(), "topic") + require.NoError(t, err) + }) + } +} + +func TestPublish_race_condition_when_closing(t *testing.T) { + testsCount := 15 + if testing.Short() { + testsCount = 3 + } + + for i := 0; i < testsCount; i++ { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() + pubSub := gochannel.NewGoChannel( + gochannel.Config{}, + watermill.NewStdLogger(true, false), + ) + go func() { + err := pubSub.Close() + require.NoError(t, err) + }() + err := pubSub.Publish("topic", message.NewMessage(strconv.Itoa(i), nil)) + require.NoError(t, err) + }) + } +} + func testPublishSubscribeSubRace(t *testing.T) { t.Helper()