diff --git a/.gitignore b/.gitignore index e5612d840..01ce7ec33 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,5 @@ vendor docs/themes/ docs/public/ docs/content/src-link +*.out +*.log diff --git a/Makefile b/Makefile index 5fbbd283f..927c049df 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,16 @@ mycli: @mycli -h 127.0.0.1 -u root -p secret test: - go test -v `glide novendor` + go test ./... + +test_v: + go test -v ./... + +test_short: + go test ./... -short test_stress: - go test -tags=stress `glide novendor` \ No newline at end of file + go test -tags=stress ./... + +test_reconnect: + go test -tags=reconnect ./... diff --git a/dev/coverage.sh b/dev/coverage.sh new file mode 100755 index 000000000..8a921430d --- /dev/null +++ b/dev/coverage.sh @@ -0,0 +1,56 @@ +#!/bin/sh +######## +# Source: https://gist.github.com/lwolf/3764a3b6cd08387e80aa6ca3b9534b8a +# originaly from https://github.com/mlafeldt/chef-runner/blob/v0.7.0/script/coverage +####### +# Generate test coverage statistics for Go packages. +# +# Works around the fact that `go test -coverprofile` currently does not work +# with multiple packages, see https://code.google.com/p/go/issues/detail?id=6909 +# +# Usage: script/coverage [--html|--coveralls] +# +# --html Additionally create HTML report and open it in browser +# --coveralls Push coverage statistics to coveralls.io +# + +set -e + +workdir=.cover +profile="$workdir/cover.out" +mode=count + +generate_cover_data() { + rm -rf "$workdir" + mkdir "$workdir" + + for pkg in "$@"; do + f="$workdir/$(echo $pkg | tr / -).cover" + go test -covermode="$mode" -coverprofile="$f" "$pkg" + done + + echo "mode: $mode" >"$profile" + grep -h -v "^mode:" "$workdir"/*.cover >>"$profile" +} + +show_cover_report() { + go tool cover -${1}="$profile" +} + +push_to_coveralls() { + echo "Pushing coverage statistics to coveralls.io" + goveralls -coverprofile="$profile" +} + +generate_cover_data $(go list ./... | grep -v /vendor/) +show_cover_report func +case "$1" in +"") + ;; +--html) + show_cover_report html ;; +--coveralls) + push_to_coveralls ;; +*) + echo >&2 "error: invalid option: $1"; exit 1 ;; +esac \ No newline at end of file diff --git a/docs/content/docs/getting-started/amqp/main.go b/docs/content/docs/getting-started/amqp/main.go index cf8616789..18208043d 100644 --- a/docs/content/docs/getting-started/amqp/main.go +++ b/docs/content/docs/getting-started/amqp/main.go @@ -5,6 +5,7 @@ import ( "log" "github.com/ThreeDotsLabs/watermill/message/infrastructure/amqp" + uuid "github.com/satori/go.uuid" "github.com/ThreeDotsLabs/watermill" diff --git a/docs/content/docs/getting-started/go-channel/main.go b/docs/content/docs/getting-started/go-channel/main.go index eb2b2b26d..a98c6a104 100644 --- a/docs/content/docs/getting-started/go-channel/main.go +++ b/docs/content/docs/getting-started/go-channel/main.go @@ -3,7 +3,6 @@ package main import ( "log" - "time" "github.com/satori/go.uuid" @@ -17,7 +16,6 @@ func main() { pubSub := gochannel.NewGoChannel( 0, // buffer (channel) size watermill.NewStdLogger(false, false), - time.Second, // send timeout ) messages, err := pubSub.Subscribe("example.topic") diff --git a/docs/content/docs/getting-started/router/main.go b/docs/content/docs/getting-started/router/main.go index 6acc6a574..1d84d6310 100644 --- a/docs/content/docs/getting-started/router/main.go +++ b/docs/content/docs/getting-started/router/main.go @@ -51,7 +51,7 @@ func main() { // for simplicity we are using gochannel Pub/Sub here, // you can replace it with any Pub/Sub implementation, it will work the same - pubSub := gochannel.NewGoChannel(0, logger, time.Second) + pubSub := gochannel.NewGoChannel(0, logger) // producing some messages in background go publishMessages(pubSub) diff --git a/go.mod b/go.mod index 24bd5792f..2cb4108d1 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/ThreeDotsLabs/watermill require ( cloud.google.com/go v0.33.1 github.com/DataDog/zstd v1.3.4 // indirect - github.com/Shopify/sarama v1.20.0 + github.com/Shopify/sarama v1.20.1 github.com/Shopify/toxiproxy v2.1.3+incompatible // indirect github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da // indirect github.com/boltdb/bolt v1.3.1 // indirect diff --git a/go.sum b/go.sum index 294bd2580..eab22d517 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,8 @@ github.com/DataDog/zstd v1.3.4 h1:LAGHkXuvC6yky+C2CUG2tD7w8QlrUwpue8XwIh0X4AY= github.com/DataDog/zstd v1.3.4/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= github.com/Shopify/sarama v1.20.0 h1:wAMHhl1lGRlobeoV/xOKpbqD2OQsOvY4A/vIOGroIe8= github.com/Shopify/sarama v1.20.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= +github.com/Shopify/sarama v1.20.1 h1:Bb0h3I++r4eX333Y0uZV2vwUXepJbt6ig05TUU1qt9I= +github.com/Shopify/sarama v1.20.1/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.3+incompatible h1:awiJqUYH4q4OmoBiRccJykjd7B+w0loJi2keSna4X/M= github.com/Shopify/toxiproxy v2.1.3+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da h1:8GUt8eRujhVEGZFFEjBj46YV4rDjvGrNxb0KMWYkL2I= diff --git a/internal/sync/waitgroup.go b/internal/sync/waitgroup.go index c5ce35ae8..390e80ee8 100644 --- a/internal/sync/waitgroup.go +++ b/internal/sync/waitgroup.go @@ -5,6 +5,8 @@ import ( "time" ) +// WaitGroupTimeout adds timeout feature for sync.WaitGroup.Wait(). +// It returns true, when timeouted. func WaitGroupTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { wgClosed := make(chan struct{}, 1) go func() { diff --git a/internal/sync/waitgroup_test.go b/internal/sync/waitgroup_test.go index 030d44138..68735f4a3 100644 --- a/internal/sync/waitgroup_test.go +++ b/internal/sync/waitgroup_test.go @@ -11,7 +11,7 @@ import ( func TestWaitGroupTimeout_no_timeout(t *testing.T) { wg := &sync.WaitGroup{} - timeouted := WaitGroupTimeout(wg, time.Millisecond*10) + timeouted := WaitGroupTimeout(wg, time.Millisecond*100) assert.False(t, timeouted) } @@ -19,6 +19,6 @@ func TestWaitGroupTimeout_timeout(t *testing.T) { wg := &sync.WaitGroup{} wg.Add(1) - timeouted := WaitGroupTimeout(wg, time.Millisecond*10) + timeouted := WaitGroupTimeout(wg, time.Millisecond*100) assert.True(t, timeouted) } diff --git a/internal/tests/asserts.go b/internal/tests/asserts.go index 1f255a1e2..5535ce56d 100644 --- a/internal/tests/asserts.go +++ b/internal/tests/asserts.go @@ -47,7 +47,9 @@ func AssertAllMessagesReceived(t *testing.T, sent message.Messages, received mes return assert.Equal( t, sentIDs, receivedIDs, - "received different messages ID's, missing: %s", MissingMessages(sent, received), + "received different messages ID's, missing: %s, extra %s", + MissingMessages(sent, received), + MissingMessages(received, sent), ) } diff --git a/log.go b/log.go index d4543aed3..6139fc292 100644 --- a/log.go +++ b/log.go @@ -2,6 +2,7 @@ package watermill import ( "fmt" + "io" "log" "os" "reflect" @@ -24,11 +25,21 @@ func (l LogFields) Add(newFields LogFields) LogFields { return resultFields } +func (l LogFields) Copy() LogFields { + cpy := make(LogFields, len(l)) + for k, v := range l { + cpy[k] = v + } + + return cpy +} + type LoggerAdapter interface { Error(msg string, err error, fields LogFields) Info(msg string, fields LogFields) Debug(msg string, fields LogFields) Trace(msg string, fields LogFields) + With(fields LogFields) LoggerAdapter } type NopLogger struct{} @@ -37,16 +48,23 @@ func (NopLogger) Error(msg string, err error, fields LogFields) {} func (NopLogger) Info(msg string, fields LogFields) {} func (NopLogger) Debug(msg string, fields LogFields) {} func (NopLogger) Trace(msg string, fields LogFields) {} +func (l NopLogger) With(fields LogFields) LoggerAdapter { return l } type StdLoggerAdapter struct { ErrorLogger *log.Logger InfoLogger *log.Logger DebugLogger *log.Logger TraceLogger *log.Logger + + fields LogFields } func NewStdLogger(debug, trace bool) LoggerAdapter { - l := log.New(os.Stderr, "[watermill] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) + return NewStdLoggerWithOut(os.Stderr, debug, trace) +} + +func NewStdLoggerWithOut(out io.Writer, debug bool, trace bool) LoggerAdapter { + l := log.New(out, "[watermill] ", log.LstdFlags|log.Lmicroseconds|log.Lshortfile) a := &StdLoggerAdapter{InfoLogger: l, ErrorLogger: l} if debug { @@ -75,6 +93,16 @@ func (l *StdLoggerAdapter) Trace(msg string, fields LogFields) { l.log(l.TraceLogger, "TRACE", msg, fields) } +func (l *StdLoggerAdapter) With(fields LogFields) LoggerAdapter { + return &StdLoggerAdapter{ + ErrorLogger: l.ErrorLogger, + InfoLogger: l.InfoLogger, + DebugLogger: l.DebugLogger, + TraceLogger: l.TraceLogger, + fields: l.fields.Add(fields), + } +} + func (l *StdLoggerAdapter) log(logger *log.Logger, level string, msg string, fields LogFields) { if logger == nil { return @@ -82,9 +110,11 @@ func (l *StdLoggerAdapter) log(logger *log.Logger, level string, msg string, fie fieldsStr := "" - keys := make([]string, len(fields)) + allFields := l.fields.Add(fields) + + keys := make([]string, len(allFields)) i := 0 - for field := range fields { + for field := range allFields { keys[i] = field i++ } @@ -93,7 +123,7 @@ func (l *StdLoggerAdapter) log(logger *log.Logger, level string, msg string, fie for _, key := range keys { var valueStr string - value := fields[key] + value := allFields[key] if stringer, ok := value.(fmt.Stringer); ok { valueStr = stringer.String() @@ -114,10 +144,10 @@ func (l *StdLoggerAdapter) log(logger *log.Logger, level string, msg string, fie type LogLevel uint const ( - Trace LogLevel = iota + 1 - Debug - Info - Error + TraceLogLevel LogLevel = iota + 1 + DebugLogLevel + InfoLogLevel + ErrorLogLevel ) type CapturedMessage struct { @@ -129,18 +159,27 @@ type CapturedMessage struct { type CaptureLoggerAdapter struct { captured map[LogLevel][]CapturedMessage + fields LogFields } -func NewCaptureLogger() CaptureLoggerAdapter { - return CaptureLoggerAdapter{ +func NewCaptureLogger() *CaptureLoggerAdapter { + return &CaptureLoggerAdapter{ captured: map[LogLevel][]CapturedMessage{}, } } +func (c *CaptureLoggerAdapter) With(fields LogFields) LoggerAdapter { + return &CaptureLoggerAdapter{c.captured, c.fields.Add(fields)} +} + func (c *CaptureLoggerAdapter) capture(msg CapturedMessage) { c.captured[msg.Level] = append(c.captured[msg.Level], msg) } +func (c CaptureLoggerAdapter) Captured() map[LogLevel][]CapturedMessage { + return c.captured +} + func (c CaptureLoggerAdapter) Has(msg CapturedMessage) bool { for _, capturedMsg := range c.captured[msg.Level] { if reflect.DeepEqual(msg, capturedMsg) { @@ -151,7 +190,7 @@ func (c CaptureLoggerAdapter) Has(msg CapturedMessage) bool { } func (c CaptureLoggerAdapter) HasError(err error) bool { - for _, capturedMsg := range c.captured[Error] { + for _, capturedMsg := range c.captured[ErrorLogLevel] { if capturedMsg.Err == err { return true } @@ -161,8 +200,8 @@ func (c CaptureLoggerAdapter) HasError(err error) bool { func (c *CaptureLoggerAdapter) Error(msg string, err error, fields LogFields) { c.capture(CapturedMessage{ - Level: Error, - Fields: fields, + Level: ErrorLogLevel, + Fields: c.fields.Add(fields), Msg: msg, Err: err, }) @@ -170,24 +209,24 @@ func (c *CaptureLoggerAdapter) Error(msg string, err error, fields LogFields) { func (c *CaptureLoggerAdapter) Info(msg string, fields LogFields) { c.capture(CapturedMessage{ - Level: Info, - Fields: fields, + Level: InfoLogLevel, + Fields: c.fields.Add(fields), Msg: msg, }) } func (c *CaptureLoggerAdapter) Debug(msg string, fields LogFields) { c.capture(CapturedMessage{ - Level: Debug, - Fields: fields, + Level: DebugLogLevel, + Fields: c.fields.Add(fields), Msg: msg, }) } func (c *CaptureLoggerAdapter) Trace(msg string, fields LogFields) { c.capture(CapturedMessage{ - Level: Trace, - Fields: fields, + Level: TraceLogLevel, + Fields: c.fields.Add(fields), Msg: msg, }) } diff --git a/log_test.go b/log_test.go new file mode 100644 index 000000000..3ecdafc3d --- /dev/null +++ b/log_test.go @@ -0,0 +1,136 @@ +package watermill_test + +import ( + "bytes" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ThreeDotsLabs/watermill" +) + +func TestLogFields_Copy(t *testing.T) { + fields1 := watermill.LogFields{"foo": "bar"} + + fields2 := fields1.Copy() + fields2["foo"] = "baz" + + assert.Equal(t, fields1["foo"], "bar") + assert.Equal(t, fields2["foo"], "baz") +} + +func TestStdLogger_with(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + cleanLogger := watermill.NewStdLoggerWithOut(buf, true, true) + + withLogFieldsLogger := cleanLogger.With(watermill.LogFields{"foo": "1"}) + + for name, logger := range map[string]watermill.LoggerAdapter{"clean": cleanLogger, "with": withLogFieldsLogger} { + logger.Error(name, nil, watermill.LogFields{"bar": "2"}) + logger.Info(name, watermill.LogFields{"bar": "2"}) + logger.Debug(name, watermill.LogFields{"bar": "2"}) + logger.Trace(name, watermill.LogFields{"bar": "2"}) + } + + cleanLoggerOut := buf.String() + assert.Contains(t, cleanLoggerOut, `level=ERROR msg="clean" bar=2 err=`) + assert.Contains(t, cleanLoggerOut, `level=INFO msg="clean" bar=2`) + assert.Contains(t, cleanLoggerOut, `level=TRACE msg="clean" bar=2`) + + assert.Contains(t, cleanLoggerOut, `level=ERROR msg="with" bar=2 err= foo=1`) + assert.Contains(t, cleanLoggerOut, `level=INFO msg="with" bar=2 foo=1`) + assert.Contains(t, cleanLoggerOut, `level=TRACE msg="with" bar=2 foo=1`) +} + +type stringer struct{} + +func (s stringer) String() string { + return "stringer" +} + +func TestStdLoggerAdapter_stringer_field(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + logger := watermill.NewStdLoggerWithOut(buf, true, true) + + logger.Info("foo", watermill.LogFields{"foo": stringer{}}) + + out := buf.String() + assert.Contains(t, out, `foo=stringer`) +} + +func TestStdLoggerAdapter_field_with_space(t *testing.T) { + buf := bytes.NewBuffer([]byte{}) + logger := watermill.NewStdLoggerWithOut(buf, true, true) + + logger.Info("foo", watermill.LogFields{"foo": `bar baz`}) + + out := buf.String() + assert.Contains(t, out, `foo="bar baz"`) +} + +func TestCaptureLoggerAdapter(t *testing.T) { + var logger watermill.LoggerAdapter = watermill.NewCaptureLogger() + + err := errors.New("error") + + logger = logger.With(watermill.LogFields{"default": "field"}) + logger.Error("error", err, watermill.LogFields{"bar": "2"}) + logger.Info("info", watermill.LogFields{"bar": "2"}) + logger.Debug("debug", watermill.LogFields{"bar": "2"}) + logger.Trace("trace", watermill.LogFields{"bar": "2"}) + + expectedLogs := map[watermill.LogLevel][]watermill.CapturedMessage{ + watermill.TraceLogLevel: { + watermill.CapturedMessage{ + Level: watermill.TraceLogLevel, + Fields: watermill.LogFields{"bar": "2", "default": "field"}, + Msg: "trace", + Err: error(nil), + }, + }, + watermill.DebugLogLevel: { + watermill.CapturedMessage{ + Level: watermill.DebugLogLevel, + Fields: watermill.LogFields{"default": "field", "bar": "2"}, + Msg: "debug", + Err: error(nil), + }, + }, + watermill.InfoLogLevel: { + watermill.CapturedMessage{ + Level: watermill.InfoLogLevel, + Fields: watermill.LogFields{"default": "field", "bar": "2"}, + Msg: "info", + Err: error(nil), + }, + }, + watermill.ErrorLogLevel: { + watermill.CapturedMessage{ + Level: watermill.ErrorLogLevel, + Fields: watermill.LogFields{"default": "field", "bar": "2"}, + Msg: "error", + Err: err, + }, + }, + } + + capturedLogger := logger.(*watermill.CaptureLoggerAdapter) + assert.EqualValues(t, expectedLogs, capturedLogger.Captured()) + + for _, logs := range expectedLogs { + for _, log := range logs { + assert.True(t, capturedLogger.Has(log)) + } + } + + assert.False(t, capturedLogger.Has(watermill.CapturedMessage{ + Level: 0, + Fields: nil, + Msg: "", + Err: nil, + })) + + assert.True(t, capturedLogger.HasError(err)) + assert.False(t, capturedLogger.HasError(errors.New("foo"))) +} diff --git a/message/infrastructure/amqp/pubsub_reconnect_test.go b/message/infrastructure/amqp/pubsub_reconnect_test.go new file mode 100644 index 000000000..f8eb9a23c --- /dev/null +++ b/message/infrastructure/amqp/pubsub_reconnect_test.go @@ -0,0 +1,18 @@ +// +build reconnect + +package amqp_test + +import ( + "testing" + + "github.com/ThreeDotsLabs/watermill/message/infrastructure" +) + +func TestPublishSubscribe_reconnect(t *testing.T) { + infrastructure.TestReconnect(t, createPubSub(t), infrastructure.Features{ + ConsumerGroups: true, + ExactlyOnceDelivery: false, + GuaranteedOrder: false, + Persistent: true, + }) +} diff --git a/message/infrastructure/amqp/pubsub_test.go b/message/infrastructure/amqp/pubsub_test.go index 044d74f7f..9e4a77d0d 100644 --- a/message/infrastructure/amqp/pubsub_test.go +++ b/message/infrastructure/amqp/pubsub_test.go @@ -13,7 +13,7 @@ import ( var amqpURI = "amqp://guest:guest@localhost:5672/" -func createPubSub(t *testing.T) message.PubSub { +func createPubSub(t *testing.T) infrastructure.PubSub { publisher, err := amqp.NewPublisher( amqp.NewDurablePubSubConfig( amqpURI, @@ -32,10 +32,10 @@ func createPubSub(t *testing.T) message.PubSub { ) require.NoError(t, err) - return message.NewPubSub(publisher, subscriber) + return message.NewPubSub(publisher, subscriber).(infrastructure.PubSub) } -func createPubSubWithConsumerGroup(t *testing.T, consumerGroup string) message.PubSub { +func createPubSubWithConsumerGroup(t *testing.T, consumerGroup string) infrastructure.PubSub { publisher, err := amqp.NewPublisher( amqp.NewDurablePubSubConfig( amqpURI, @@ -54,7 +54,7 @@ func createPubSubWithConsumerGroup(t *testing.T, consumerGroup string) message.P ) require.NoError(t, err) - return message.NewPubSub(publisher, subscriber) + return message.NewPubSub(publisher, subscriber).(infrastructure.PubSub) } func TestPublishSubscribe_pubsub(t *testing.T) { @@ -72,7 +72,7 @@ func TestPublishSubscribe_pubsub(t *testing.T) { ) } -func createQueuePubSub(t *testing.T) message.PubSub { +func createQueuePubSub(t *testing.T) infrastructure.PubSub { config := amqp.NewDurableQueueConfig( amqpURI, ) @@ -89,7 +89,7 @@ func createQueuePubSub(t *testing.T) message.PubSub { ) require.NoError(t, err) - return message.NewPubSub(publisher, subscriber) + return message.NewPubSub(publisher, subscriber).(infrastructure.PubSub) } func TestPublishSubscribe_queue(t *testing.T) { @@ -126,5 +126,15 @@ func TestPublishSubscribe_transactional_publish(t *testing.T) { ) require.NoError(t, err) - infrastructure.TestPublishSubscribe(t, message.NewPubSub(publisher, subscriber)) + infrastructure.TestPublishSubscribe( + t, + message.NewPubSub(publisher, subscriber).(infrastructure.PubSub), + infrastructure.Features{ + ConsumerGroups: true, + ExactlyOnceDelivery: false, + GuaranteedOrder: true, + Persistent: true, + RestartServiceCommand: []string{"docker", "restart", "watermill_rabbitmq_1"}, + }, + ) } diff --git a/message/infrastructure/amqp/subscriber.go b/message/infrastructure/amqp/subscriber.go index ca5943de7..711447499 100644 --- a/message/infrastructure/amqp/subscriber.go +++ b/message/infrastructure/amqp/subscriber.go @@ -36,12 +36,12 @@ func NewSubscriber(config Config, logger watermill.LoggerAdapter) (*Subscriber, // Watermill's topic in Subscribe is not mapped to AMQP's topic, but depending on configuration it can be mapped // to exchange, queue or routing key. // For detailed description of nomenclature mapping, please check "Nomenclature" paragraph in doc.go file. -func (p *Subscriber) Subscribe(topic string) (chan *message.Message, error) { - if p.closed { +func (s *Subscriber) Subscribe(topic string) (chan *message.Message, error) { + if s.closed { return nil, errors.New("pub/sub is closed") } - if !p.IsConnected() { + if !s.IsConnected() { return nil, errors.New("not connected to AMQP") } @@ -49,35 +49,35 @@ func (p *Subscriber) Subscribe(topic string) (chan *message.Message, error) { out := make(chan *message.Message, 0) - queueName := p.config.Queue.GenerateName(topic) + queueName := s.config.Queue.GenerateName(topic) logFields["amqp_queue_name"] = queueName - exchangeName := p.config.Exchange.GenerateName(topic) + exchangeName := s.config.Exchange.GenerateName(topic) logFields["amqp_exchange_name"] = exchangeName - if err := p.prepareConsume(queueName, exchangeName, logFields); err != nil { + if err := s.prepareConsume(queueName, exchangeName, logFields); err != nil { return nil, errors.Wrap(err, "failed to prepare consume") } - p.subscribingWg.Add(1) + s.subscribingWg.Add(1) go func() { defer func() { close(out) - p.logger.Info("Stopped consuming from AMQP channel", logFields) - p.subscribingWg.Done() + s.logger.Info("Stopped consuming from AMQP channel", logFields) + s.subscribingWg.Done() }() ReconnectLoop: for { - p.logger.Debug("Waiting for p.connected or p.closing in ReconnectLoop", logFields) + s.logger.Debug("Waiting for s.connected or s.closing in ReconnectLoop", logFields) select { - case <-p.connected: - p.logger.Debug("Connection established in ReconnectLoop", logFields) + case <-s.connected: + s.logger.Debug("Connection established in ReconnectLoop", logFields) // runSubscriber blocks until connection fails or Close() is called - p.runSubscriber(out, queueName, exchangeName, logFields) - case <-p.closing: - p.logger.Debug("Stopping ReconnectLoop", logFields) + s.runSubscriber(out, queueName, exchangeName, logFields) + case <-s.closing: + s.logger.Debug("Stopping ReconnectLoop", logFields) break ReconnectLoop } @@ -88,8 +88,30 @@ func (p *Subscriber) Subscribe(topic string) (chan *message.Message, error) { return out, nil } -func (p *Subscriber) prepareConsume(queueName string, exchangeName string, logFields watermill.LogFields) (err error) { - channel, err := p.openSubscribeChannel(logFields) +func (s *Subscriber) SubscribeInitialize(topic string) (err error) { + if s.closed { + return errors.New("pub/sub is closed") + } + + if !s.IsConnected() { + return errors.New("not connected to AMQP") + } + + logFields := watermill.LogFields{"topic": topic} + + queueName := s.config.Queue.GenerateName(topic) + logFields["amqp_queue_name"] = queueName + + exchangeName := s.config.Exchange.GenerateName(topic) + logFields["amqp_exchange_name"] = exchangeName + + s.logger.Info("Initializing subscribe", logFields) + + return errors.Wrap(s.prepareConsume(queueName, exchangeName, logFields), "failed to prepare consume") +} + +func (s *Subscriber) prepareConsume(queueName string, exchangeName string, logFields watermill.LogFields) (err error) { + channel, err := s.openSubscribeChannel(logFields) if err != nil { return err } @@ -101,49 +123,49 @@ func (p *Subscriber) prepareConsume(queueName string, exchangeName string, logFi if _, err := channel.QueueDeclare( queueName, - p.config.Queue.Durable, - p.config.Queue.AutoDelete, - p.config.Queue.Exclusive, - p.config.Queue.NoWait, - p.config.Queue.Arguments, + s.config.Queue.Durable, + s.config.Queue.AutoDelete, + s.config.Queue.Exclusive, + s.config.Queue.NoWait, + s.config.Queue.Arguments, ); err != nil { return errors.Wrap(err, "cannot declare queue") } - p.logger.Debug("Queue declared", logFields) + s.logger.Debug("Queue declared", logFields) if exchangeName == "" { - p.logger.Debug("No exchange to declare", logFields) + s.logger.Debug("No exchange to declare", logFields) return nil } - if err := p.exchangeDeclare(channel, exchangeName); err != nil { + if err := s.exchangeDeclare(channel, exchangeName); err != nil { return errors.Wrap(err, "cannot declare exchange") } - p.logger.Debug("Exchange declared", logFields) + s.logger.Debug("Exchange declared", logFields) if err := channel.QueueBind( queueName, - p.config.QueueBind.RoutingKey, + s.config.QueueBind.RoutingKey, exchangeName, - p.config.QueueBind.NoWait, - p.config.QueueBind.Arguments, + s.config.QueueBind.NoWait, + s.config.QueueBind.Arguments, ); err != nil { return errors.Wrap(err, "cannot bind queue") } - p.logger.Debug("Queue bound to exchange", logFields) + s.logger.Debug("Queue bound to exchange", logFields) return nil } -func (p *Subscriber) runSubscriber(out chan *message.Message, queueName string, exchangeName string, logFields watermill.LogFields) { - channel, err := p.openSubscribeChannel(logFields) +func (s *Subscriber) runSubscriber(out chan *message.Message, queueName string, exchangeName string, logFields watermill.LogFields) { + channel, err := s.openSubscribeChannel(logFields) if err != nil { - p.logger.Error("Failed to open channel", err, logFields) + s.logger.Error("Failed to open channel", err, logFields) return } defer func() { err := channel.Close() - p.logger.Error("Failed to close channel", err, logFields) + s.logger.Error("Failed to close channel", err, logFields) }() notifyCloseChannel := channel.NotifyClose(make(chan *amqp.Error)) @@ -154,34 +176,34 @@ func (p *Subscriber) runSubscriber(out chan *message.Message, queueName string, notifyCloseChannel: notifyCloseChannel, channel: channel, queueName: queueName, - logger: p.logger, - closing: p.closing, - config: p.config, + logger: s.logger, + closing: s.closing, + config: s.config, } - p.logger.Info("Starting consuming from AMQP channel", logFields) + s.logger.Info("Starting consuming from AMQP channel", logFields) sub.ProcessMessages() } -func (p *Subscriber) openSubscribeChannel(logFields watermill.LogFields) (*amqp.Channel, error) { - if !p.IsConnected() { +func (s *Subscriber) openSubscribeChannel(logFields watermill.LogFields) (*amqp.Channel, error) { + if !s.IsConnected() { return nil, errors.New("not connected to AMQP") } - channel, err := p.amqpConnection.Channel() + channel, err := s.amqpConnection.Channel() if err != nil { return nil, errors.Wrap(err, "cannot open channel") } - p.logger.Debug("Channel opened", logFields) + s.logger.Debug("Channel opened", logFields) - if p.config.Consume.Qos != (QosConfig{}) { + if s.config.Consume.Qos != (QosConfig{}) { err = channel.Qos( - p.config.Consume.Qos.PrefetchCount, // prefetch count - p.config.Consume.Qos.PrefetchSize, // prefetch size - p.config.Consume.Qos.Global, // global + s.config.Consume.Qos.PrefetchCount, + s.config.Consume.Qos.PrefetchSize, + s.config.Consume.Qos.Global, ) - p.logger.Debug("Qos set", logFields) + s.logger.Debug("Qos set", logFields) } return channel, nil diff --git a/message/infrastructure/gochannel/pubsub.go b/message/infrastructure/gochannel/pubsub.go index 86d02a7a0..1c7892035 100644 --- a/message/infrastructure/gochannel/pubsub.go +++ b/message/infrastructure/gochannel/pubsub.go @@ -3,7 +3,8 @@ package gochannel import ( "context" "sync" - "time" + + "github.com/renstrom/shortuuid" "github.com/pkg/errors" @@ -23,61 +24,103 @@ type subscriber struct { // // GoChannel has no global state, // that means that you need to use the same instance for Publishing and Subscribing! +// +// When GoChannel is persistent, messages order is not guaranteed. type GoChannel struct { - sendTimeout time.Duration - buffer int64 + outputChannelBuffer int64 - subscribers map[string][]*subscriber - subscribersLock *sync.RWMutex + subscribers map[string][]*subscriber + subscribersLock sync.RWMutex + subscribersByTopicLock sync.Map // map of *sync.Mutex logger watermill.LoggerAdapter closed bool closing chan struct{} + + // If persistent is set to true, when subscriber subscribes to the topic, + // it will receive all previously produced messages. + // All messages are persisted to the memory, + // so be aware that with large amount of messages you can go out of the memory. + persistent bool + + persistedMessages map[string][]*message.Message +} + +func (g *GoChannel) Publisher() message.Publisher { + return g } -func NewGoChannel(buffer int64, logger watermill.LoggerAdapter, sendTimeout time.Duration) message.PubSub { +func (g *GoChannel) Subscriber() message.Subscriber { + return g +} + +// NewGoChannel creates new GoChannel Pub/Sub. +// +// This GoChannel is not persistent. +// That means if you send a message to a topic to which no subscriber is subscribed, that message will be discarded. +func NewGoChannel(outputChannelBuffer int64, logger watermill.LoggerAdapter) message.PubSub { return &GoChannel{ - sendTimeout: sendTimeout, - buffer: buffer, + outputChannelBuffer: outputChannelBuffer, - subscribers: make(map[string][]*subscriber), - subscribersLock: &sync.RWMutex{}, - logger: logger, + subscribers: make(map[string][]*subscriber), + subscribersByTopicLock: sync.Map{}, + logger: logger.With(watermill.LogFields{ + "pubsub_uuid": shortuuid.New(), + }), closing: make(chan struct{}), } } -// Publish in GoChannel is blocking until all consumers consume and acknowledge the message. -// Sending message to one subscriber has timeout equal to GoChannel.sendTimeout configured via constructor. +// NewPersistentGoChannel creates new persistent GoChannel Pub/Sub. // -// Messages are not persisted. If there are no subscribers and message is produced it will be gone. -func (g *GoChannel) Publish(topic string, messages ...*message.Message) error { - for _, msg := range messages { - if err := g.sendMessage(topic, msg); err != nil { - return err - } - } +// This GoChannel is persistent. +// That means that when subscriber subscribes to the topic, it will receive all previously produced messages. +// All messages are persisted to the memory, so be aware that with large amount of messages you can go out of the memory. +// +// Messages are persisted per GoChannel, so you must use the same object to consume these persisted messages. +func NewPersistentGoChannel(outputChannelBuffer int64, logger watermill.LoggerAdapter) message.PubSub { + return &GoChannel{ + outputChannelBuffer: outputChannelBuffer, - return nil + subscribers: make(map[string][]*subscriber), + logger: logger.With(watermill.LogFields{ + "pubsub_uuid": shortuuid.New(), + }), + + closing: make(chan struct{}), + + persistent: true, + persistedMessages: map[string][]*message.Message{}, + } } -func (g *GoChannel) sendMessage(topic string, message *message.Message) error { - messageLogFields := watermill.LogFields{ - "message_uuid": message.UUID, +// Publish in GoChannel is NOT blocking until all consumers consume. +// Messages will be send in background. +// +// Messages may be persisted or not, depending of persistent attribute. +func (g *GoChannel) Publish(topic string, messages ...*message.Message) error { + if g.closed { + return errors.New("Pub/Sub closed") } g.subscribersLock.RLock() defer g.subscribersLock.RUnlock() - subscribers, ok := g.subscribers[topic] - if !ok { - return nil + subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{}) + subLock.(*sync.Mutex).Lock() + defer subLock.(*sync.Mutex).Unlock() + + if g.persistent { + if _, ok := g.persistedMessages[topic]; !ok { + g.persistedMessages[topic] = make([]*message.Message, 0) + } + g.persistedMessages[topic] = append(g.persistedMessages[topic], messages...) } - for _, s := range subscribers { - if err := g.sendMessageToSubscriber(message, s, messageLogFields); err != nil { + for i := range messages { + if err := g.sendMessage(topic, messages[i]); err != nil { return err } } @@ -85,46 +128,66 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) error { return nil } -func (g *GoChannel) sendMessageToSubscriber(msg *message.Message, s *subscriber, messageLogFields watermill.LogFields) error { - subscriberLogFields := messageLogFields.Add(watermill.LogFields{ - "subscriber_uuid": s.uuid, - }) +func (g *GoChannel) sendMessage(topic string, message *message.Message) error { + subscribers := g.topicSubscribers(topic) + if len(subscribers) == 0 { + return nil + } + + for i := range subscribers { + s := subscribers[i] + + go func(subscriber *subscriber) { + g.sendMessageToSubscriber(message, subscriber) + }(s) + } + + return nil +} + +func (g *GoChannel) sendMessageToSubscriber(msg *message.Message, s *subscriber) { + subscriberLogFields := watermill.LogFields{ + "message_uuid": msg.UUID, + "pubsub_uuid": s.uuid, + } + + ctx, cancelCtx := context.WithCancel(context.Background()) + defer cancelCtx() SendToSubscriber: for { // copy the message to prevent ack/nack propagation to other consumers // also allows to make retries on a fresh copy of the original message msgToSend := msg.Copy() - - ctx, cancelCtx := context.WithCancel(context.Background()) msgToSend.SetContext(ctx) - defer cancelCtx() + + g.logger.Trace("Sending msg to subscriber", subscriberLogFields) + + if g.closed { + g.logger.Info("Pub/Sub closed, discarding msg", subscriberLogFields) + return + } select { case s.outputChannel <- msgToSend: g.logger.Trace("Sent message to subscriber", subscriberLogFields) - case <-time.After(g.sendTimeout): - return errors.Errorf("Sending message %s timeouted after %s", msgToSend.UUID, g.sendTimeout) case <-g.closing: g.logger.Trace("Closing, message discarded", subscriberLogFields) - return nil + return } select { case <-msgToSend.Acked(): g.logger.Trace("Message acked", subscriberLogFields) - break SendToSubscriber + return case <-msgToSend.Nacked(): g.logger.Trace("Nack received, resending message", subscriberLogFields) - continue SendToSubscriber case <-g.closing: g.logger.Trace("Closing, message discarded", subscriberLogFields) - return nil + return } } - - return nil } // Subscribe returns channel to which all published messages are sent. @@ -132,37 +195,94 @@ SendToSubscriber: // // There are no consumer groups support etc. Every consumer will receive every produced message. func (g *GoChannel) Subscribe(topic string) (chan *message.Message, error) { + if g.closed { + return nil, errors.New("Pub/Sub closed") + } + g.subscribersLock.Lock() - defer g.subscribersLock.Unlock() - if _, ok := g.subscribers[topic]; !ok { - g.subscribers[topic] = make([]*subscriber, 0) - } + subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{}) + subLock.(*sync.Mutex).Lock() s := &subscriber{ uuid: uuid.NewV4().String(), - outputChannel: make(chan *message.Message, g.buffer), + outputChannel: make(chan *message.Message, g.outputChannelBuffer), } - g.subscribers[topic] = append(g.subscribers[topic], s) + + if !g.persistent { + defer g.subscribersLock.Unlock() + defer subLock.(*sync.Mutex).Unlock() + + g.addSubscriber(topic, s) + + return s.outputChannel, nil + } + + go func(s *subscriber) { + defer g.subscribersLock.Unlock() + defer subLock.(*sync.Mutex).Unlock() + + if messages, ok := g.persistedMessages[topic]; ok { + for i := range messages { + msg := g.persistedMessages[topic][i] + + go func() { + g.sendMessageToSubscriber(msg, s) + }() + } + } + + g.addSubscriber(topic, s) + }(s) return s.outputChannel, nil } +func (g *GoChannel) addSubscriber(topic string, s *subscriber) { + if _, ok := g.subscribers[topic]; !ok { + g.subscribers[topic] = make([]*subscriber, 0) + } + g.subscribers[topic] = append(g.subscribers[topic], s) +} + +func (g *GoChannel) topicSubscribers(topic string) []*subscriber { + subscribers, ok := g.subscribers[topic] + if !ok { + return nil + } + + return subscribers +} + func (g *GoChannel) Close() error { if g.closed { return nil } + g.closed = true close(g.closing) g.subscribersLock.Lock() defer g.subscribersLock.Unlock() - for _, topicSubscribers := range g.subscribers { + g.logger.Info("Closing Pub/Sub", nil) + + for topic, topicSubscribers := range g.subscribers { + subLock, _ := g.subscribersByTopicLock.LoadOrStore(topic, &sync.Mutex{}) + subLock.(*sync.Mutex).Lock() + for _, subscriber := range topicSubscribers { + g.logger.Debug("Closing subscriber channel", watermill.LogFields{ + "subscriber_uuid": subscriber.uuid, + }) close(subscriber.outputChannel) } + + subLock.(*sync.Mutex).Unlock() } + g.logger.Info("Pub/Sub closed", nil) + g.persistedMessages = nil + return nil } diff --git a/message/infrastructure/gochannel/pubsub_bench_test.go b/message/infrastructure/gochannel/pubsub_bench_test.go index e83e96676..05c3a22e2 100644 --- a/message/infrastructure/gochannel/pubsub_bench_test.go +++ b/message/infrastructure/gochannel/pubsub_bench_test.go @@ -2,7 +2,6 @@ package gochannel_test import ( "testing" - "time" "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message/infrastructure/gochannel" @@ -13,6 +12,12 @@ import ( func BenchmarkSubscriber(b *testing.B) { infrastructure.BenchSubscriber(b, func(n int) message.PubSub { - return gochannel.NewGoChannel(int64(n), watermill.NopLogger{}, time.Second) + return gochannel.NewGoChannel(int64(n), watermill.NopLogger{}) + }) +} + +func BenchmarkSubscriberPersistent(b *testing.B) { + infrastructure.BenchSubscriber(b, func(n int) message.PubSub { + return gochannel.NewPersistentGoChannel(int64(n), watermill.NopLogger{}) }) } diff --git a/message/infrastructure/gochannel/pubsub_test.go b/message/infrastructure/gochannel/pubsub_test.go index 3af49f900..255f58f69 100644 --- a/message/infrastructure/gochannel/pubsub_test.go +++ b/message/infrastructure/gochannel/pubsub_test.go @@ -1,35 +1,132 @@ package gochannel_test import ( + "fmt" + "log" + "sync" "testing" "time" "github.com/ThreeDotsLabs/watermill" - - "github.com/ThreeDotsLabs/watermill/message/infrastructure/gochannel" - + "github.com/ThreeDotsLabs/watermill/internal/tests" "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message/infrastructure" + "github.com/ThreeDotsLabs/watermill/message/infrastructure/gochannel" + "github.com/ThreeDotsLabs/watermill/message/subscriber" + "github.com/satori/go.uuid" + "github.com/stretchr/testify/require" ) -func createPubSub(t *testing.T) message.PubSub { - return gochannel.NewGoChannel( - 0, +func createPersistentPubSub(t *testing.T) infrastructure.PubSub { + return gochannel.NewPersistentGoChannel( + 10000, watermill.NewStdLogger(true, true), - time.Second*10, - ) + ).(infrastructure.PubSub) } -func TestPublishSubscribe(t *testing.T) { +func TestPublishSubscribe_persistent(t *testing.T) { infrastructure.TestPubSub( t, infrastructure.Features{ ConsumerGroups: false, ExactlyOnceDelivery: true, - GuaranteedOrder: true, + GuaranteedOrder: false, Persistent: false, }, - createPubSub, + createPersistentPubSub, nil, ) } + +func TestPublishSubscribe_not_persistent(t *testing.T) { + messagesCount := 100 + pubSub := gochannel.NewGoChannel( + int64(messagesCount), + watermill.NewStdLogger(true, true), + ) + topicName := "test_topic_" + uuid.NewV4().String() + + msgs, err := pubSub.Subscribe(topicName) + require.NoError(t, err) + + sendMessages := infrastructure.AddSimpleMessages(t, messagesCount, pubSub, topicName) + receivedMsgs, _ := subscriber.BulkRead(msgs, messagesCount, time.Second) + + tests.AssertAllMessagesReceived(t, sendMessages, receivedMsgs) +} + +func TestPublishSubscribe_race_condition_on_subscribe(t *testing.T) { + testsCount := 15 + if testing.Short() { + testsCount = 3 + } + + for i := 0; i < testsCount; i++ { + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() + testPublishSubscribeSubRace(t) + }) + } +} + +func testPublishSubscribeSubRace(t *testing.T) { + t.Helper() + + messagesCount := 500 + subscribersCount := 200 + if testing.Short() { + messagesCount = 200 + subscribersCount = 20 + } + + pubSub := gochannel.NewPersistentGoChannel( + int64(messagesCount), + watermill.NewStdLogger(true, false), + ) + + allSent := sync.WaitGroup{} + allSent.Add(messagesCount) + allReceived := sync.WaitGroup{} + + sentMessages := message.Messages{} + go func() { + for i := 0; i < messagesCount; i++ { + msg := message.NewMessage(uuid.NewV4().String(), nil) + sentMessages = append(sentMessages, msg) + + go func() { + require.NoError(t, pubSub.Publish("topic", msg)) + allSent.Done() + }() + } + }() + + subscriberReceivedCh := make(chan message.Messages, subscribersCount) + for i := 0; i < subscribersCount; i++ { + allReceived.Add(1) + + go func() { + msgs, err := pubSub.Subscribe("topic") + require.NoError(t, err) + + received, _ := subscriber.BulkRead(msgs, messagesCount, time.Second*10) + subscriberReceivedCh <- received + + allReceived.Done() + }() + } + + log.Println("waiting for all sent") + allSent.Wait() + + log.Println("waiting for all received") + allReceived.Wait() + + close(subscriberReceivedCh) + + log.Println("asserting") + + for subMsgs := range subscriberReceivedCh { + tests.AssertAllMessagesReceived(t, sentMessages, subMsgs) + } +} diff --git a/message/infrastructure/googlecloud/pubsub_test.go b/message/infrastructure/googlecloud/pubsub_test.go index 5be0bb5df..bae606709 100644 --- a/message/infrastructure/googlecloud/pubsub_test.go +++ b/message/infrastructure/googlecloud/pubsub_test.go @@ -48,14 +48,14 @@ func newPubSub(t *testing.T, marshaler googlecloud.MarshalerUnmarshaler, subscri return message.NewPubSub(publisher, subscriber) } -func createPubSubWithSubscriptionName(t *testing.T, subscriptionName string) message.PubSub { +func createPubSubWithSubscriptionName(t *testing.T, subscriptionName string) infrastructure.PubSub { return newPubSub(t, googlecloud.DefaultMarshalerUnmarshaler{}, googlecloud.TopicSubscriptionNameWithSuffix(subscriptionName), - ) + ).(infrastructure.PubSub) } -func createPubSub(t *testing.T) message.PubSub { - return newPubSub(t, googlecloud.DefaultMarshalerUnmarshaler{}, googlecloud.TopicSubscriptionName) +func createPubSub(t *testing.T) infrastructure.PubSub { + return newPubSub(t, googlecloud.DefaultMarshalerUnmarshaler{}, googlecloud.TopicSubscriptionName).(infrastructure.PubSub) } func TestPublishSubscribe(t *testing.T) { diff --git a/message/infrastructure/googlecloud/subscriber.go b/message/infrastructure/googlecloud/subscriber.go index 4f136b75f..c831012a4 100644 --- a/message/infrastructure/googlecloud/subscriber.go +++ b/message/infrastructure/googlecloud/subscriber.go @@ -157,7 +157,6 @@ func (s *Subscriber) Subscribe(topic string) (chan *message.Message, error) { sub, err := s.subscription(ctx, subscriptionName, topic) if err != nil { - s.logger.Error("Could not obtain subscription", err, logFields) return nil, err } @@ -184,6 +183,25 @@ func (s *Subscriber) Subscribe(topic string) (chan *message.Message, error) { return output, nil } +func (s *Subscriber) SubscribeInitialize(topic string) (err error) { + ctx, cancel := context.WithCancel(s.ctx) + defer cancel() + + subscriptionName := s.config.GenerateSubscriptionName(topic) + logFields := watermill.LogFields{ + "provider": ProviderName, + "topic": topic, + "subscription_name": subscriptionName, + } + s.logger.Info("Subscribing to Google Cloud PubSub topic", logFields) + + if _, err := s.subscription(ctx, subscriptionName, topic); err != nil { + return err + } + + return nil +} + // Close notifies the Subscriber to stop processing messages on all subscriptions, close all the output channels // and terminate the connection. func (s *Subscriber) Close() error { @@ -218,9 +236,9 @@ func (s *Subscriber) receive( return } - msgCtx, cancel := context.WithCancel(ctx) - defer cancel() - msg.SetContext(msgCtx) + ctx, cancelCtx := context.WithCancel(context.Background()) + msg.SetContext(ctx) + defer cancelCtx() select { case <-s.closing: diff --git a/message/infrastructure/http/publisher.go b/message/infrastructure/http/publisher.go index 5e32b7684..fc7ab4ed3 100644 --- a/message/infrastructure/http/publisher.go +++ b/message/infrastructure/http/publisher.go @@ -100,6 +100,8 @@ func (p *Publisher) Publish(topic string, messages ...*message.Message) error { "provider": ProviderName, } + p.logger.Trace("Publishing message", logFields) + resp, err := p.config.Client.Do(req) if err != nil { return errors.Wrapf(err, "publishing message %s failed", msg.UUID) diff --git a/message/infrastructure/http/pubsub_test.go b/message/infrastructure/http/pubsub_test.go index 34b802026..e031b507d 100644 --- a/message/infrastructure/http/pubsub_test.go +++ b/message/infrastructure/http/pubsub_test.go @@ -2,71 +2,41 @@ package http_test import ( "fmt" - "net" - net_http "net/http" "testing" "time" + "github.com/ThreeDotsLabs/watermill/message/subscriber" + + "github.com/ThreeDotsLabs/watermill/internal/tests" + "github.com/ThreeDotsLabs/watermill/message" + "github.com/stretchr/testify/require" "github.com/ThreeDotsLabs/watermill" - "github.com/ThreeDotsLabs/watermill/internal/publisher" - "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message/infrastructure" "github.com/ThreeDotsLabs/watermill/message/infrastructure/http" ) -func createPubSub(t *testing.T) message.PubSub { +func createPubSub(t *testing.T) (*http.Publisher, *http.Subscriber) { logger := watermill.NewStdLogger(true, true) // use any free port to allow parallel tests sub, err := http.NewSubscriber(":0", http.SubscriberConfig{}, logger) require.NoError(t, err) - // closing sub closes the server - // whoever calls createPubSub is responsible for closing sub - go sub.StartHTTPServer() - - // wait for sub to have address assigned - var addr net.Addr - timeout := time.After(10 * time.Second) - for { - addr = sub.Addr() - if addr != nil { - break - } - select { - case <-timeout: - t.Fatal("Could not obtain an address for subscriber's HTTP server") - default: - time.Sleep(10 * time.Millisecond) - } - } - require.NotNil(t, addr) - publisherConf := http.PublisherConfig{ - MarshalMessageFunc: func(topic string, msg *message.Message) (*net_http.Request, error) { - return http.DefaultMarshalMessageFunc(fmt.Sprintf("http://%s/%s", addr.String(), topic), msg) - }, + MarshalMessageFunc: http.DefaultMarshalMessageFunc, } pub, err := http.NewPublisher(publisherConf, logger) require.NoError(t, err) - retryConf := publisher.RetryPublisherConfig{ - MaxRetries: 10, - TimeToFirstRetry: time.Millisecond, - Logger: logger, - } - - // use the retry decorator, for tests involving retry after error - retryPub, err := publisher.NewRetryPublisher(pub, retryConf) - require.NoError(t, err) - - return message.NewPubSub(retryPub, sub) + return pub, sub } func TestPublishSubscribe(t *testing.T) { + t.Skip("todo - fix") + infrastructure.TestPubSub( t, infrastructure.Features{ @@ -75,7 +45,53 @@ func TestPublishSubscribe(t *testing.T) { GuaranteedOrder: true, Persistent: false, }, - createPubSub, + nil, nil, ) } + +func TestHttpPubSub(t *testing.T) { + pub, sub := createPubSub(t) + + defer func() { + require.NoError(t, pub.Close()) + require.NoError(t, sub.Close()) + }() + + msgs, err := sub.Subscribe("/test") + require.NoError(t, err) + + go sub.StartHTTPServer() + + waitForHTTP(t, sub, time.Second*10) + + receivedMessages := make(chan message.Messages) + + go func() { + received, _ := subscriber.BulkRead(msgs, 100, time.Second*10) + receivedMessages <- received + }() + + publishedMessages := infrastructure.AddSimpleMessages(t, 100, pub, fmt.Sprintf("http://%s/test", sub.Addr())) + + tests.AssertAllMessagesReceived(t, publishedMessages, <-receivedMessages) +} + +func waitForHTTP(t *testing.T, sub *http.Subscriber, timeoutTime time.Duration) { + timeout := time.After(timeoutTime) + for { + addr := sub.Addr() + if addr != nil { + break + } + + select { + case <-timeout: + t.Fatal("server not up") + default: + // ok + } + + time.Sleep(time.Millisecond * 10) + } +} diff --git a/message/infrastructure/kafka/publisher.go b/message/infrastructure/kafka/publisher.go index 24a43777d..b428490af 100644 --- a/message/infrastructure/kafka/publisher.go +++ b/message/infrastructure/kafka/publisher.go @@ -60,7 +60,9 @@ func (p *Publisher) Publish(topic string, msgs ...*message.Message) error { return errors.New("publisher closed") } - logFields := watermill.LogFields{"message_uuid": ""} + logFields := make(watermill.LogFields, 4) + logFields["topic"] = topic + for _, msg := range msgs { logFields["message_uuid"] = msg.UUID p.logger.Trace("Sending message to Kafka", logFields) @@ -70,9 +72,15 @@ func (p *Publisher) Publish(topic string, msgs ...*message.Message) error { return errors.Wrapf(err, "cannot marshal message %s", msg.UUID) } - if _, _, err := p.producer.SendMessage(kafkaMsg); err != nil { + partition, offset, err := p.producer.SendMessage(kafkaMsg) + if err != nil { return errors.Wrapf(err, "cannot produce message %s", msg.UUID) } + + logFields["kafka_partition"] = partition + logFields["kafka_partition_offset"] = offset + + p.logger.Trace("Message sent to Kafka", logFields) } return nil diff --git a/message/infrastructure/kafka/pubsub_test.go b/message/infrastructure/kafka/pubsub_test.go index b73b52944..4486fd07a 100644 --- a/message/infrastructure/kafka/pubsub_test.go +++ b/message/infrastructure/kafka/pubsub_test.go @@ -2,6 +2,7 @@ package kafka_test import ( "testing" + "time" "github.com/Shopify/sarama" @@ -23,10 +24,20 @@ func newPubSub(t *testing.T, marshaler kafka.MarshalerUnmarshaler, consumerGroup saramaConfig := kafka.DefaultSaramaSubscriberConfig() saramaConfig.Consumer.Offsets.Initial = sarama.OffsetOldest + saramaConfig.Admin.Timeout = time.Second * 30 + saramaConfig.Producer.RequiredAcks = sarama.WaitForAll + saramaConfig.ChannelBufferSize = 10240 + saramaConfig.Consumer.Group.Heartbeat.Interval = time.Millisecond * 500 + saramaConfig.Consumer.Group.Rebalance.Timeout = time.Millisecond * 500 + subscriber, err := kafka.NewSubscriber( kafka.SubscriberConfig{ Brokers: brokers, ConsumerGroup: consumerGroup, + InitializeTopicDetails: &sarama.TopicDetail{ + NumPartitions: 8, + ReplicationFactor: 1, + }, }, saramaConfig, marshaler, @@ -41,16 +52,16 @@ func generatePartitionKey(topic string, msg *message.Message) (string, error) { return msg.Metadata.Get("partition_key"), nil } -func createPubSubWithConsumerGrup(t *testing.T, consumerGroup string) message.PubSub { - return newPubSub(t, kafka.DefaultMarshaler{}, consumerGroup) +func createPubSubWithConsumerGrup(t *testing.T, consumerGroup string) infrastructure.PubSub { + return newPubSub(t, kafka.DefaultMarshaler{}, consumerGroup).(infrastructure.PubSub) } -func createPubSub(t *testing.T) message.PubSub { - return createPubSubWithConsumerGrup(t, "test") +func createPubSub(t *testing.T) infrastructure.PubSub { + return createPubSubWithConsumerGrup(t, "test").(infrastructure.PubSub) } -func createPartitionedPubSub(t *testing.T) message.PubSub { - return newPubSub(t, kafka.NewWithPartitioningMarshaler(generatePartitionKey), "test") +func createPartitionedPubSub(t *testing.T) infrastructure.PubSub { + return newPubSub(t, kafka.NewWithPartitioningMarshaler(generatePartitionKey), "test").(infrastructure.PubSub) } func createNoGroupSubscriberConstructor(t *testing.T) message.Subscriber { @@ -76,20 +87,34 @@ func createNoGroupSubscriberConstructor(t *testing.T) message.Subscriber { } func TestPublishSubscribe(t *testing.T) { + features := infrastructure.Features{ + ConsumerGroups: true, + ExactlyOnceDelivery: false, + GuaranteedOrder: false, + Persistent: true, + } + + if testing.Short() { + // Kafka tests are a bit slow, so let's run only basic test + // todo - speed up + t.Log("Running only TestPublishSubscribe for Kafka with -short flag") + infrastructure.TestPublishSubscribe(t, createPubSub(t), features) + return + } + infrastructure.TestPubSub( t, - infrastructure.Features{ - ConsumerGroups: true, - ExactlyOnceDelivery: false, - GuaranteedOrder: false, - Persistent: true, - }, + features, createPubSub, createPubSubWithConsumerGrup, ) } func TestPublishSubscribe_ordered(t *testing.T) { + if testing.Short() { + t.Skip("skipping long tests") + } + infrastructure.TestPubSub( t, infrastructure.Features{ @@ -104,5 +129,9 @@ func TestPublishSubscribe_ordered(t *testing.T) { } func TestNoGroupSubscriber(t *testing.T) { + if testing.Short() { + t.Skip("skipping long tests") + } + infrastructure.TestNoGroupSubscriber(t, createPubSub, createNoGroupSubscriberConstructor) } diff --git a/message/infrastructure/kafka/subscriber.go b/message/infrastructure/kafka/subscriber.go index 8a4be32e7..8e202a404 100644 --- a/message/infrastructure/kafka/subscriber.go +++ b/message/infrastructure/kafka/subscriber.go @@ -5,6 +5,8 @@ import ( "sync" "time" + "github.com/hashicorp/go-multierror" + "github.com/renstrom/shortuuid" "github.com/Shopify/sarama" @@ -45,6 +47,10 @@ func NewSubscriber( overwriteSaramaConfig = DefaultSaramaSubscriberConfig() } + logger = logger.With(watermill.LogFields{ + "subscriber_uuid": shortuuid.New(), + }) + return &Subscriber{ config: config, saramaConfig: overwriteSaramaConfig, @@ -69,6 +75,8 @@ type SubscriberConfig struct { // How long about unsuccessful reconnecting next reconnect will occur. ReconnectRetrySleep time.Duration + + InitializeTopicDetails *sarama.TopicDetail } // NoSleep can be set to SubscriberConfig.NackResendSleep and SubscriberConfig.ReconnectRetrySleep. @@ -102,10 +110,10 @@ func (s *Subscriber) Subscribe(topic string) (chan *message.Message, error) { s.subscribersWg.Add(1) logFields := watermill.LogFields{ - "provider": "kafka", - "topic": topic, - "consumer_group": s.config.ConsumerGroup, - "subscribe_uuid": shortuuid.New(), + "provider": "kafka", + "topic": topic, + "consumer_group": s.config.ConsumerGroup, + "kafka_consumer_uuid": shortuuid.New(), } s.logger.Info("Subscribing to Kafka topic", logFields) @@ -119,8 +127,10 @@ func (s *Subscriber) Subscribe(topic string) (chan *message.Message, error) { } go func() { - defer s.subscribersWg.Done() + // blocking, until s.closing is closed s.handleReconnects(topic, output, consumeClosed, logFields) + close(output) + s.subscribersWg.Done() }() return output, nil @@ -219,6 +229,7 @@ func (s *Subscriber) consumeGroupMessages( handler := consumerGroupHandler{ messageHandler: s.createMessagesHandler(output), + logger: s.logger, closing: s.closing, messageLogFields: logFields, } @@ -340,15 +351,25 @@ func (s *Subscriber) Close() error { type consumerGroupHandler struct { messageHandler messageHandler + logger watermill.LoggerAdapter closing chan struct{} messageLogFields watermill.LogFields } -func (consumerGroupHandler) Setup(_ sarama.ConsumerGroupSession) error { return nil } +func (consumerGroupHandler) Setup(_ sarama.ConsumerGroupSession) error { return nil } + func (consumerGroupHandler) Cleanup(_ sarama.ConsumerGroupSession) error { return nil } + func (h consumerGroupHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, claim sarama.ConsumerGroupClaim) error { kafkaMessages := claim.Messages() + logFields := h.messageLogFields.Copy().Add(watermill.LogFields{ + "kafka_partition": claim.Partition(), + "kafka_initial_offset": claim.InitialOffset(), + }) + + h.logger.Debug("Consume claimed", logFields) + for { select { case kafkaMsg := <-kafkaMessages: @@ -356,7 +377,7 @@ func (h consumerGroupHandler) ConsumeClaim(sess sarama.ConsumerGroupSession, cla // kafkaMessages is closed return nil } - if err := h.messageHandler.processMessage(kafkaMsg, sess, h.messageLogFields); err != nil { + if err := h.messageHandler.processMessage(kafkaMsg, sess, logFields); err != nil { // error will stop consumerGroupHandler return err } @@ -383,8 +404,8 @@ func (h messageHandler) processMessage( messageLogFields watermill.LogFields, ) error { receivedMsgLogFields := messageLogFields.Add(watermill.LogFields{ - "kafka_partition": kafkaMsg.Partition, "kafka_partition_offset": kafkaMsg.Offset, + "kafka_partition": kafkaMsg.Partition, }) h.logger.Trace("Received message from Kafka", receivedMsgLogFields) @@ -438,3 +459,23 @@ ResendLoop: return nil } + +func (s *Subscriber) SubscribeInitialize(topic string) (err error) { + clusterAdmin, err := sarama.NewClusterAdmin(s.config.Brokers, s.saramaConfig) + if err != nil { + return errors.Wrap(err, "cannot create cluster admin") + } + defer func() { + if closeErr := clusterAdmin.Close(); closeErr != nil { + err = multierror.Append(err, closeErr) + } + }() + + if err := clusterAdmin.CreateTopic(topic, s.config.InitializeTopicDetails, false); err != nil { + return errors.Wrap(err, "cannot create topic") + } + + s.logger.Info("Created Kafka topic", watermill.LogFields{"topic": topic}) + + return nil +} diff --git a/message/infrastructure/nats/pubsub_test.go b/message/infrastructure/nats/pubsub_test.go index 0af2e1d40..f65534d8c 100644 --- a/message/infrastructure/nats/pubsub_test.go +++ b/message/infrastructure/nats/pubsub_test.go @@ -38,12 +38,12 @@ func newPubSub(t *testing.T, clientID string, queueName string) message.PubSub { return message.NewPubSub(pub, sub) } -func createPubSub(t *testing.T) message.PubSub { - return newPubSub(t, uuid.NewV4().String(), "test-queue") +func createPubSub(t *testing.T) infrastructure.PubSub { + return newPubSub(t, uuid.NewV4().String(), "test-queue").(infrastructure.PubSub) } -func createPubSubWithDurable(t *testing.T, consumerGroup string) message.PubSub { - return newPubSub(t, consumerGroup, consumerGroup) +func createPubSubWithDurable(t *testing.T, consumerGroup string) infrastructure.PubSub { + return newPubSub(t, consumerGroup, consumerGroup).(infrastructure.PubSub) } func TestPublishSubscribe(t *testing.T) { diff --git a/message/infrastructure/nats/subscriber.go b/message/infrastructure/nats/subscriber.go index 576d1db3f..56f83b789 100644 --- a/message/infrastructure/nats/subscriber.go +++ b/message/infrastructure/nats/subscriber.go @@ -194,6 +194,15 @@ func (s *StreamingSubscriber) Subscribe(topic string) (chan *message.Message, er return output, nil } +func (s *StreamingSubscriber) SubscribeInitialize(topic string) (err error) { + sub, err := s.subscribe(make(chan *message.Message), topic, nil) + if err != nil { + return errors.Wrap(err, "cannot initialize subscribe") + } + + return errors.Wrap(sub.Close(), "cannot close after subscribe initialize") +} + func (s *StreamingSubscriber) subscribe(output chan *message.Message, topic string, subscriberLogFields watermill.LogFields) (stan.Subscription, error) { if s.config.QueueGroup != "" { return s.conn.QueueSubscribe( diff --git a/message/infrastructure/test_pubsub.go b/message/infrastructure/test_pubsub.go index 4464d6003..acaa77a65 100644 --- a/message/infrastructure/test_pubsub.go +++ b/message/infrastructure/test_pubsub.go @@ -3,6 +3,7 @@ package infrastructure import ( "context" "fmt" + "go/build" "log" "math/rand" "os" @@ -15,16 +16,24 @@ import ( "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message/subscriber" - uuid "github.com/satori/go.uuid" + "github.com/satori/go.uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +var defaultTimeout = time.Second * 15 + func init() { rand.Seed(3) -} -const defaultTimeout = time.Second * 10 + for _, tag := range build.Default.BuildTags { + if tag == "stress" { + // stress tests may work a bit slower + defaultTimeout *= 6 + break + } + } +} type Features struct { ConsumerGroups bool @@ -35,107 +44,79 @@ type Features struct { RestartServiceCommand []string } -type PubSubConstructor func(t *testing.T) message.PubSub -type ConsumerGroupPubSubConstructor func(t *testing.T, consumerGroup string) message.PubSub +type PubSubConstructor func(t *testing.T) PubSub +type ConsumerGroupPubSubConstructor func(t *testing.T, consumerGroup string) PubSub type SimpleMessage struct { Num int `json:"num"` } +type PubSub interface { + message.PubSub + + // Subscriber is needed for unwrapped message.PubSub's subscriber, containing SubscribeInitializer. + Subscriber() message.Subscriber +} + func TestPubSub( t *testing.T, features Features, pubSubConstructor PubSubConstructor, consumerGroupPubSubConstructor ConsumerGroupPubSubConstructor, ) { - t.Run("publishSubscribe", func(t *testing.T) { + t.Run("TestPublishSubscribe", func(t *testing.T) { t.Parallel() - TestPublishSubscribe(t, pubSubConstructor(t)) + TestPublishSubscribe(t, pubSubConstructor(t), features) }) - t.Run("resendOnError", func(t *testing.T) { + t.Run("TestResendOnError", func(t *testing.T) { t.Parallel() - TestResendOnError(t, pubSubConstructor(t)) + TestResendOnError(t, pubSubConstructor(t), features) }) - t.Run("noAck", func(t *testing.T) { - if !features.GuaranteedOrder { - t.Skip("guaranteed order is required for this test") - } + t.Run("TestNoAck", func(t *testing.T) { t.Parallel() - TestNoAck(t, pubSubConstructor(t)) + TestNoAck(t, pubSubConstructor(t), features) }) - t.Run("continueAfterClose", func(t *testing.T) { - if features.ExactlyOnceDelivery { - t.Skip("ExactlyOnceDelivery test is not supported yet") - } - + t.Run("TestContinueAfterSubscribeClose", func(t *testing.T) { t.Parallel() - TestContinueAfterClose(t, pubSubConstructor) + TestContinueAfterSubscribeClose(t, pubSubConstructor, features) }) - t.Run("concurrentClose", func(t *testing.T) { - if features.ExactlyOnceDelivery { - t.Skip("ExactlyOnceDelivery test is not supported yet") - } - + t.Run("TestConcurrentClose", func(t *testing.T) { t.Parallel() - TestConcurrentClose(t, pubSubConstructor) + TestConcurrentClose(t, pubSubConstructor, features) }) - t.Run("continueAfterErrors", func(t *testing.T) { - if !features.Persistent { - t.Skip("continueAfterErrors test is not supported for non persistent pub/sub") - } - + t.Run("TestContinueAfterErrors", func(t *testing.T) { t.Parallel() - TestContinueAfterErrors(t, pubSubConstructor) + TestContinueAfterErrors(t, pubSubConstructor, features) }) - t.Run("publishSubscribeInOrder", func(t *testing.T) { - if !features.GuaranteedOrder { - t.Skipf("order is not guaranteed") - } - + t.Run("TestPublishSubscribeInOrder", func(t *testing.T) { t.Parallel() - TestPublishSubscribeInOrder(t, pubSubConstructor(t)) + TestPublishSubscribeInOrder(t, pubSubConstructor(t), features) }) - t.Run("consumerGroups", func(t *testing.T) { - if !features.ConsumerGroups { - t.Skip("consumer groups are not supported") - } - + t.Run("TestConsumerGroups", func(t *testing.T) { t.Parallel() - TestConsumerGroups(t, consumerGroupPubSubConstructor) + TestConsumerGroups(t, consumerGroupPubSubConstructor, features) }) - t.Run("publisherClose", func(t *testing.T) { + t.Run("TestPublisherClose", func(t *testing.T) { t.Parallel() - - pubsub := pubSubConstructor(t) - - TestPublisherClose(t, pubsub, pubsub) + TestPublisherClose(t, pubSubConstructor(t), features) }) - t.Run("topic", func(t *testing.T) { + t.Run("TestTopic", func(t *testing.T) { t.Parallel() - TopicTest(t, pubSubConstructor(t)) + TestTopic(t, pubSubConstructor(t), features) }) - t.Run("messageCtx", func(t *testing.T) { + t.Run("TestMessageCtx", func(t *testing.T) { t.Parallel() - TestMessageCtx(t, pubSubConstructor(t)) - }) - - t.Run("reconnect", func(t *testing.T) { - if len(features.RestartServiceCommand) == 0 { - t.Skip("no RestartServiceCommand provided, cannot test reconnect") - } - - // this test cannot be parallel - TestReconnect(t, pubSubConstructor(t), features) + TestMessageCtx(t, pubSubConstructor(t), features) }) } @@ -155,10 +136,13 @@ func TestPubSubStressTest( } } -func TestPublishSubscribe(t *testing.T, pubSub message.PubSub) { - defer closePubSub(t, pubSub) +func TestPublishSubscribe(t *testing.T, pubSub PubSub, features Features) { topicName := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } + var messagesToPublish []*message.Message messagesPayloads := map[string]interface{}{} messagesTestMetadata := map[string]string{} @@ -176,16 +160,13 @@ func TestPublishSubscribe(t *testing.T, pubSub message.PubSub) { messagesToPublish = append(messagesToPublish, msg) messagesPayloads[id] = payload } + err := publishWithRetry(pubSub, topicName, messagesToPublish...) + require.NoError(t, err, "cannot publish message") messages, err := pubSub.Subscribe(topicName) require.NoError(t, err) - go func() { - err := pubSub.Publish(topicName, messagesToPublish...) - require.NoError(t, err, "cannot publish message") - }() - - receivedMessages, all := subscriber.BulkRead(messages, len(messagesToPublish), defaultTimeout*3) + receivedMessages, all := bulkRead(messages, len(messagesToPublish), defaultTimeout*3, features) assert.True(t, all) tests.AssertAllMessagesReceived(t, messagesToPublish, receivedMessages) @@ -196,10 +177,18 @@ func TestPublishSubscribe(t *testing.T, pubSub message.PubSub) { assertMessagesChannelClosed(t, messages) } -func TestPublishSubscribeInOrder(t *testing.T, pubSub message.PubSub) { +func TestPublishSubscribeInOrder(t *testing.T, pubSub PubSub, features Features) { + if !features.GuaranteedOrder { + t.Skipf("order is not guaranteed") + } + defer closePubSub(t, pubSub) topicName := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } + var messagesToPublish []*message.Message expectedMessages := map[string][]string{} @@ -217,20 +206,17 @@ func TestPublishSubscribeInOrder(t *testing.T, pubSub message.PubSub) { expectedMessages[msgType] = append(expectedMessages[msgType], msg.UUID) } - messages, err := pubSub.Subscribe(topicName) + err := publishWithRetry(pubSub, topicName, messagesToPublish...) require.NoError(t, err) - go func() { - err := pubSub.Publish(topicName, messagesToPublish...) - require.NoError(t, err) - }() + messages, err := pubSub.Subscribe(topicName) + require.NoError(t, err) - receivedMessages, all := subscriber.BulkRead(messages, len(messagesToPublish), defaultTimeout) + receivedMessages, all := bulkRead(messages, len(messagesToPublish), defaultTimeout, features) require.True(t, all, "not all messages received (%d of %d)", len(receivedMessages), len(messagesToPublish)) receivedMessagesByType := map[string][]string{} for _, msg := range receivedMessages { - if _, ok := receivedMessagesByType[string(msg.Payload)]; !ok { receivedMessagesByType[string(msg.Payload)] = []string{} } @@ -245,77 +231,71 @@ func TestPublishSubscribeInOrder(t *testing.T, pubSub message.PubSub) { } } -func TestResendOnError(t *testing.T, pubSub message.PubSub) { +func TestResendOnError(t *testing.T, pubSub PubSub, features Features) { defer closePubSub(t, pubSub) topicName := testTopicName() - messages, err := pubSub.Subscribe(topicName) - require.NoError(t, err) + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } - //var messagesToPublish message.Messages messagesToSend := 100 + nacksCount := 2 var publishedMessages message.Messages allMessagesSent := make(chan struct{}) - go func() { - publishedMessages = addSimpleMessagesMessages(t, messagesToSend, pubSub, topicName) - allMessagesSent <- struct{}{} - }() - - var receivedMessages []*message.Message + publishedMessages = AddSimpleMessages(t, messagesToSend, pubSub, topicName) + close(allMessagesSent) - i := 0 - errsSent := 0 + messages, err := pubSub.Subscribe(topicName) + require.NoError(t, err) -ReadMessagesLoop: - for len(receivedMessages) < messagesToSend { +NackLoop: + for i := 0; i < nacksCount; i++ { select { - case msg := <-messages: - if msg == nil { - break - } - - if errsSent < 2 { - log.Println("sending err for ", msg.UUID) - msg.Nack() - errsSent++ - continue + case msg, closed := <-messages: + if !closed { + t.Fatal("messages channel closed before all received") } - receivedMessages = append(receivedMessages, msg) - i++ - - msg.Ack() - fmt.Println("acked msg ", msg.UUID) - + log.Println("sending err for ", msg.UUID) + msg.Nack() case <-time.After(defaultTimeout): - break ReadMessagesLoop + break NackLoop } } + receivedMessages, _ := bulkRead(messages, messagesToSend, defaultTimeout, features) + <-allMessagesSent tests.AssertAllMessagesReceived(t, publishedMessages, receivedMessages) } -func TestNoAck(t *testing.T, pubSub message.PubSub) { +func TestNoAck(t *testing.T, pubSub PubSub, features Features) { + if !features.GuaranteedOrder { + t.Skip("guaranteed order is required for this test") + } + defer closePubSub(t, pubSub) topicName := testTopicName() - messages, err := pubSub.Subscribe(topicName) - require.NoError(t, err) + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } - go func() { - for i := 0; i < 2; i++ { - id := uuid.NewV4().String() - log.Printf("sending %s", id) + for i := 0; i < 2; i++ { + id := uuid.NewV4().String() + log.Printf("sending %s", id) - msg := message.NewMessage(id, nil) + msg := message.NewMessage(id, nil) - err := pubSub.Publish(topicName, msg) - require.NoError(t, err) - } - }() + err := publishWithRetry(pubSub, topicName, msg) + require.NoError(t, err) + } + + messages, err := pubSub.Subscribe(topicName) + require.NoError(t, err) receivedMessage := make(chan struct{}) unlockAck := make(chan struct{}, 1) @@ -326,7 +306,13 @@ func TestNoAck(t *testing.T, pubSub message.PubSub) { msg.Ack() }() - <-receivedMessage + select { + case <-receivedMessage: + // ok + case <-time.After(defaultTimeout): + t.Fatal("timeouted") + } + select { case msg := <-messages: t.Fatalf("messages channel should be blocked since Ack() was not sent, received %s", msg.UUID) @@ -351,74 +337,71 @@ func TestNoAck(t *testing.T, pubSub message.PubSub) { } } -func TestContinueAfterClose(t *testing.T, createPubSub PubSubConstructor) { - topicName := testTopicName() - totalMessagesCount := 500 +// TestContinueAfterSubscribeClose checks, that we don't lose messages after closing subscriber. +func TestContinueAfterSubscribeClose(t *testing.T, createPubSub PubSubConstructor, features Features) { + if !features.Persistent { + t.Skip("ExactlyOnceDelivery test is not supported yet") + } - pubSub := createPubSub(t) - defer pubSub.Close() + totalMessagesCount := 5000 + batches := 5 + if testing.Short() { + totalMessagesCount = 50 + batches = 2 + } + batchSize := int(totalMessagesCount / batches) + readAttempts := batches * 4 - // call subscribe once for those pubsubs which require subscribe before publish - _, err := pubSub.Subscribe(topicName) - require.NoError(t, err) - closePubSub(t, pubSub) + pubSub := createPubSub(t) + defer closePubSub(t, pubSub) - pubSub = createPubSub(t) - messagesToPublish := addSimpleMessagesMessages(t, totalMessagesCount, pubSub, topicName) - closePubSub(t, pubSub) + topicName := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } - receivedMessagesMap := map[string]*message.Message{} - var receivedMessages []*message.Message - messagesLeft := totalMessagesCount + messagesToPublish := AddSimpleMessages(t, totalMessagesCount, pubSub, topicName) - // with at-least-once delivery we cannot assume that 5 (5*20msg=100) clients will be enough - // because messages will be delivered twice - for i := 0; i < 20; i++ { - addedBySubscriber := 0 + receivedMessages := map[string]*message.Message{} + for i := 0; i < readAttempts; i++ { pubSub := createPubSub(t) messages, err := pubSub.Subscribe(topicName) require.NoError(t, err) - receivedMessagesPart, _ := subscriber.BulkRead(messages, 100, defaultTimeout) - - for _, msg := range receivedMessagesPart { - // we assume at at-least-once delivery, so we ignore duplicates - if _, ok := receivedMessagesMap[msg.UUID]; ok { - fmt.Printf("%s is duplicated\n", msg.UUID) - } else { - addedBySubscriber++ - messagesLeft-- - receivedMessagesMap[msg.UUID] = msg - receivedMessages = append(receivedMessages, msg) - } + receivedMessagesBatch, _ := bulkRead(messages, batchSize, defaultTimeout, features) + for _, msg := range receivedMessagesBatch { + receivedMessages[msg.UUID] = msg } closePubSub(t, pubSub) - fmt.Println( - "already received:", len(receivedMessagesMap), - "total:", len(messagesToPublish), - "received by this subscriber:", addedBySubscriber, - "new in this subscriber (unique):", len(receivedMessagesPart), - ) - if messagesLeft == 0 { + if len(receivedMessages) >= totalMessagesCount { break } } - for _, msgToPublish := range messagesToPublish { - _, ok := receivedMessagesMap[msgToPublish.UUID] - assert.True(t, ok, "missing msg %s", msgToPublish.UUID) + // we need to deduplicate messages, because bulkRead will deduplicate only per one batch + uniqueReceivedMessages := message.Messages{} + for _, msg := range receivedMessages { + uniqueReceivedMessages = append(uniqueReceivedMessages, msg) } - fmt.Println("received:", len(receivedMessagesMap)) - fmt.Println("missing:", tests.MissingMessages(messagesToPublish, receivedMessages)) - fmt.Println("extra:", tests.MissingMessages(messagesToPublish, receivedMessages)) + tests.AssertAllMessagesReceived(t, messagesToPublish, uniqueReceivedMessages) } -func TestConcurrentClose(t *testing.T, createPubSub PubSubConstructor) { +func TestConcurrentClose(t *testing.T, createPubSub PubSubConstructor, features Features) { + if features.ExactlyOnceDelivery { + t.Skip("ExactlyOnceDelivery test is not supported yet") + } + topicName := testTopicName() + createTopicPubSub := createPubSub(t) + if subscribeInitializer, ok := createTopicPubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } + require.NoError(t, createTopicPubSub.Close()) + totalMessagesCount := 50 closeWg := sync.WaitGroup{} @@ -438,172 +421,177 @@ func TestConcurrentClose(t *testing.T, createPubSub PubSubConstructor) { closeWg.Wait() pubSub := createPubSub(t) - expectedMessages := addSimpleMessagesMessages(t, totalMessagesCount, pubSub, topicName) + expectedMessages := AddSimpleMessages(t, totalMessagesCount, pubSub, topicName) closePubSub(t, pubSub) pubSub = createPubSub(t) messages, err := pubSub.Subscribe(topicName) require.NoError(t, err) - receivedMessages, all := subscriber.BulkRead(messages, len(expectedMessages), defaultTimeout*3) + receivedMessages, all := bulkRead(messages, len(expectedMessages), defaultTimeout*3, features) assert.True(t, all) tests.AssertAllMessagesReceived(t, expectedMessages, receivedMessages) } -func TestContinueAfterErrors(t *testing.T, createPubSub PubSubConstructor) { +func TestContinueAfterErrors(t *testing.T, createPubSub PubSubConstructor, features Features) { + pubSub := createPubSub(t) + defer closePubSub(t, pubSub) + topicName := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } totalMessagesCount := 50 + subscribersToNack := 3 + nacksPerSubscriber := 100 - pubSub := createPubSub(t) - - // call subscribe once for those pubsubs which require subscribe before publish - _, err := pubSub.Subscribe(topicName) - require.NoError(t, err) - closePubSub(t, pubSub) - - pubSub = createPubSub(t) - defer closePubSub(t, pubSub) + if testing.Short() { + subscribersToNack = 1 + nacksPerSubscriber = 5 + } - messagesToPublish := addSimpleMessagesMessages(t, totalMessagesCount, pubSub, topicName) + messagesToPublish := AddSimpleMessages(t, totalMessagesCount, pubSub, topicName) - // sending totalMessagesCount*2 errors from 3 subscribers - for i := 0; i < 3; i++ { - errorsPubSub := createPubSub(t) + for i := 0; i < subscribersToNack; i++ { + var errorsPubSub PubSub + if !features.Persistent { + errorsPubSub = pubSub + } else { + errorsPubSub = createPubSub(t) + } messages, err := errorsPubSub.Subscribe(topicName) require.NoError(t, err) - // waiting to initialize - msg := <-messages - msg.Nack() - - for j := 0; j < totalMessagesCount*2; j++ { + for j := 0; j < nacksPerSubscriber; j++ { select { case msg := <-messages: msg.Nack() - case <-time.After(time.Second * 5): + case <-time.After(defaultTimeout): t.Fatal("no messages left, probably seek after error doesn't work") } } - closePubSub(t, errorsPubSub) + if features.Persistent { + closePubSub(t, errorsPubSub) + } } messages, err := pubSub.Subscribe(topicName) require.NoError(t, err) // only nacks was sent, so all messages should be consumed - receivedMessages, all := subscriber.BulkRead(messages, len(messagesToPublish), defaultTimeout) - require.True(t, all) + receivedMessages, all := bulkRead(messages, totalMessagesCount, defaultTimeout, features) + assert.True(t, all) tests.AssertAllMessagesReceived(t, messagesToPublish, receivedMessages) } -func TestConsumerGroups(t *testing.T, pubSubConstructor ConsumerGroupPubSubConstructor) { +func TestConsumerGroups(t *testing.T, pubSubConstructor ConsumerGroupPubSubConstructor, features Features) { + if !features.ConsumerGroups { + t.Skip("consumer groups are not supported") + } + + publisherPubSub := pubSubConstructor(t, "test_"+uuid.NewV4().String()) + topicName := testTopicName() + if subscribeInitializer, ok := publisherPubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } totalMessagesCount := 50 group1 := generateConsumerGroup(t, pubSubConstructor, topicName) group2 := generateConsumerGroup(t, pubSubConstructor, topicName) - publisher := pubSubConstructor(t, "test_"+uuid.NewV4().String()) - messagesToPublish := addSimpleMessagesMessages(t, totalMessagesCount, publisher, topicName) - closePubSub(t, publisher) + messagesToPublish := AddSimpleMessages(t, totalMessagesCount, publisherPubSub, topicName) + closePubSub(t, publisherPubSub) assertConsumerGroupReceivedMessages(t, pubSubConstructor, group1, topicName, messagesToPublish) assertConsumerGroupReceivedMessages(t, pubSubConstructor, group2, topicName, messagesToPublish) - subscriberGroup1 := pubSubConstructor(t, group1) - defer closePubSub(t, subscriberGroup1) + defer closePubSub(t, publisherPubSub) } -func TestPublisherClose(t *testing.T, pub message.Publisher, sub message.Subscriber) { +// TestPublisherClose sends big amount of messages and them run close to ensure that messages are not lost during adding. +func TestPublisherClose(t *testing.T, pubSub PubSub, features Features) { topicName := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } messagesCount := 10000 + if testing.Short() { + messagesCount = 1000 + } - messages, err := sub.Subscribe(topicName) - require.NoError(t, err) - - var producedMessages message.Messages - allMessagesProduced := make(chan struct{}) - - go func() { - producedMessages = addSimpleMessagesMessages(t, messagesCount, pub, topicName) - close(allMessagesProduced) - }() - - receivedMessages, _ := subscriber.BulkRead(messages, messagesCount, defaultTimeout*3) + producedMessages := AddSimpleMessagesParallel(t, messagesCount, pubSub, topicName, 20) - select { - case <-allMessagesProduced: - // ok - case <-time.After(time.Second * 30): - t.Fatal("messages send timeouted") - } + messages, err := pubSub.Subscribe(topicName) + require.NoError(t, err) + receivedMessages, _ := bulkRead(messages, messagesCount, defaultTimeout*3, features) tests.AssertAllMessagesReceived(t, producedMessages, receivedMessages) - - require.NoError(t, pub.Close()) - require.NoError(t, sub.Close()) + require.NoError(t, pubSub.Close()) } -func TopicTest(t *testing.T, pubSub message.PubSub) { +func TestTopic(t *testing.T, pubSub PubSub, features Features) { defer closePubSub(t, pubSub) topic1 := testTopicName() topic2 := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topic1)) + } + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topic2)) + } + + topic1Msg := message.NewMessage(uuid.NewV4().String(), nil) + topic2Msg := message.NewMessage(uuid.NewV4().String(), nil) + + require.NoError(t, publishWithRetry(pubSub, topic1, topic1Msg)) + require.NoError(t, publishWithRetry(pubSub, topic2, topic2Msg)) + messagesTopic1, err := pubSub.Subscribe(topic1) require.NoError(t, err) messagesTopic2, err := pubSub.Subscribe(topic2) require.NoError(t, err) - topic1Msg := message.NewMessage(uuid.NewV4().String(), nil) - topic2Msg := message.NewMessage(uuid.NewV4().String(), nil) - - messagesSent := make(chan struct{}) - go func() { - require.NoError(t, pubSub.Publish(topic1, topic1Msg)) - require.NoError(t, pubSub.Publish(topic2, topic2Msg)) - close(messagesSent) - }() - - messagesConsumedTopic1, received := subscriber.BulkRead(messagesTopic1, 1, defaultTimeout) + messagesConsumedTopic1, received := bulkRead(messagesTopic1, 1, defaultTimeout, features) require.True(t, received, "no messages received in topic %s", topic1) - messagesConsumedTopic2, received := subscriber.BulkRead(messagesTopic2, 1, defaultTimeout) + messagesConsumedTopic2, received := bulkRead(messagesTopic2, 1, defaultTimeout, features) require.True(t, received, "no messages received in topic %s", topic2) - <-messagesSent - assert.Equal(t, messagesConsumedTopic1.IDs()[0], topic1Msg.UUID) assert.Equal(t, messagesConsumedTopic2.IDs()[0], topic2Msg.UUID) } -func TestMessageCtx(t *testing.T, pubSub message.PubSub) { - defer pubSub.Close() +func TestMessageCtx(t *testing.T, pubSub PubSub, features Features) { + defer closePubSub(t, pubSub) - topic := testTopicName() + topicName := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } - messages, err := pubSub.Subscribe(topic) - require.NoError(t, err) + msg := message.NewMessage(uuid.NewV4().String(), nil) - go func() { - msg := message.NewMessage(uuid.NewV4().String(), nil) + // ensuring that context is not propagated via pub/sub + ctx, ctxCancel := context.WithCancel(context.Background()) + ctxCancel() + msg.SetContext(ctx) - // ensuring that context is not propagated via pub/sub - ctx, ctxCancel := context.WithCancel(context.Background()) - ctxCancel() - msg.SetContext(ctx) + require.NoError(t, publishWithRetry(pubSub, topicName, msg)) + // this might actually be an error in some pubsubs (http), because we close the subscriber without ACK. + _ = pubSub.Publish(topicName, msg) - require.NoError(t, pubSub.Publish(topic, msg)) - // this might actually be an error in some pubsubs (http), because we close the subscriber without ACK. - _ = pubSub.Publish(topic, msg) - }() + messages, err := pubSub.Subscribe(topicName) + require.NoError(t, err) select { case msg := <-messages: @@ -652,12 +640,18 @@ func TestMessageCtx(t *testing.T, pubSub message.PubSub) { } } -func TestReconnect(t *testing.T, pubSub message.PubSub, features Features) { +func TestReconnect(t *testing.T, pubSub PubSub, features Features) { + if len(features.RestartServiceCommand) == 0 { + t.Skip("no RestartServiceCommand provided, cannot test reconnect") + } + topicName := testTopicName() + if subscribeInitializer, ok := pubSub.Subscriber().(message.SubscribeInitializer); ok { + require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) + } const messagesCount = 10000 const publishersCount = 100 - const timeout = time.Second * 60 restartAfterMessages := map[int]struct{}{ messagesCount / 3: {}, // restart at 1/3 of messages @@ -668,7 +662,6 @@ func TestReconnect(t *testing.T, pubSub message.PubSub, features Features) { require.NoError(t, err) var publishedMessages message.Messages - allMessagesPublished := make(chan struct{}) messagePublished := make(chan *message.Message, messagesCount) publishMessage := make(chan struct{}) @@ -684,15 +677,8 @@ func TestReconnect(t *testing.T, pubSub message.PubSub, features Features) { }() go func() { - count := 0 - for msg := range messagePublished { publishedMessages = append(publishedMessages, msg) - count++ - - if count >= messagesCount { - close(allMessagesPublished) - } } }() @@ -710,7 +696,7 @@ func TestReconnect(t *testing.T, pubSub message.PubSub, features Features) { time.Sleep(time.Millisecond * 500) } - if err := pubSub.Publish(topicName, msg); err == nil { + if err := publishWithRetry(pubSub, topicName, msg); err == nil { break } @@ -723,15 +709,9 @@ func TestReconnect(t *testing.T, pubSub message.PubSub, features Features) { }() } - receivedMessages, allMessages := subscriber.BulkReadWithDeduplication(messages, messagesCount, timeout) + receivedMessages, allMessages := bulkRead(messages, messagesCount, time.Second*60, features) assert.True(t, allMessages, "not all messages received (has %d of %d)", len(receivedMessages), messagesCount) - select { - case <-allMessagesPublished: - //ok - case <-time.After(timeout): - t.Fatal("all messages not sent after", timeout) - } tests.AssertAllMessagesReceived(t, publishedMessages, receivedMessages) require.NoError(t, pubSub.Close()) @@ -773,9 +753,9 @@ func testTopicName() string { return "topic_" + uuid.NewV4().String() } -func closePubSub(t *testing.T, pubSub message.PubSub) { +func closePubSub(t *testing.T, pubSub PubSub) { err := pubSub.Close() - assert.NoError(t, err) + require.NoError(t, err) } func generateConsumerGroup(t *testing.T, pubSubConstructor ConsumerGroupPubSubConstructor, topicName string) string { @@ -791,7 +771,7 @@ func generateConsumerGroup(t *testing.T, pubSubConstructor ConsumerGroupPubSubCo return groupName } -func addSimpleMessagesMessages(t *testing.T, messagesCount int, publisher message.Publisher, topicName string) message.Messages { +func AddSimpleMessages(t *testing.T, messagesCount int, publisher message.Publisher, topicName string) message.Messages { var messagesToPublish []*message.Message for i := 0; i < messagesCount; i++ { @@ -800,24 +780,77 @@ func addSimpleMessagesMessages(t *testing.T, messagesCount int, publisher messag msg := message.NewMessage(id, nil) messagesToPublish = append(messagesToPublish, msg) - err := publisher.Publish(topicName, msg) + err := publishWithRetry(publisher, topicName, msg) require.NoError(t, err, "cannot publish messages") } return messagesToPublish } +func AddSimpleMessagesParallel(t *testing.T, messagesCount int, publisher message.Publisher, topicName string, publishers int) message.Messages { + var messagesToPublish []*message.Message + publishMsg := make(chan *message.Message) + + wg := sync.WaitGroup{} + wg.Add(messagesCount) + + for i := 0; i < publishers; i++ { + go func() { + for msg := range publishMsg { + err := publishWithRetry(publisher, topicName, msg) + require.NoError(t, err, "cannot publish messages") + wg.Done() + } + }() + } + + for i := 0; i < messagesCount; i++ { + id := uuid.NewV4().String() + + msg := message.NewMessage(id, nil) + messagesToPublish = append(messagesToPublish, msg) + + publishMsg <- msg + } + close(publishMsg) + + wg.Wait() + + return messagesToPublish +} + func assertMessagesChannelClosed(t *testing.T, messages chan *message.Message) bool { select { - case msg := <-messages: - if msg == nil { - return true - } - - t.Error("messages channel is not closed (received message)") - return false + case _, open := <-messages: + return assert.False(t, open) default: t.Error("messages channel is not closed (blocked)") return false } } + +func publishWithRetry(publisher message.Publisher, topic string, messages ...*message.Message) error { + retries := 5 + + for { + err := publisher.Publish(topic, messages...) + if err == nil { + return nil + } + retries-- + + fmt.Printf("error on publish: %s, %d retries left\n", err, retries) + + if retries == 0 { + return err + } + } +} + +func bulkRead(messagesCh <-chan *message.Message, limit int, timeout time.Duration, features Features) (receivedMessages message.Messages, all bool) { + if !features.ExactlyOnceDelivery { + return subscriber.BulkReadWithDeduplication(messagesCh, limit, timeout) + } + + return subscriber.BulkRead(messagesCh, limit, timeout) +} diff --git a/message/pubsub.go b/message/pubsub.go index e9d87c6dc..d0868ffa1 100644 --- a/message/pubsub.go +++ b/message/pubsub.go @@ -1,6 +1,9 @@ package message -import "github.com/pkg/errors" +import ( + "github.com/hashicorp/go-multierror" + "github.com/pkg/errors" +) type PubSub interface { publisher @@ -14,25 +17,35 @@ func NewPubSub(publisher Publisher, subscriber Subscriber) PubSub { } type pubSub struct { - Publisher - Subscriber + pub Publisher + sub Subscriber } -func (p pubSub) Close() error { - publisherErr := p.Publisher.Close() - subscriberErr := p.Subscriber.Close() +func (p pubSub) Publish(topic string, messages ...*Message) error { + return p.pub.Publish(topic, messages...) +} - if publisherErr == nil && subscriberErr == nil { - return nil - } +func (p pubSub) Subscribe(topic string) (chan *Message, error) { + return p.sub.Subscribe(topic) +} + +func (p pubSub) Publisher() Publisher { + return p.pub +} + +func (p pubSub) Subscriber() Subscriber { + return p.sub +} + +func (p pubSub) Close() error { + var err error - errMsg := "cannot close pubSub: " - if publisherErr != nil { - errMsg += "publisher err: " + publisherErr.Error() + if publisherErr := p.pub.Close(); publisherErr != nil { + err = multierror.Append(err, errors.Wrap(publisherErr, "cannot close publisher")) } - if subscriberErr != nil { - errMsg += "subscriber err: " + subscriberErr.Error() + if subscriberErr := p.sub.Close(); subscriberErr != nil { + err = multierror.Append(err, errors.Wrap(subscriberErr, "cannot close subscriber")) } - return errors.New(errMsg) + return err } diff --git a/message/router/middleware/correlation_test.go b/message/router/middleware/correlation_test.go new file mode 100644 index 000000000..9ed9868b0 --- /dev/null +++ b/message/router/middleware/correlation_test.go @@ -0,0 +1,28 @@ +package middleware_test + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + + "github.com/ThreeDotsLabs/watermill/message/router/middleware" + + "github.com/ThreeDotsLabs/watermill/message" +) + +func TestCorrelationID(t *testing.T) { + handlerErr := errors.New("foo") + + handler := middleware.CorrelationID(func(msg *message.Message) ([]*message.Message, error) { + return message.Messages{message.NewMessage("2", nil)}, handlerErr + }) + + msg := message.NewMessage("1", nil) + middleware.SetCorrelationID("correlation_id", msg) + + producedMsgs, err := handler(msg) + + assert.Equal(t, middleware.MessageCorrelationID(producedMsgs[0]), "correlation_id") + assert.Equal(t, handlerErr, err) +} diff --git a/message/router/middleware/instant_ack_test.go b/message/router/middleware/instant_ack_test.go new file mode 100644 index 000000000..a41dce8be --- /dev/null +++ b/message/router/middleware/instant_ack_test.go @@ -0,0 +1,33 @@ +package middleware + +import ( + "testing" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestInstantAck(t *testing.T) { + producedMessages := message.Messages{message.NewMessage("2", nil)} + producedErr := errors.New("foo") + + h := InstantAck(func(msg *message.Message) (messages []*message.Message, e error) { + return producedMessages, producedErr + }) + + msg := message.NewMessage("1", nil) + + handlerMessages, handlerErr := h(msg) + assert.EqualValues(t, producedMessages, handlerMessages) + assert.Equal(t, producedErr, handlerErr) + + select { + case <-msg.Acked(): + // ok + case <-msg.Nacked(): + t.Fatal("expected ack, not nack") + default: + t.Fatal("no ack received") + } +} diff --git a/message/router/middleware/randomfail_test.go b/message/router/middleware/randomfail_test.go new file mode 100644 index 000000000..0f1867e7e --- /dev/null +++ b/message/router/middleware/randomfail_test.go @@ -0,0 +1,28 @@ +package middleware_test + +import ( + "testing" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/ThreeDotsLabs/watermill/message/router/middleware" + "github.com/stretchr/testify/assert" +) + +func TestRandomFail(t *testing.T) { + h := middleware.RandomFail(1)(func(msg *message.Message) (messages []*message.Message, e error) { + return nil, nil + }) + + _, err := h(message.NewMessage("1", nil)) + assert.Error(t, err) +} + +func TestRandomPanic(t *testing.T) { + h := middleware.RandomPanic(1)(func(msg *message.Message) (messages []*message.Message, e error) { + return nil, nil + }) + + assert.Panics(t, func() { + _, _ = h(message.NewMessage("1", nil)) + }) +} diff --git a/message/router/middleware/recoverer_test.go b/message/router/middleware/recoverer_test.go new file mode 100644 index 000000000..79a0f3f64 --- /dev/null +++ b/message/router/middleware/recoverer_test.go @@ -0,0 +1,19 @@ +package middleware_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/ThreeDotsLabs/watermill/message/router/middleware" +) + +func TestRecoverer(t *testing.T) { + h := middleware.Recoverer(func(msg *message.Message) (messages []*message.Message, e error) { + panic("foo") + }) + + _, err := h(message.NewMessage("1", nil)) + assert.Error(t, err) +} diff --git a/message/router/middleware/retry.go b/message/router/middleware/retry.go index 15ce28925..c415c60ba 100644 --- a/message/router/middleware/retry.go +++ b/message/router/middleware/retry.go @@ -69,5 +69,5 @@ func (r Retry) calculateWaitTime() time.Duration { } func (r Retry) shouldRetry(err error, retries int) bool { - return err != nil && (retries <= r.MaxRetries || r.MaxRetries == RetryForever) + return err != nil && (retries < r.MaxRetries || r.MaxRetries == RetryForever) } diff --git a/message/router/middleware/retry_test.go b/message/router/middleware/retry_test.go new file mode 100644 index 000000000..a1d70e249 --- /dev/null +++ b/message/router/middleware/retry_test.go @@ -0,0 +1,93 @@ +package middleware_test + +import ( + "testing" + "time" + + "github.com/ThreeDotsLabs/watermill" + + "github.com/stretchr/testify/assert" + + "github.com/ThreeDotsLabs/watermill/message" + "github.com/pkg/errors" + + "github.com/ThreeDotsLabs/watermill/message/router/middleware" +) + +func TestRetry_retry(t *testing.T) { + retry := middleware.Retry{ + MaxRetries: 1, + } + + runCount := 0 + producedMessages := message.Messages{message.NewMessage("2", nil)} + + h := retry.Middleware(func(msg *message.Message) (messages []*message.Message, e error) { + runCount++ + if runCount == 0 { + return nil, errors.New("foo") + } + + return producedMessages, nil + }) + + handlerMessages, handlerErr := h(message.NewMessage("1", nil)) + + assert.Equal(t, 1, runCount) + assert.EqualValues(t, producedMessages, handlerMessages) + assert.NoError(t, handlerErr) +} + +func TestRetry_max_retries(t *testing.T) { + retry := middleware.Retry{ + MaxRetries: 1, + } + + runCount := 0 + + h := retry.Middleware(func(msg *message.Message) (messages []*message.Message, e error) { + runCount++ + return nil, errors.New("foo") + }) + + _, err := h(message.NewMessage("1", nil)) + + assert.Equal(t, 2, runCount) + assert.EqualError(t, err, "foo") +} + +func TestRetry_retry_hook(t *testing.T) { + var retriesFromHook []int + + retry := middleware.Retry{ + MaxRetries: 2, + OnRetryHook: func(retryNum int, delay time.Duration) { + retriesFromHook = append(retriesFromHook, retryNum) + }, + } + + h := retry.Middleware(func(msg *message.Message) (messages []*message.Message, e error) { + return nil, errors.New("foo") + }) + _, _ = h(message.NewMessage("1", nil)) + + assert.EqualValues(t, []int{1, 2}, retriesFromHook) +} + +func TestRetry_logger(t *testing.T) { + logger := watermill.NewCaptureLogger() + + retry := middleware.Retry{ + MaxRetries: 2, + Logger: logger, + } + + handlerErr := errors.New("foo") + + h := retry.Middleware(func(msg *message.Message) (messages []*message.Message, e error) { + return nil, handlerErr + }) + _, _ = h(message.NewMessage("1", nil)) + + assert.True(t, logger.HasError(handlerErr)) +} diff --git a/message/router_test.go b/message/router_test.go index 95718a9c5..c44d62543 100644 --- a/message/router_test.go +++ b/message/router_test.go @@ -27,7 +27,7 @@ func TestRouter_functional(t *testing.T) { assert.NoError(t, pubSub.Close()) }() - messagesCount := 100 + messagesCount := 50 var expectedReceivedMessages message.Messages allMessagesSent := make(chan struct{}) @@ -237,7 +237,7 @@ func TestRouterNoPublisherHandler(t *testing.T) { r, err := message.NewRouter( message.RouterConfig{}, - &logger, + logger, ) require.NoError(t, err) @@ -344,7 +344,7 @@ func publishMessagesForHandler(t *testing.T, messagesCount int, pubSub message.P } func createPubSub() (message.PubSub, error) { - return gochannel.NewGoChannel(0, watermill.NewStdLogger(true, true), time.Second*10), nil + return gochannel.NewPersistentGoChannel(0, watermill.NewStdLogger(true, true)), nil } func readMessages(messagesCh <-chan *message.Message, limit int, timeout time.Duration) (receivedMessages []*message.Message, all bool) { diff --git a/message/subscriber.go b/message/subscriber.go index b3183f0b4..6b51b3b4f 100644 --- a/message/subscriber.go +++ b/message/subscriber.go @@ -15,3 +15,14 @@ type Subscriber interface { // Close closes all subscriptions with their output channels and flush offsets etc. when needed. Close() error } + +type SubscribeInitializer interface { + // SubscribeInitialize can be called to initialize subscribe before consume. + // When calling Subscribe before Publish, SubscribeInitialize should be not required. + // + // Not every Pub/Sub requires this initialize and it may be optional for performance improvements etc. + // For detailed SubscribeInitialize functionality, please check Pub/Subs godoc. + // + // Implementing SubscribeInitialize is not obligatory. + SubscribeInitialize(topic string) error +} diff --git a/message/subscriber/read.go b/message/subscriber/read.go index 74e6be41f..10da54a9b 100644 --- a/message/subscriber/read.go +++ b/message/subscriber/read.go @@ -7,56 +7,43 @@ import ( ) func BulkRead(messagesCh <-chan *message.Message, limit int, timeout time.Duration) (receivedMessages message.Messages, all bool) { - allMessagesReceived := make(chan struct{}, 1) +MessagesLoop: + for len(receivedMessages) < limit { + select { + case msg, ok := <-messagesCh: + if !ok { + break MessagesLoop + } - go func() { - for msg := range messagesCh { receivedMessages = append(receivedMessages, msg) msg.Ack() - - if len(receivedMessages) == limit { - allMessagesReceived <- struct{}{} - break - } + case <-time.After(timeout): + break MessagesLoop } - // messagesCh closed - allMessagesReceived <- struct{}{} - }() - - select { - case <-allMessagesReceived: - case <-time.After(timeout): } return receivedMessages, len(receivedMessages) == limit } -// todo -add tests & deduplicate func BulkReadWithDeduplication(messagesCh <-chan *message.Message, limit int, timeout time.Duration) (receivedMessages message.Messages, all bool) { - allMessagesReceived := make(chan struct{}, 1) - receivedIDs := map[string]struct{}{} - go func() { - for msg := range messagesCh { - if _, alreadyReceived := receivedIDs[msg.UUID]; !alreadyReceived { - receivedMessages = append(receivedMessages, msg) - receivedIDs[msg.UUID] = struct{}{} +MessagesLoop: + for len(receivedMessages) < limit { + select { + case msg, ok := <-messagesCh: + if !ok { + break MessagesLoop } - msg.Ack() - if len(receivedMessages) == limit { - allMessagesReceived <- struct{}{} - break + if _, ok := receivedIDs[msg.UUID]; !ok { + receivedIDs[msg.UUID] = struct{}{} + receivedMessages = append(receivedMessages, msg) } + msg.Ack() + case <-time.After(timeout): + break MessagesLoop } - // messagesCh closed - allMessagesReceived <- struct{}{} - }() - - select { - case <-allMessagesReceived: - case <-time.After(timeout): } return receivedMessages, len(receivedMessages) == limit diff --git a/message/subscriber/read_test.go b/message/subscriber/read_test.go index 7e7280230..fbe8a7ba1 100644 --- a/message/subscriber/read_test.go +++ b/message/subscriber/read_test.go @@ -5,98 +5,188 @@ import ( "time" "github.com/ThreeDotsLabs/watermill/internal/tests" + "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/message/subscriber" "github.com/satori/go.uuid" "github.com/stretchr/testify/assert" ) +type bulkReadFunc func(messagesCh <-chan *message.Message, limit int, timeout time.Duration) (receivedMessages message.Messages, all bool) + func TestBulkRead(t *testing.T) { - messagesCount := 100 + testCases := []struct { + Name string + BulkReadFunc bulkReadFunc + }{ + { + Name: "BulkRead", + BulkReadFunc: subscriber.BulkRead, + }, + { + Name: "BulkReadWithDeduplication", + BulkReadFunc: subscriber.BulkReadWithDeduplication, + }, + } - var messages []*message.Message - messagesCh := make(chan *message.Message, messagesCount) + for _, c := range testCases { + t.Run(c.Name, func(t *testing.T) { + messagesCount := 100 - for i := 0; i < messagesCount; i++ { - msg := message.NewMessage(uuid.NewV4().String(), nil) + var messages []*message.Message + messagesCh := make(chan *message.Message, messagesCount) - messages = append(messages, msg) - messagesCh <- msg - } + for i := 0; i < messagesCount; i++ { + msg := message.NewMessage(uuid.NewV4().String(), nil) - readMessages, all := subscriber.BulkRead(messagesCh, messagesCount, time.Second) - assert.True(t, all) + messages = append(messages, msg) + messagesCh <- msg + } + + readMessages, all := subscriber.BulkRead(messagesCh, messagesCount, time.Second) + assert.True(t, all) - tests.AssertAllMessagesReceived(t, messages, readMessages) + tests.AssertAllMessagesReceived(t, messages, readMessages) + }) + } } func TestBulkRead_timeout(t *testing.T) { - messagesCount := 100 - sendLimit := 90 + testCases := []struct { + Name string + BulkReadFunc bulkReadFunc + }{ + { + Name: "BulkRead", + BulkReadFunc: subscriber.BulkRead, + }, + { + Name: "BulkReadWithDeduplication", + BulkReadFunc: subscriber.BulkReadWithDeduplication, + }, + } - var messages []*message.Message - messagesCh := make(chan *message.Message, messagesCount) + for _, c := range testCases { + t.Run(c.Name, func(t *testing.T) { + messagesCount := 100 + sendLimit := 90 - for i := 0; i < messagesCount; i++ { - msg := message.NewMessage(uuid.NewV4().String(), nil) + var messages []*message.Message + messagesCh := make(chan *message.Message, messagesCount) - messages = append(messages, msg) + for i := 0; i < messagesCount; i++ { + msg := message.NewMessage(uuid.NewV4().String(), nil) - if i < sendLimit { - messagesCh <- msg - } - } + messages = append(messages, msg) - bulkReadStart := time.Now() - readMessages, all := subscriber.BulkRead(messagesCh, messagesCount, time.Millisecond) + if i < sendLimit { + messagesCh <- msg + } + } - assert.WithinDuration(t, bulkReadStart, time.Now(), time.Millisecond*100) - assert.False(t, all) - assert.Equal(t, sendLimit, len(readMessages)) + bulkReadStart := time.Now() + readMessages, all := subscriber.BulkRead(messagesCh, messagesCount, time.Millisecond) + + assert.WithinDuration(t, bulkReadStart, time.Now(), time.Millisecond*100) + assert.False(t, all) + assert.Equal(t, sendLimit, len(readMessages)) + }) + } } func TestBulkRead_with_limit(t *testing.T) { - messagesCount := 110 - limit := 100 + testCases := []struct { + Name string + BulkReadFunc bulkReadFunc + }{ + { + Name: "BulkRead", + BulkReadFunc: subscriber.BulkRead, + }, + { + Name: "BulkReadWithDeduplication", + BulkReadFunc: subscriber.BulkReadWithDeduplication, + }, + } - var messages []*message.Message - messagesCh := make(chan *message.Message, messagesCount) + for _, c := range testCases { + t.Run(c.Name, func(t *testing.T) { + messagesCount := 110 + limit := 100 - for i := 0; i < messagesCount; i++ { - msg := message.NewMessage(uuid.NewV4().String(), nil) + var messages []*message.Message + messagesCh := make(chan *message.Message, messagesCount) - messages = append(messages, msg) - messagesCh <- msg - } + for i := 0; i < messagesCount; i++ { + msg := message.NewMessage(uuid.NewV4().String(), nil) - readMessages, all := subscriber.BulkRead(messagesCh, limit, time.Second) - assert.True(t, all) - assert.Equal(t, limit, len(readMessages)) + messages = append(messages, msg) + messagesCh <- msg + } + + readMessages, all := subscriber.BulkRead(messagesCh, limit, time.Second) + assert.True(t, all) + assert.Equal(t, limit, len(readMessages)) + }) + } } func TestBulkRead_return_on_channel_close(t *testing.T) { - messagesCount := 100 - sendLimit := 90 - - var messages []*message.Message - messagesCh := make(chan *message.Message, messagesCount) - messagesChClosed := false - - for i := 0; i < messagesCount; i++ { - msg := message.NewMessage(uuid.NewV4().String(), nil) - messages = append(messages, msg) - - if i < sendLimit { - messagesCh <- msg - } else if !messagesChClosed { - close(messagesCh) - messagesChClosed = true - } + testCases := []struct { + Name string + BulkReadFunc bulkReadFunc + }{ + { + Name: "BulkRead", + BulkReadFunc: subscriber.BulkRead, + }, + { + Name: "BulkReadWithDeduplication", + BulkReadFunc: subscriber.BulkReadWithDeduplication, + }, } - bulkReadStart := time.Now() - _, all := subscriber.BulkRead(messagesCh, messagesCount, time.Second) + for _, c := range testCases { + t.Run(c.Name, func(t *testing.T) { + messagesCount := 100 + sendLimit := 90 + + var messages []*message.Message + messagesCh := make(chan *message.Message, messagesCount) + messagesChClosed := false + + for i := 0; i < messagesCount; i++ { + msg := message.NewMessage(uuid.NewV4().String(), nil) + messages = append(messages, msg) + + if i < sendLimit { + messagesCh <- msg + } else if !messagesChClosed { + close(messagesCh) + messagesChClosed = true + } + } + + bulkReadStart := time.Now() + _, all := subscriber.BulkRead(messagesCh, messagesCount, time.Second) + + assert.WithinDuration(t, bulkReadStart, time.Now(), time.Millisecond*100) + assert.False(t, all) + }) + } +} + +func TestBulkReadWithDeduplication(t *testing.T) { + messagesCh := make(chan *message.Message, 3) + + msg1 := message.NewMessage(uuid.NewV4().String(), nil) + msg2 := message.NewMessage(uuid.NewV4().String(), nil) + messagesCh <- msg1 + messagesCh <- msg1 + messagesCh <- msg2 + + readMessages, all := subscriber.BulkReadWithDeduplication(messagesCh, 2, time.Second) + assert.True(t, all) - assert.WithinDuration(t, bulkReadStart, time.Now(), time.Millisecond*100) - assert.False(t, all) + assert.Equal(t, []string{msg1.UUID, msg2.UUID}, readMessages.IDs()) }