Skip to content

Commit

Permalink
Merge pull request #9 from bstasyszyn/sharedconnection-2
Browse files Browse the repository at this point in the history
feat: Allow subscribers/publishers to reuse a single connection
  • Loading branch information
roblaszczak authored Nov 4, 2021
2 parents 92366a9 + 34fc75b commit 4f337d7
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 47 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.17
require (
github.com/ThreeDotsLabs/watermill v1.2.0-rc.6
github.com/cenkalti/backoff/v3 v3.2.2
github.com/google/uuid v1.3.0
github.com/hashicorp/go-multierror v1.1.1
github.com/pkg/errors v0.9.1
github.com/streadway/amqp v1.0.0
Expand All @@ -13,7 +14,6 @@ require (

require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/hashicorp/errwrap v1.0.0 // indirect
github.com/lithammer/shortuuid/v3 v3.0.7 // indirect
github.com/oklog/ulid v1.3.1 // indirect
Expand Down
62 changes: 34 additions & 28 deletions pkg/amqp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ package amqp

import (
"sync"
"sync/atomic"

"github.com/ThreeDotsLabs/watermill"
"github.com/cenkalti/backoff/v3"
"github.com/pkg/errors"
"github.com/streadway/amqp"
)

type connectionWrapper struct {
config Config
// ConnectionWrapper manages an AMQP connection.
type ConnectionWrapper struct {
config ConnectionConfig

logger watermill.LoggerAdapter

Expand All @@ -19,21 +21,21 @@ type connectionWrapper struct {
connected chan struct{}

closing chan struct{}
closed bool
closed uint32

publishingWg sync.WaitGroup
subscribingWg sync.WaitGroup
connectionWaitGroup sync.WaitGroup
}

func newConnection(
config Config,
// NewConnection returns a new connection wrapper.
func NewConnection(
config ConnectionConfig,
logger watermill.LoggerAdapter,
) (*connectionWrapper, error) {
) (*ConnectionWrapper, error) {
if logger == nil {
logger = watermill.NopLogger{}
}

pubSub := &connectionWrapper{
pubSub := &ConnectionWrapper{
config: config,
logger: logger,
closing: make(chan struct{}),
Expand All @@ -48,18 +50,18 @@ func newConnection(
return pubSub, nil
}

func (c *connectionWrapper) Close() error {
if c.closed {
func (c *ConnectionWrapper) Close() error {
if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) {
// Already closed.
return nil
}
c.closed = true

close(c.closing)

c.logger.Info("Closing AMQP Pub/Sub", nil)
defer c.logger.Info("Closed AMQP Pub/Sub", nil)

c.publishingWg.Wait()
c.subscribingWg.Wait()
c.connectionWaitGroup.Wait()

if err := c.amqpConnection.Close(); err != nil {
c.logger.Error("Connection close error", err, nil)
Expand All @@ -68,24 +70,24 @@ func (c *connectionWrapper) Close() error {
return nil
}

func (c *connectionWrapper) connect() error {
func (c *ConnectionWrapper) connect() error {
c.amqpConnectionLock.Lock()
defer c.amqpConnectionLock.Unlock()

amqpConfig := c.config.Connection.AmqpConfig
if amqpConfig != nil && amqpConfig.TLSClientConfig != nil && c.config.Connection.TLSConfig != nil {
amqpConfig := c.config.AmqpConfig
if amqpConfig != nil && amqpConfig.TLSClientConfig != nil && c.config.TLSConfig != nil {
return errors.New("both Config.AmqpConfig.TLSClientConfig and Config.TLSConfig are set")
}

var connection *amqp.Connection
var err error

if amqpConfig != nil {
connection, err = amqp.DialConfig(c.config.Connection.AmqpURI, *c.config.Connection.AmqpConfig)
} else if c.config.Connection.TLSConfig != nil {
connection, err = amqp.DialTLS(c.config.Connection.AmqpURI, c.config.Connection.TLSConfig)
connection, err = amqp.DialConfig(c.config.AmqpURI, *c.config.AmqpConfig)
} else if c.config.TLSConfig != nil {
connection, err = amqp.DialTLS(c.config.AmqpURI, c.config.TLSConfig)
} else {
connection, err = amqp.Dial(c.config.Connection.AmqpURI)
connection, err = amqp.Dial(c.config.AmqpURI)
}

if err != nil {
Expand All @@ -99,15 +101,15 @@ func (c *connectionWrapper) connect() error {
return nil
}

func (c *connectionWrapper) Connection() *amqp.Connection {
func (c *ConnectionWrapper) Connection() *amqp.Connection {
return c.amqpConnection
}

func (c *connectionWrapper) Connected() chan struct{} {
func (c *ConnectionWrapper) Connected() chan struct{} {
return c.connected
}

func (c *connectionWrapper) IsConnected() bool {
func (c *ConnectionWrapper) IsConnected() bool {
select {
case <-c.connected:
return true
Expand All @@ -116,7 +118,11 @@ func (c *connectionWrapper) IsConnected() bool {
}
}

func (c *connectionWrapper) handleConnectionClose() {
func (c *ConnectionWrapper) Closed() bool {
return atomic.LoadUint32(&c.closed) == 1
}

func (c *ConnectionWrapper) handleConnectionClose() {
for {
c.logger.Debug("handleConnectionClose is waiting for c.connected", nil)
<-c.connected
Expand All @@ -137,8 +143,8 @@ func (c *connectionWrapper) handleConnectionClose() {
}
}

func (c *connectionWrapper) reconnect() {
reconnectConfig := c.config.Connection.Reconnect
func (c *ConnectionWrapper) reconnect() {
reconnectConfig := c.config.Reconnect
if reconnectConfig == nil {
reconnectConfig = DefaultReconnectConfig()
}
Expand All @@ -151,7 +157,7 @@ func (c *connectionWrapper) reconnect() {

c.logger.Error("Cannot reconnect to AMQP, retrying", err, nil)

if c.closed {
if c.Closed() {
return backoff.Permanent(errors.Wrap(err, "closing AMQP connection"))
}

Expand Down
50 changes: 44 additions & 6 deletions pkg/amqp/publisher.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package amqp

import (
"fmt"
"sync"

"github.com/ThreeDotsLabs/watermill"
Expand All @@ -11,28 +12,60 @@ import (
)

type Publisher struct {
*connectionWrapper
*ConnectionWrapper

config Config
publishBindingsLock sync.RWMutex
publishBindingsPrepared map[string]struct{}
closePublisher func() error
}

func NewPublisher(config Config, logger watermill.LoggerAdapter) (*Publisher, error) {
if err := config.ValidatePublisher(); err != nil {
return nil, err
}

conn, err := newConnection(config, logger)
var err error

conn, err := NewConnection(config.Connection, logger)
if err != nil {
return nil, fmt.Errorf("create new connection: %w", err)
}

// Close the connection when the publisher is closed since this publisher owns the connection.
closePublisher := func() error {
logger.Debug("Closing publisher connection.", nil)

return conn.Close()
}

return &Publisher{
conn,
config,
sync.RWMutex{},
make(map[string]struct{}),
closePublisher,
}, nil
}

func NewPublisherWithConnection(config Config, logger watermill.LoggerAdapter, conn *ConnectionWrapper) (*Publisher, error) {
if err := config.ValidatePublisher(); err != nil {
return nil, err
}

// Shared connections should not be closed by the publisher.
closePublisher := func() error {
logger.Debug("Publisher closed.", nil)

return nil
}

return &Publisher{
conn,
config,
sync.RWMutex{},
make(map[string]struct{}),
closePublisher,
}, nil
}

Expand All @@ -44,16 +77,17 @@ func NewPublisher(config Config, logger watermill.LoggerAdapter) (*Publisher, er
// to exchange, queue or routing key.
// For detailed description of nomenclature mapping, please check "Nomenclature" paragraph in doc.go file.
func (p *Publisher) Publish(topic string, messages ...*message.Message) (err error) {
if p.closed {
return errors.New("pub/sub is connection closed")
if p.Closed() {
return errors.New("pub/sub is connection closedChan")
}
p.publishingWg.Add(1)
defer p.publishingWg.Done()

if !p.IsConnected() {
return errors.New("not connected to AMQP")
}

p.connectionWaitGroup.Add(1)
defer p.connectionWaitGroup.Done()

channel, err := p.amqpConnection.Channel()
if err != nil {
return errors.Wrap(err, "cannot open channel")
Expand Down Expand Up @@ -95,6 +129,10 @@ func (p *Publisher) Publish(topic string, messages ...*message.Message) (err err
return nil
}

func (p *Publisher) Close() error {
return p.closePublisher()
}

func (p *Publisher) beginTransaction(channel *amqp.Channel) error {
if err := channel.Tx(); err != nil {
return errors.Wrap(err, "cannot start transaction")
Expand Down
82 changes: 82 additions & 0 deletions pkg/amqp/pubsub_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package amqp_test

import (
"context"
"os"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -122,6 +124,18 @@ func createQueuePubSub(t *testing.T) (message.Publisher, message.Subscriber) {
return publisher, subscriber
}

func createQueuePubSubWithSharedConnection(t *testing.T, config amqp.Config, conn *amqp.ConnectionWrapper, logger watermill.LoggerAdapter) (message.Publisher, message.Subscriber) {
t.Logf("Creating publisher/subscriber with shared connection")

publisher, err := amqp.NewPublisherWithConnection(config, logger, conn)
require.NoError(t, err)

subscriber, err := amqp.NewSubscriberWithConnection(config, logger, conn)
require.NoError(t, err)

return publisher, subscriber
}

func TestPublishSubscribe_queue(t *testing.T) {
tests.TestPubSub(
t,
Expand Down Expand Up @@ -154,6 +168,74 @@ func TestPublishSubscribe_transactional_publish(t *testing.T) {
)
}

func TestPublishSubscribe_queue_with_shared_connection(t *testing.T) {
config := amqp.NewDurableQueueConfig(
amqpURI(),
)

logger := watermill.NewStdLogger(true, true)

conn, err := amqp.NewConnection(config.Connection, logger)
require.NoError(t, err)

tests.TestPubSub(
t,
tests.Features{
ConsumerGroups: false,
ExactlyOnceDelivery: false,
GuaranteedOrder: true,
GuaranteedOrderWithSingleSubscriber: true,
Persistent: true,
},
func(t *testing.T) (message.Publisher, message.Subscriber) {
return createQueuePubSubWithSharedConnection(t, config, conn, logger)
},
nil,
)
}

func TestSharedConnection(t *testing.T) {
const topic = "topicXXX"

config := amqp.NewDurableQueueConfig(
amqpURI(),
)

logger := watermill.NewStdLogger(true, true)

conn, err := amqp.NewConnection(config.Connection, logger)
require.NoError(t, err)

s, err := amqp.NewSubscriberWithConnection(config, logger, conn)
require.NoError(t, err)

msgChan, err := s.Subscribe(context.Background(), topic)
require.NoError(t, err)

p, err := amqp.NewPublisherWithConnection(config, logger, conn)
require.NoError(t, err)

require.NoError(t, p.Publish(topic, message.NewMessage(watermill.NewUUID(), []byte("payload"))))

select {
case <-time.After(time.Second):
t.Fatal("Timed out waiting for message")
case msg := <-msgChan:
msg.Ack()
}

require.NoError(t, conn.Close())

// After closing the connection, the subscriber message channel should also be closed.

select {
case _, open := <-msgChan:
require.False(t, open)
default:
t.Error("messages channel is not closed")
}
}

//func TestClose(t *testing.T) {
// t.Parallel()
//
Expand Down
Loading

0 comments on commit 4f337d7

Please sign in to comment.