Skip to content

Commit

Permalink
use a threadsafe writer in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sethgrid committed Dec 3, 2024
1 parent 1283348 commit c602906
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 10 deletions.
64 changes: 64 additions & 0 deletions logger/lockbuffer/lockbuffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Package lockbuffer provides a thread-safe buffer that can be leveraged during testing to capture logs or other output.

package lockbuffer

import (
"bytes"
"io"
"sync"
)

// LockBuffer is a thread-safe ReadWriter that wraps a bytes.Buffer
type LockBuffer struct {
mu sync.Mutex
buffer *bytes.Buffer
}

// NewLockBuffer creates a new LockBuffer instance
func NewLockBuffer() *LockBuffer {
return &LockBuffer{
buffer: bytes.NewBuffer(nil),
}
}

// Write writes to the buffer safely
func (lb *LockBuffer) Write(p []byte) (int, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.Write(p)
}

// Read reads from the buffer safely
func (lb *LockBuffer) Read(p []byte) (int, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.Read(p)
}

// WriteTo implements the io.WriterTo interface, writing the buffer's content to the given Writer safely
func (lb *LockBuffer) WriteTo(w io.Writer) (int64, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.WriteTo(w)
}

// ReadFrom implements the io.ReaderFrom interface, reading content into the buffer from the given Reader safely
func (lb *LockBuffer) ReadFrom(r io.Reader) (int64, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.ReadFrom(r)
}

// Bytes returns the buffer's content as a byte slice safely
func (lb *LockBuffer) Bytes() []byte {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.Bytes()
}

// String returns the buffer's content as a string safely
func (lb *LockBuffer) String() string {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.String()
}
10 changes: 6 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package server

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"io"
"log"
"log/slog"
"net"
Expand Down Expand Up @@ -352,10 +352,12 @@ func (s *Server) Serve() error {
return nil
}

// WithLogbuf is a test helper for server configuration to override logger's buffer
func WithLogbuf(buf *bytes.Buffer) func(*Server) {
// WithLogWriter is a test helper for server configuration to override logger's writer.
// Typically used with the lockbuffer package for testing, allowing concurrent reads and writes,
// preventing races in the test suite.
func WithLogWriter(w io.Writer) func(*Server) {
return func(s *Server) {
logger := slog.New(slog.NewJSONHandler(buf, nil))
logger := slog.New(slog.NewJSONHandler(w, nil))
s.parentLogger = logger
s.taskq = taskqueue.NewInMemoryTaskQueue(1, 15*time.Second, logger)
}
Expand Down
11 changes: 5 additions & 6 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"
"time"

"github.com/sethgrid/helloworld/logger/lockbuffer"
"github.com/sethgrid/helloworld/taskqueue"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -30,9 +31,9 @@ func TestHealthcheck(t *testing.T) {
}

func TestEventStoreErr(t *testing.T) {
var logbuf bytes.Buffer
logbuf := lockbuffer.NewLockBuffer()

srv, err := newTestServer(WithLogbuf(&logbuf))
srv, err := newTestServer(WithLogWriter(logbuf))
require.NoError(t, err)
defer srv.Close()

Expand Down Expand Up @@ -83,21 +84,19 @@ func TestGracefulShutdown(t *testing.T) {
}

func TestContextTimeoutAndRequestTimeout(t *testing.T) {
var logbuf bytes.Buffer
logbuf := lockbuffer.NewLockBuffer()
customConfig := Config{
// server kills any request that takes longer than this
RequestTimeout: 100 * time.Millisecond,
}

srv, err := newTestServer(WithConfig(customConfig), WithLogbuf(&logbuf))
srv, err := newTestServer(WithConfig(customConfig), WithLogWriter(logbuf))
require.NoError(t, err)

source := fmt.Sprintf("http://localhost:%d/?delay=101ms", srv.Port())
_, err = http.Get(source)
require.Error(t, err)

// close the server to prevent concurrent writes to the log buffer so we can assert on it
assert.NoError(t, srv.Close())
assert.Contains(t, logbuf.String(), `"error":"context canceled"`)

}
Expand Down

0 comments on commit c602906

Please sign in to comment.