Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
- Move variables around
- Add timeout before dropping requests. This prevents blocking on the `WriteTo` function
  • Loading branch information
julienduchesne committed Oct 9, 2024
1 parent e241903 commit 9155c11
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 10 deletions.
27 changes: 22 additions & 5 deletions kv/memberlist/tcp_transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type TCPTransportConfig struct {
// Maximum number of concurrent writes to other nodes.
MaxConcurrentWrites int `yaml:"max_concurrent_writes" category:"advanced"`

// Timeout for acquiring one of the concurrent write slots.
AcquireWriterTimeout time.Duration `yaml:"acquire_writer_timeout" category:"advanced"`

// Transport logs lots of messages at debug level, so it deserves an extra flag for turning it on
TransportDebug bool `yaml:"-" category:"advanced"`

Expand All @@ -76,6 +79,7 @@ func (cfg *TCPTransportConfig) RegisterFlagsWithPrefix(f *flag.FlagSet, prefix s
f.DurationVar(&cfg.PacketDialTimeout, prefix+"memberlist.packet-dial-timeout", 2*time.Second, "Timeout used when connecting to other nodes to send packet.")
f.DurationVar(&cfg.PacketWriteTimeout, prefix+"memberlist.packet-write-timeout", 5*time.Second, "Timeout for writing 'packet' data.")
f.IntVar(&cfg.MaxConcurrentWrites, prefix+"memberlist.max-concurrent-writes", 3, "Maximum number of concurrent writes to other nodes.")
f.DurationVar(&cfg.AcquireWriterTimeout, prefix+"memberlist.acquire-writer-timeout", 250*time.Millisecond, "Timeout for acquiring one of the concurrent write slots. After this time, the message will be dropped.")
f.BoolVar(&cfg.TransportDebug, prefix+"memberlist.transport-debug", false, "Log debug transport messages. Note: global log.level must be at debug level as well.")

f.BoolVar(&cfg.TLSEnabled, prefix+"memberlist.tls-enabled", false, "Enable TLS on the memberlist transport layer.")
Expand All @@ -99,11 +103,11 @@ type TCPTransport struct {
tcpListeners []net.Listener
tlsConfig *tls.Config

writeCh chan writeRequest
writeWG sync.WaitGroup

shutdown bool
shutdownMu sync.RWMutex
shutdown bool
writeCh chan writeRequest // this channel is protected by shutdownMu

writeWG sync.WaitGroup

advertiseMu sync.RWMutex
advertiseAddr string
Expand Down Expand Up @@ -454,7 +458,20 @@ func (t *TCPTransport) WriteTo(b []byte, addr string) (time.Time, error) {
if t.shutdown {
return time.Time{}, errors.New("transport is shutting down")
}
t.writeCh <- writeRequest{b: b, addr: addr}

// Send the packet to the write workers
// If this blocks for too long (as configured), abort and log an error.
select {
case <-time.After(t.cfg.AcquireWriterTimeout):
level.Warn(t.logger).Log("msg", "WriteTo failed to acquire a writer. Dropping message", "timeout", t.cfg.AcquireWriterTimeout, "addr", addr)
t.sentPacketsErrors.Inc()
// WriteTo is used to send "UDP" packets. Since we use TCP, we can detect more errors,
// but memberlist library doesn't seem to cope with that very well. That is why we return nil instead.
return time.Now(), nil
case t.writeCh <- writeRequest{b: b, addr: addr}:
// OK
}

return time.Now(), nil
}

Expand Down
41 changes: 36 additions & 5 deletions kv/memberlist/tcp_transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package memberlist
import (
"net"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -78,10 +79,7 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) {
writeCt := 50

// Listen for TCP connections on a random port
freePorts, err := getFreePorts(1)
require.NoError(t, err)
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: freePorts[0]}
listener, err := net.ListenTCP("tcp", addr)
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

Expand All @@ -107,7 +105,7 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) {
timeStart := time.Now()

for i := 0; i < writeCt; i++ {
_, err = transport.WriteTo([]byte("test"), addr.String())
_, err = transport.WriteTo([]byte("test"), listener.Addr().String())
require.NoError(t, err)
}

Expand All @@ -119,6 +117,39 @@ func TestTCPTransportWriteToUnreachableAddr(t *testing.T) {
assert.LessOrEqual(t, time.Since(timeStart), 2*time.Second, "expected to take less than 2s (timeout + a good margin), writing to unreachable addresses should not block")
}

func TestTCPTransportWriterAcquireTimeout(t *testing.T) {
// Listen for TCP connections on a random port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

logs := &concurrency.SyncBuffer{}
logger := log.NewLogfmtLogger(logs)

cfg := TCPTransportConfig{}
flagext.DefaultValues(&cfg)
cfg.MaxConcurrentWrites = 1
cfg.AcquireWriterTimeout = 1 * time.Millisecond // very short timeout
transport, err := NewTCPTransport(cfg, logger, nil)
require.NoError(t, err)

writeCt := 100
var reqWg sync.WaitGroup
for i := 0; i < writeCt; i++ {
reqWg.Add(1)
go func() {
defer reqWg.Done()
transport.WriteTo([]byte("test"), listener.Addr().String()) // nolint:errcheck
}()
}
reqWg.Wait()

require.NoError(t, transport.Shutdown())
gotErrorCt := strings.Count(logs.String(), "WriteTo failed to acquire a writer. Dropping message")
assert.Less(t, gotErrorCt, writeCt, "expected to have less errors (%d) than total writes (%d). Some writes should pass.", gotErrorCt, writeCt)
assert.NotZero(t, gotErrorCt, "expected errors, got none")
}

func TestFinalAdvertiseAddr(t *testing.T) {
tests := map[string]struct {
advertiseAddr string
Expand Down

0 comments on commit 9155c11

Please sign in to comment.