Skip to content

Commit

Permalink
Context timeout validation (#2)
Browse files Browse the repository at this point in the history
* ensure timeout is on request context to start and demonstrate using it in the handlers

* show timeout propagation via context

* some readme updates

* use a threadsafe log writer in tests
  • Loading branch information
sethgrid authored Dec 3, 2024
1 parent df1148f commit bf7ae6a
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 50 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ bin/*
settings.env
logs/*
docker-configs/loki/*
.DS_Store
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
What comes out of the box with this example implementation:

- fully unit and unit-integration testable json http web server
- tests showing graceful shutdown and context propagation
- able to spin up multiple running http servers in parallel during tests
- able to assert against the server's logged contents
- structured logging with `slog`
- bubbling key:value error data upto `slog` for better structured error logging
- bubbling key:value error data upto `slog` via the `kverr` package for better structured error logging
- fakes as test doubles, as a practical example against the need for mock and dependency injection frameworks
- uses prometheus, and you can see logs and metrics via grafana and loki in docker
- demonstrates building and testing via docker compose, see `make targets` for a list options

Some interesting choices:

- The test server takes a variadic list of log writers, but only takes the first one. This is the closest thing to an optional parameter. It makes writing tests nice because you use the same new test server constructor, and you can optionally send in a logger. This would probably be better implemented as Options.
- test servers take Options, most commonly to be used for passing in a buffer to which the server writes logs
- each request places a logger into the context and there is a package with functions for making life easier for pulling the logger out of existing contexts / requests. This includes a strange thing I did where I put in a backup logger that is kinda gross, but because it is a variadic argument, you never see it nor have to use it.
- There is a util directory. I know, I know. I still find value in a junk drawer. When it makes sense, things get pulled into their own package. And it is behind /internal/ anyway.
- For migrations, I use `goose`, but I pulled those examples out for now.
Expand Down
64 changes: 64 additions & 0 deletions logger/lockbuffer/lockbuffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Package lockbuffer provides a thread-safe buffer that can be leveraged during testing to capture logs or other output.

package lockbuffer

import (
"bytes"
"io"
"sync"
)

// LockBuffer is a thread-safe ReadWriter that wraps a bytes.Buffer
type LockBuffer struct {
mu sync.Mutex
buffer *bytes.Buffer
}

// NewLockBuffer creates a new LockBuffer instance
func NewLockBuffer() *LockBuffer {
return &LockBuffer{
buffer: bytes.NewBuffer(nil),
}
}

// Write writes to the buffer safely
func (lb *LockBuffer) Write(p []byte) (int, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.Write(p)
}

// Read reads from the buffer safely
func (lb *LockBuffer) Read(p []byte) (int, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.Read(p)
}

// WriteTo implements the io.WriterTo interface, writing the buffer's content to the given Writer safely
func (lb *LockBuffer) WriteTo(w io.Writer) (int64, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.WriteTo(w)
}

// ReadFrom implements the io.ReaderFrom interface, reading content into the buffer from the given Reader safely
func (lb *LockBuffer) ReadFrom(r io.Reader) (int64, error) {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.ReadFrom(r)
}

// Bytes returns the buffer's content as a byte slice safely
func (lb *LockBuffer) Bytes() []byte {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.Bytes()
}

// String returns the buffer's content as a string safely
func (lb *LockBuffer) String() string {
lb.mu.Lock()
defer lb.mu.Unlock()
return lb.buffer.String()
}
1 change: 0 additions & 1 deletion logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ func Middleware(logger *slog.Logger, shouldPrint bool) func(next http.Handler) h

start := time.Now()
metrics.InFlightGauge.Inc()
logger.Error("should increment flight gauge")

defer func() {
metrics.InFlightGauge.Dec()
Expand Down
2 changes: 1 addition & 1 deletion server/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type errorResp struct {
func (s *Server) ErrorJSON(w http.ResponseWriter, r *http.Request, statusCode int, userMsg string, err error) {
// NOTE: if this was a kverr, those key:value pairs will be pulled out and attached to our error log here
logger := logger.FromRequest(r).With("status_code", statusCode).With(kverr.YoinkArgs(err)...)
defer logger.Error(userMsg, "error", err.Error())
logger.Error(userMsg, "error", err.Error())

w.Header().Set("content-type", "application/json")
w.WriteHeader(statusCode)
Expand Down
29 changes: 27 additions & 2 deletions server/handlers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package server

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -27,12 +28,20 @@ func (s *Server) helloworldHandler(w http.ResponseWriter, r *http.Request) {
s.ErrorJSON(w, r, http.StatusBadRequest, "invalid delay", kverr.New(err, "delay", delay))
return
}
if duration > 10*time.Second {
duration = 10 * time.Second
if duration > 90*time.Second {
logger.FromRequest(r).Error("delay too long", "duration", duration.String())
duration = 1 * time.Millisecond
} else if duration < 1*time.Millisecond {
duration = 1 * time.Millisecond
}
time.Sleep(duration)

err = someWorkThatChecksContextDeadline(r.Context())
if err != nil {
s.ErrorJSON(w, r, http.StatusRequestTimeout, "context deadline exceeded", err)
return
}

} else if err := RandomFailure(); err != nil {
// NOTE: we don't have to tell other services that a kverr is being passed in
s.ErrorJSON(w, r, http.StatusInternalServerError, "random failure", err)
Expand All @@ -55,6 +64,22 @@ func RandomFailure() error {
return nil
}

// someWorkThatChecksContextDeadline is a demo showing how to check if the context is canceled
func someWorkThatChecksContextDeadline(ctx context.Context) error {
// Check if the context is canceled
select {
case <-ctx.Done():
// Context is canceled
return kverr.New(fmt.Errorf("context canceled"), "context_err", ctx.Err())
default:
// Context is still active
}

// Do whatever work you need to do

return nil
}

// DoSomethingWithEvents is for illustrative purposes of faking during tests,
// showing how faked dependencies bubble up in test assertions
func (s *Server) DoSomethingWithEvents() error {
Expand Down
63 changes: 44 additions & 19 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"crypto/x509"
"database/sql"
"fmt"
"io"
"log"
"log/slog"
"net"
Expand Down Expand Up @@ -210,9 +211,9 @@ func (s *Server) newRouter() *chi.Mux {
router.Use(customCORSMiddleware(origins))

router.Use(middleware.RealIP)
router.Use(timeoutMiddleware(s.config.RequestTimeout))
router.Use(logger.Middleware(s.parentLogger, s.inDebug))
router.Use(middleware.Recoverer)
router.Use(middleware.Timeout(60 * time.Second))

return router
}
Expand Down Expand Up @@ -283,10 +284,10 @@ func (s *Server) Serve() error {

go func() {
internalHTTP := http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 30 * time.Second,
ReadHeaderTimeout: 2 * time.Second,
ReadTimeout: s.config.RequestTimeout,
WriteTimeout: s.config.RequestTimeout,
IdleTimeout: s.config.RequestTimeout,
ReadHeaderTimeout: s.config.RequestTimeout,
Handler: privateRouter,
}

Expand Down Expand Up @@ -318,10 +319,10 @@ func (s *Server) Serve() error {
go runner.Start()

publicHTTP := http.Server{
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
IdleTimeout: 30 * time.Second,
ReadHeaderTimeout: 2 * time.Second,
ReadTimeout: s.config.RequestTimeout,
WriteTimeout: s.config.RequestTimeout,
IdleTimeout: s.config.RequestTimeout,
ReadHeaderTimeout: s.config.RequestTimeout,
Handler: router,
}

Expand Down Expand Up @@ -351,6 +352,32 @@ func (s *Server) Serve() error {
return nil
}

// WithLogWriter is a test helper for server configuration to override logger's writer.
// Typically used with the lockbuffer package for testing, allowing concurrent reads and writes,
// preventing races in the test suite.
func WithLogWriter(w io.Writer) func(*Server) {
return func(s *Server) {
logger := slog.New(slog.NewJSONHandler(w, nil))
s.parentLogger = logger
s.taskq = taskqueue.NewInMemoryTaskQueue(1, 15*time.Second, logger)
}
}

// WithLogger is a test helper for server configuration to override logger
func WithLogger(logger *slog.Logger) func(*Server) {
return func(s *Server) {
s.parentLogger = logger
s.taskq = taskqueue.NewInMemoryTaskQueue(1, 15*time.Second, logger)
}
}

// WithConfig is a test helper for overwriting server configuration
func WithConfig(config Config) func(*Server) {
return func(s *Server) {
s.config = config
}
}

func (s *Server) setLastError(err error) {
s.mu.Lock()
defer s.mu.Unlock()
Expand Down Expand Up @@ -405,15 +432,13 @@ func (s *Server) LastError() error {
return s.srvErr
}

func (s *Server) loggerMiddleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// create an instance of the logger bound to this single request
// if we change rid, also change rid's backup init in GetLoggerFromRequest
_, _, r = logger.NewRequestLogger(ctx, r, s.parentLogger)

next.ServeHTTP(w, r)
func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()
// Pass the new context to the next handler
next.ServeHTTP(w, r.WithContext(ctx))
})
}

return http.HandlerFunc(fn)
}
1 change: 1 addition & 0 deletions server/server_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Config struct {
EnableDebug bool `default:"true" envconfig:"enable_debug"`
TaskExpiration time.Duration `default:"1m" envconfig:"task_expiration"`
ShutdownTimeout time.Duration `default:"30s" envconfig:"shutdown_timeout"`
RequestTimeout time.Duration `default:"30s" envconfig:"shutdown_timeout"`

SGAPIKey string `default:"" envconfig:"sendgrid_apikey"`

Expand Down
13 changes: 7 additions & 6 deletions server/server_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@
package server

import (
"bytes"
"fmt"
"os"
"strings"
"testing"

"github.com/sethgrid/helloworld/logger/lockbuffer"
"github.com/stretchr/testify/require"
)

// launchOrGetTestServer will try to run in a faked test server or will connect to the configured
// HOST_ADDR and PORT
func launchOrGetTestServer(t *testing.T) (theURL string, logs bytes.Buffer, closefn func() error) {
func launchOrGetTestServer(t *testing.T) (theURL string, logs *lockbuffer.LockBuffer, closefn func() error) {
logs = lockbuffer.NewLockBuffer()
if os.Getenv("USE_LOCAL_helloworld") != "" {
host := os.Getenv("HOST_ADDR")
if strings.Contains(host, "helloworld.com") {
Expand All @@ -24,7 +25,7 @@ func launchOrGetTestServer(t *testing.T) (theURL string, logs bytes.Buffer, clos
return fmt.Sprintf("http://%s:%s", host, os.Getenv("PORT")), logs, func() error { return nil }
}

srv, err := newTestServer(&logs)
srv, err := newTestServer(WithLogWriter(logs))
require.NoError(t, err)
return fmt.Sprintf("http://localhost:%d", srv.Port()), logs, srv.Close
}
Expand All @@ -38,12 +39,12 @@ func TestSomething(t *testing.T) {
defer dumpLogsOnFailure(t, logs)

// call the server at theURL. Inspect logs.
fmt.Sprintf(theURL)
fmt.Sprintf(logs.String())
fmt.Printf(theURL)
fmt.Printf(logs.String())

}

func dumpLogsOnFailure(t *testing.T, logBuf bytes.Buffer) {
func dumpLogsOnFailure(t *testing.T, logBuf *lockbuffer.LockBuffer) {
if t.Failed() {
fmt.Printf("\nServer Log Dump:\n%s\n", logBuf.String())
}
Expand Down
Loading

0 comments on commit bf7ae6a

Please sign in to comment.