Skip to content

Commit

Permalink
fix a couple of race conditions in tests (#535)
Browse files Browse the repository at this point in the history
  • Loading branch information
roblaszczak authored Jan 10, 2025
1 parent 457a1ce commit 0e64e21
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 45 deletions.
73 changes: 47 additions & 26 deletions log.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package watermill

import (
"errors"
"fmt"
"io"
"log"
Expand All @@ -9,6 +10,7 @@ import (
"sort"
"strings"
"sync"
"time"
)

// LogFields is the logger's key-value list of fields.
Expand Down Expand Up @@ -162,34 +164,58 @@ const (

type CapturedMessage struct {
Level LogLevel
Time time.Time
Fields LogFields
Msg string
Err error
}

func (c CapturedMessage) ContentEquals(other CapturedMessage) bool {
return c.Level == other.Level &&
reflect.DeepEqual(c.Fields, other.Fields) &&
c.Msg == other.Msg &&
errors.Is(c.Err, other.Err)
}

// CaptureLoggerAdapter is a logger which captures all logs.
// This logger is mostly useful for testing logging.
type CaptureLoggerAdapter struct {
captured map[LogLevel][]CapturedMessage
fields LogFields
lock sync.Mutex
lock *sync.Mutex
}

func NewCaptureLogger() *CaptureLoggerAdapter {
return &CaptureLoggerAdapter{
captured: map[LogLevel][]CapturedMessage{},
lock: &sync.Mutex{},
}
}

func (c *CaptureLoggerAdapter) With(fields LogFields) LoggerAdapter {
return &CaptureLoggerAdapter{captured: c.captured, fields: c.fields.Add(fields)}
c.lock.Lock()
defer c.lock.Unlock()

return &CaptureLoggerAdapter{
captured: c.captured, // we are passing the same map, so we'll capture logs from this instance as well
fields: c.fields.Copy().Add(fields),
lock: c.lock,
}
}

func (c *CaptureLoggerAdapter) capture(msg CapturedMessage) {
func (c *CaptureLoggerAdapter) capture(level LogLevel, msg string, err error, fields LogFields) {
c.lock.Lock()
defer c.lock.Unlock()

c.captured[msg.Level] = append(c.captured[msg.Level], msg)
logMsg := CapturedMessage{
Level: level,
Time: time.Now(),
Fields: c.fields.Add(fields),
Msg: msg,
Err: err,
}

c.captured[level] = append(c.captured[level], logMsg)
}

func (c *CaptureLoggerAdapter) Captured() map[LogLevel][]CapturedMessage {
Expand All @@ -199,12 +225,24 @@ func (c *CaptureLoggerAdapter) Captured() map[LogLevel][]CapturedMessage {
return c.captured
}

type Logfer interface {
Logf(format string, a ...interface{})
}

func (c *CaptureLoggerAdapter) PrintCaptured(t Logfer) {
for level, messages := range c.Captured() {
for _, msg := range messages {
t.Logf("%s %d %s %v", msg.Time.Format("15:04:05.999999999"), level, msg.Msg, msg.Fields)
}
}
}

func (c *CaptureLoggerAdapter) Has(msg CapturedMessage) bool {
c.lock.Lock()
defer c.lock.Unlock()

for _, capturedMsg := range c.captured[msg.Level] {
if reflect.DeepEqual(msg, capturedMsg) {
if msg.ContentEquals(capturedMsg) {
return true
}
}
Expand All @@ -224,34 +262,17 @@ func (c *CaptureLoggerAdapter) HasError(err error) bool {
}

func (c *CaptureLoggerAdapter) Error(msg string, err error, fields LogFields) {
c.capture(CapturedMessage{
Level: ErrorLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
Err: err,
})
c.capture(ErrorLogLevel, msg, err, fields)
}

func (c *CaptureLoggerAdapter) Info(msg string, fields LogFields) {
c.capture(CapturedMessage{
Level: InfoLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
})
c.capture(InfoLogLevel, msg, nil, fields)
}

func (c *CaptureLoggerAdapter) Debug(msg string, fields LogFields) {
c.capture(CapturedMessage{
Level: DebugLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
})
c.capture(DebugLogLevel, msg, nil, fields)
}

func (c *CaptureLoggerAdapter) Trace(msg string, fields LogFields) {
c.capture(CapturedMessage{
Level: TraceLogLevel,
Fields: c.fields.Add(fields),
Msg: msg,
})
c.capture(TraceLogLevel, msg, nil, fields)
}
2 changes: 1 addition & 1 deletion log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ func TestCaptureLoggerAdapter(t *testing.T) {
}

capturedLogger := logger.(*watermill.CaptureLoggerAdapter)
assert.EqualValues(t, expectedLogs, capturedLogger.Captured())

assert.Equal(t, len(expectedLogs), len(capturedLogger.Captured()))
for _, logs := range expectedLogs {
for _, log := range logs {
assert.True(t, capturedLogger.Has(log))
Expand Down
18 changes: 12 additions & 6 deletions message/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,13 @@ func (r *Router) RunHandlers(ctx context.Context) error {
return errors.Wrapf(err, "could not decorate subscriber of handler %s", name)
}

r.logger.Debug("Subscribing to topic", watermill.LogFields{
logger := r.logger.With(watermill.LogFields{
"subscriber_name": h.name,
"topic": h.subscribeTopic,
})

logger.Debug("Subscribing to topic", nil)

ctx, cancel := context.WithCancel(ctx)

messages, err := h.subscriber.Subscribe(ctx, h.subscribeTopic)
Expand All @@ -458,14 +460,15 @@ func (r *Router) RunHandlers(ctx context.Context) error {
h.run(ctx, middlewares)

r.handlersWg.Done()
r.logger.Info("Subscriber stopped", watermill.LogFields{
"subscriber_name": h.name,
"topic": h.subscribeTopic,
})
logger.Info("Subscriber stopped", nil)

r.handlersLock.Lock()
delete(r.handlers, name)
r.handlersLock.Unlock()

logger.Trace("Removed subscriber from r.handlers", nil)

close(h.stopped)
}()
}
return nil
Expand All @@ -492,6 +495,7 @@ func (r *Router) closeWhenAllHandlersStopped(ctx context.Context) {

r.handlersWg.Wait()
if r.IsClosed() {
r.logger.Trace("closeWhenAllHandlersStopped: already closed", nil)
// already closed
return
}
Expand Down Expand Up @@ -543,8 +547,11 @@ func (r *Router) Close() error {
defer r.handlersLock.Unlock()

if r.closed {
r.logger.Debug("Already closed", nil)
return nil
}

r.logger.Debug("Running Close()", nil)
r.closed = true

r.logger.Info("Closing router", nil)
Expand Down Expand Up @@ -649,7 +656,6 @@ func (h *handler) run(ctx context.Context, middlewares []middleware) {
}

h.logger.Debug("Router handler stopped", nil)
close(h.stopped)
}

// Handler handles Messages.
Expand Down
4 changes: 3 additions & 1 deletion message/router/middleware/deduplicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ func NewMapExpiringKeyRepository(window time.Duration) (ExpiringKeyRepository, e
mu: &sync.Mutex{},
tags: make(map[string]time.Time),
}
go kr.cleanOutLoop(context.Background(), time.NewTicker(window/2))
ticker := time.NewTicker(window / 2)

go kr.cleanOutLoop(context.Background(), ticker)
return kr, nil
}

Expand Down
26 changes: 19 additions & 7 deletions message/router/middleware/deduplicator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,23 @@ func TestMapExpiringKeyRepositoryCleanup(t *testing.T) {
t.Errorf("expected 6 tags, but %d remain", l)
}

time.Sleep(wait * 2)
if count != 6 {
t.Errorf("sent six messages, but only received %d", count)
}
if l := measurable.Len(); l != 0 {
t.Errorf("tags should have been cleaned out, but %d remain", l)
}
assert.Eventually(
t,
func() bool {
return count == 6
},
wait*3,
time.Millisecond,
"sent six messages, but only received %d", count,
)
assert.Eventually(
t,
func() bool {
return measurable.Len() == 0
},
wait*3,
time.Millisecond,
"tags should have been cleaned out, but %d remain",
measurable.Len(),
)
}
16 changes: 12 additions & 4 deletions message/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,8 @@ func TestRouter_stopping_all_handlers_logs_error(t *testing.T) {

logger := watermill.NewCaptureLogger()

defer logger.PrintCaptured(t)

r, err := message.NewRouter(message.RouterConfig{}, logger)
require.NoError(t, err)

Expand All @@ -1392,13 +1394,19 @@ func TestRouter_stopping_all_handlers_logs_error(t *testing.T) {
}()
<-r.Running()

// Stop the subscriber - this should close the router with an error
// Stop the subscriber - this should close the router with an error logged
err = sub.Close()
require.NoError(t, err)

require.Eventually(t, func() bool {
return r.IsClosed()
}, 1*time.Second, 1*time.Millisecond, "Router should be closed after all handlers are stopped")
require.Eventually(
t,
func() bool {
return r.IsClosed()
},
1*time.Second,
1*time.Millisecond,
"Router should be closed after all handlers are stopped",
)

expectedLogMessage := watermill.CapturedMessage{
Level: watermill.ErrorLogLevel,
Expand Down

0 comments on commit 0e64e21

Please sign in to comment.