Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix race condition in TestMapExpiringKeyRepositoryCleanup #535

Merged
merged 3 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 47 additions & 26 deletions log.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package watermill

import (
"errors"
"fmt"
"io"
"log"
Expand All @@ -9,6 +10,7 @@ import (
"sort"
"strings"
"sync"
"time"
)

// LogFields is the logger's key-value list of fields.
Expand Down Expand Up @@ -162,34 +164,58 @@ const (

type CapturedMessage struct {
Level LogLevel
Time time.Time
Fields LogFields
Msg string
Err error
}

func (c CapturedMessage) ContentEquals(other CapturedMessage) bool {
return c.Level == other.Level &&
reflect.DeepEqual(c.Fields, other.Fields) &&
c.Msg == other.Msg &&
errors.Is(c.Err, other.Err)
}

// CaptureLoggerAdapter is a logger which captures all logs.
// This logger is mostly useful for testing logging.
type CaptureLoggerAdapter struct {
captured map[LogLevel][]CapturedMessage
fields LogFields
lock sync.Mutex
lock *sync.Mutex
}

func NewCaptureLogger() *CaptureLoggerAdapter {
return &CaptureLoggerAdapter{
captured: map[LogLevel][]CapturedMessage{},
lock: &sync.Mutex{},
}
}

func (c *CaptureLoggerAdapter) With(fields LogFields) LoggerAdapter {
return &CaptureLoggerAdapter{captured: c.captured, fields: c.fields.Add(fields)}
c.lock.Lock()
defer c.lock.Unlock()

return &CaptureLoggerAdapter{
captured: c.captured, // we are passing the same map, so we'll capture logs from this instance as well
fields: c.fields.Copy().Add(fields),
lock: c.lock,
}
}

func (c *CaptureLoggerAdapter) capture(msg CapturedMessage) {
func (c *CaptureLoggerAdapter) capture(level LogLevel, msg string, err error, fields LogFields) {
c.lock.Lock()
defer c.lock.Unlock()

c.captured[msg.Level] = append(c.captured[msg.Level], msg)
logMsg := CapturedMessage{
Level: level,
Time: time.Now(),
Fields: c.fields.Add(fields),
Msg: msg,
Err: err,
}

c.captured[level] = append(c.captured[level], logMsg)
}

func (c *CaptureLoggerAdapter) Captured() map[LogLevel][]CapturedMessage {
Expand All @@ -199,12 +225,24 @@ func (c *CaptureLoggerAdapter) Captured() map[LogLevel][]CapturedMessage {
return c.captured
}

type Logfer interface {
Logf(format string, a ...interface{})
}

func (c *CaptureLoggerAdapter) PrintCaptured(t Logfer) {
for level, messages := range c.Captured() {
for _, msg := range messages {
t.Logf("%s %d %s %v", msg.Time.Format("15:04:05.999999999"), level, msg.Msg, msg.Fields)
}
}
}

func (c *CaptureLoggerAdapter) Has(msg CapturedMessage) bool {
c.lock.Lock()
defer c.lock.Unlock()

for _, capturedMsg := range c.captured[msg.Level] {
if reflect.DeepEqual(msg, capturedMsg) {
if msg.ContentEquals(capturedMsg) {
return true
}
}
Expand All @@ -224,34 +262,17 @@ func (c *CaptureLoggerAdapter) HasError(err error) bool {
}

func (c *CaptureLoggerAdapter) Error(msg string, err error, fields LogFields) {
c.capture(CapturedMessage{
Level: ErrorLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
Err: err,
})
c.capture(ErrorLogLevel, msg, err, fields)
}

func (c *CaptureLoggerAdapter) Info(msg string, fields LogFields) {
c.capture(CapturedMessage{
Level: InfoLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
})
c.capture(InfoLogLevel, msg, nil, fields)
}

func (c *CaptureLoggerAdapter) Debug(msg string, fields LogFields) {
c.capture(CapturedMessage{
Level: DebugLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
})
c.capture(DebugLogLevel, msg, nil, fields)
}

func (c *CaptureLoggerAdapter) Trace(msg string, fields LogFields) {
c.capture(CapturedMessage{
Level: TraceLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
})
c.capture(TraceLogLevel, msg, nil, fields)
}
2 changes: 1 addition & 1 deletion log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func TestCaptureLoggerAdapter(t *testing.T) {
}

capturedLogger := logger.(*watermill.CaptureLoggerAdapter)
assert.EqualValues(t, expectedLogs, capturedLogger.Captured())

assert.Equal(t, len(expectedLogs), len(capturedLogger.Captured()))
for _, logs := range expectedLogs {
for _, log := range logs {
assert.True(t, capturedLogger.Has(log))
Expand Down
18 changes: 12 additions & 6 deletions message/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,13 @@ func (r *Router) RunHandlers(ctx context.Context) error {
return errors.Wrapf(err, "could not decorate subscriber of handler %s", name)
}

r.logger.Debug("Subscribing to topic", watermill.LogFields{
logger := r.logger.With(watermill.LogFields{
"subscriber_name": h.name,
"topic": h.subscribeTopic,
})

logger.Debug("Subscribing to topic", nil)

ctx, cancel := context.WithCancel(ctx)

messages, err := h.subscriber.Subscribe(ctx, h.subscribeTopic)
Expand All @@ -458,14 +460,15 @@ func (r *Router) RunHandlers(ctx context.Context) error {
h.run(ctx, middlewares)

r.handlersWg.Done()
r.logger.Info("Subscriber stopped", watermill.LogFields{
"subscriber_name": h.name,
"topic": h.subscribeTopic,
})
logger.Info("Subscriber stopped", nil)

r.handlersLock.Lock()
delete(r.handlers, name)
r.handlersLock.Unlock()

logger.Trace("Removed subscriber from r.handlers", nil)

close(h.stopped)
}()
}
return nil
Expand All @@ -492,6 +495,7 @@ func (r *Router) closeWhenAllHandlersStopped(ctx context.Context) {

r.handlersWg.Wait()
if r.IsClosed() {
r.logger.Trace("closeWhenAllHandlersStopped: already closed", nil)
// already closed
return
}
Expand Down Expand Up @@ -543,8 +547,11 @@ func (r *Router) Close() error {
defer r.handlersLock.Unlock()

if r.closed {
r.logger.Debug("Already closed", nil)
return nil
}

r.logger.Debug("Running Close()", nil)
r.closed = true

r.logger.Info("Closing router", nil)
Expand Down Expand Up @@ -649,7 +656,6 @@ func (h *handler) run(ctx context.Context, middlewares []middleware) {
}

h.logger.Debug("Router handler stopped", nil)
close(h.stopped)
}

// Handler handles Messages.
Expand Down
4 changes: 3 additions & 1 deletion message/router/middleware/deduplicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ func NewMapExpiringKeyRepository(window time.Duration) (ExpiringKeyRepository, e
mu: &sync.Mutex{},
tags: make(map[string]time.Time),
}
go kr.cleanOutLoop(context.Background(), time.NewTicker(window/2))
ticker := time.NewTicker(window / 2)

go kr.cleanOutLoop(context.Background(), ticker)
return kr, nil
}

Expand Down
26 changes: 19 additions & 7 deletions message/router/middleware/deduplicator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,23 @@ func TestMapExpiringKeyRepositoryCleanup(t *testing.T) {
t.Errorf("expected 6 tags, but %d remain", l)
}

time.Sleep(wait * 2)
if count != 6 {
t.Errorf("sent six messages, but only received %d", count)
}
if l := measurable.Len(); l != 0 {
t.Errorf("tags should have been cleaned out, but %d remain", l)
}
assert.Eventually(
t,
func() bool {
return count == 6
},
wait*3,
time.Millisecond,
"sent six messages, but only received %d", count,
)
assert.Eventually(
t,
func() bool {
return measurable.Len() == 0
},
wait*3,
time.Millisecond,
"tags should have been cleaned out, but %d remain",
measurable.Len(),
)
}
16 changes: 12 additions & 4 deletions message/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,8 @@ func TestRouter_stopping_all_handlers_logs_error(t *testing.T) {

logger := watermill.NewCaptureLogger()

defer logger.PrintCaptured(t)

r, err := message.NewRouter(message.RouterConfig{}, logger)
require.NoError(t, err)

Expand All @@ -1392,13 +1394,19 @@ func TestRouter_stopping_all_handlers_logs_error(t *testing.T) {
}()
<-r.Running()

// Stop the subscriber - this should close the router with an error
// Stop the subscriber - this should close the router with an error logged
err = sub.Close()
require.NoError(t, err)

require.Eventually(t, func() bool {
return r.IsClosed()
}, 1*time.Second, 1*time.Millisecond, "Router should be closed after all handlers are stopped")
require.Eventually(
t,
func() bool {
return r.IsClosed()
},
1*time.Second,
1*time.Millisecond,
"Router should be closed after all handlers are stopped",
)

expectedLogMessage := watermill.CapturedMessage{
Level: watermill.ErrorLogLevel,
Expand Down
Loading