diff --git a/memcache/memcache.go b/memcache/memcache.go index b98a7653..a204d3c9 100644 --- a/memcache/memcache.go +++ b/memcache/memcache.go @@ -70,7 +70,20 @@ const ( // DefaultMaxIdleConns is the default maximum number of idle connections // kept for any single address. - DefaultMaxIdleConns = 2 + DefaultMaxIdleConns = 100 + + // DefaultMinIdleConns is the default minimum number of idle connections + // kept for any single address. + DefaultMinIdleConns = 2 + + // DefaultIdleTimeout is the time after which the freelist is scanned and + // idle connections are closed. Since the connections are scanned with the + // same period as their TTL, they may live for up to twice this duration. + DefaultIdleTimeout = 60 * time.Second + + // DefaultMaxOpenConns is the maximum number of active and idle connections + // for any single address. + DefaultMaxOpenConns = 200 ) const buffered = 8 // arbitrary buffered channel size, for readability @@ -126,7 +139,7 @@ func New(server ...string) *Client { // NewFromSelector returns a new Client using the provided ServerSelector. func NewFromSelector(ss ServerSelector) *Client { - return &Client{selector: ss} + return &Client{MinIdleConns: -1, selector: ss, openconn: make(map[string]chan struct{})} } // Client is a memcache client. @@ -144,10 +157,23 @@ type Client struct { // be set to a number higher than your peak parallel requests. MaxIdleConns int + // MinIdleConns is the minimum number of idle connections to keep in the freelist. + // -1 indicates that DefaultMinIdleConns should be used. + MinIdleConns int + + // IdleTimeout specifies the duration an Idle connection will remain unclosed. + IdleTimeout time.Duration + + // MaxOpenConns is the maximum number of active and idle connections per address. + MaxOpenConns int + selector ServerSelector lk sync.Mutex freeconn map[string][]*conn + // openconn is a semaphore for limiting connections in use. + openconn map[string]chan struct{} + bg *time.Timer } // Item is an item to be got or stored in a memcached server. @@ -177,32 +203,51 @@ type conn struct { rw *bufio.ReadWriter addr net.Addr c *Client -} - -// release returns this connection back to the client's free pool -func (cn *conn) release() { - cn.c.putFreeConn(cn.addr, cn) + t time.Time } func (cn *conn) extendDeadline() { cn.nc.SetDeadline(time.Now().Add(cn.c.netTimeout())) } +func (c *Client) releaseOpenConn(addr net.Addr) { + c.lk.Lock() + c.releaseOpenConnLocked(addr) + c.lk.Unlock() +} + // condRelease releases this connection if the error pointed to by err // is nil (not an error) or is only a protocol level error (e.g. a // cache miss). The purpose is to not recycle TCP connections that // are bad. func (cn *conn) condRelease(err *error) { + c := cn.c + c.lk.Lock() if *err == nil || resumableError(*err) { - cn.release() + c.putFreeConn(cn.addr, cn) } else { cn.nc.Close() } + c.releaseOpenConnLocked(cn.addr) + c.lk.Unlock() } -func (c *Client) putFreeConn(addr net.Addr, cn *conn) { +func (c *Client) acquireOpenConn(addr net.Addr) { c.lk.Lock() - defer c.lk.Unlock() + s := c.openconn[addr.String()] + if s == nil { + s = make(chan struct{}, c.maxOpenConns()) + c.openconn[addr.String()] = s + } + c.lk.Unlock() + s <- struct{}{} +} + +func (c *Client) releaseOpenConnLocked(addr net.Addr) { + <-c.openconn[addr.String()] +} + +func (c *Client) putFreeConn(addr net.Addr, cn *conn) { if c.freeconn == nil { c.freeconn = make(map[string][]*conn) } @@ -214,6 +259,44 @@ func (c *Client) putFreeConn(addr net.Addr, cn *conn) { c.freeconn[addr.String()] = append(freelist, cn) } +func (c *Client) cleanup() { + c.lk.Lock() + defer c.lk.Unlock() + if c.freeconn == nil { + return + } + new_freeconn := make(map[string][]*conn) + timeout := time.Now().Add(-c.idleTimeout()) + for k, freelist := range c.freeconn { + if len(freelist) <= c.minIdleConns() { + if len(freelist) > 0 { + new_freeconn[k] = freelist + } + continue + } + // freelist is sorted descending in time, so loop to find the + // index to keep. + i := 0 + for ; i < len(freelist)-c.minIdleConns(); i++ { + cn := freelist[i] + if cn.t.After(timeout) { + break + } + cn.nc.Close() + } + if i < len(freelist) { + new_freeconn[k] = freelist[i:] + } + } + if len(new_freeconn) > 0 { + c.freeconn = new_freeconn + c.bg = time.AfterFunc(c.idleTimeout(), c.cleanup) + } else { + c.freeconn = nil + c.bg = nil + } +} + func (c *Client) getFreeConn(addr net.Addr) (cn *conn, ok bool) { c.lk.Lock() defer c.lk.Unlock() @@ -243,6 +326,27 @@ func (c *Client) maxIdleConns() int { return DefaultMaxIdleConns } +func (c *Client) minIdleConns() int { + if c.MinIdleConns >= 0 { + return c.MinIdleConns + } + return DefaultMinIdleConns +} + +func (c *Client) idleTimeout() time.Duration { + if c.IdleTimeout > 0 { + return c.IdleTimeout + } + return DefaultIdleTimeout +} + +func (c *Client) maxOpenConns() int { + if c.MaxOpenConns > 0 { + return c.MaxOpenConns + } + return DefaultMaxOpenConns +} + // ConnectTimeoutError is the error type used when it takes // too long to connect to the desired host. This level of // detail can generally be ignored. @@ -273,6 +377,7 @@ func (c *Client) dial(addr net.Addr) (net.Conn, error) { } func (c *Client) getConn(addr net.Addr) (*conn, error) { + c.acquireOpenConn(addr) cn, ok := c.getFreeConn(addr) if ok { cn.extendDeadline() @@ -280,6 +385,7 @@ func (c *Client) getConn(addr net.Addr) (*conn, error) { } nc, err := c.dial(addr) if err != nil { + c.releaseOpenConn(addr) return nil, err } cn = &conn{ diff --git a/memcache/memcache_test.go b/memcache/memcache_test.go index 4b52a911..d81327d4 100644 --- a/memcache/memcache_test.go +++ b/memcache/memcache_test.go @@ -22,10 +22,12 @@ import ( "fmt" "io" "io/ioutil" + "log" "net" "os" "os/exec" "strings" + "sync" "testing" "time" ) @@ -49,17 +51,13 @@ func TestLocalhost(t *testing.T) { testWithClient(t, New(testServer)) } -// Run the memcached binary as a child process and connect to its unix socket. -func TestUnixSocket(t *testing.T) { - sock := fmt.Sprintf("/tmp/test-gomemcache-%d.sock", os.Getpid()) +func MakeUnixSocketMemcached(t *testing.T, tag string) (string, *exec.Cmd, error) { + sock := fmt.Sprintf("/tmp/test-gomemcache-%d%s.sock", tag, os.Getpid()) cmd := exec.Command("memcached", "-s", sock) if err := cmd.Start(); err != nil { t.Skipf("skipping test; couldn't find memcached") - return + return "", nil, err } - defer cmd.Wait() - defer cmd.Process.Kill() - // Wait a bit for the socket to appear. for i := 0; i < 10; i++ { if _, err := os.Stat(sock); err == nil { @@ -67,7 +65,17 @@ func TestUnixSocket(t *testing.T) { } time.Sleep(time.Duration(25*i) * time.Millisecond) } + return sock, cmd, nil +} +// Run the memcached binary as a child process and connect to its unix socket. +func TestUnixSocket(t *testing.T) { + sock, cmd, err := MakeUnixSocketMemcached(t, "") + if err != nil { + return + } + defer cmd.Wait() + defer cmd.Process.Kill() testWithClient(t, New(sock)) } @@ -258,6 +266,120 @@ func testTouchWithClient(t *testing.T, c *Client) { } } +// stallProxy blocks all connections until a caller unblocks it. +type stallProxy struct { + l *net.UnixListener + stallCh chan struct{} + sock string + mu sync.Mutex + waiting int +} + +func newStallProxy(addr string) (*stallProxy, error) { + sock := fmt.Sprintf("/tmp/test-stall-gomemcache-%d.sock", os.Getpid()) + laddr, err := net.ResolveUnixAddr("unix", sock) + if err != nil { + return nil, err + } + l, err := net.ListenUnix("unix", laddr) + if err != nil { + return nil, err + } + s := &stallProxy{ + stallCh: make(chan struct{}, 1000), + l: l, + sock: sock, + } + go s.listenLoop(addr) + return s, nil +} + +func (s *stallProxy) listenLoop(addr string) { + laddr, err := net.ResolveUnixAddr("unix", addr) + if err != nil { + log.Fatalf("ResolveUnixAddr(%s) failed: %v", addr, err) + } + for { + lConn, err := s.l.AcceptUnix() + if err != nil { + log.Fatalf("AcceptUnix(%s) failed: %v", addr, err) + } + rConn, err := net.DialUnix("unix", nil, laddr) + if err != nil { + log.Fatalf("Dial(%s) failed: %v", addr, err) + } + go func() { // Send loop + io.Copy(rConn, lConn) + rConn.CloseWrite() + lConn.CloseRead() + }() + go func() { // Receive loop + s.mu.Lock() + s.waiting++ + s.mu.Unlock() + <-s.stallCh + io.Copy(lConn, rConn) + rConn.CloseRead() + lConn.CloseWrite() + }() + } +} + +func (s *stallProxy) Close() { + s.l.Close() +} + +func (s *stallProxy) Unstall() { + close(s.stallCh) + return + s.mu.Lock() + for s.waiting > 0 { + s.stallCh <- struct{}{} + s.waiting-- + } + s.mu.Unlock() +} + +func (s *stallProxy) Waiting() int { + s.mu.Lock() + defer s.mu.Unlock() + return s.waiting +} + +// Run a stallProxy in front of a unix domain socket to memacached. +func TestMaxConn(t *testing.T) { + sock, cmd, err := MakeUnixSocketMemcached(t, "-max") + if err != nil { + return + } + defer cmd.Wait() + defer cmd.Process.Kill() + + stallp, err := newStallProxy(sock) + c := New(stallp.sock) + n := 10 + c.MaxOpenConns = n + c.Timeout = 1 * time.Second + mustSet := mustSetF(t, c) + // Start 2 * MaxOpenConns operations. + var wg sync.WaitGroup + for i := 0; i < n*2; i++ { + wg.Add(1) + go func(a int) { + mustSet(&Item{Key: "foo", Value: []byte("42")}) + wg.Done() + }(i) + } + // Wait and verify n are queued. + time.Sleep(c.Timeout / 2) + if stallp.Waiting() != n { + t.Fatalf("Incorrect number of connections waiting: %d\n", stallp.Waiting()) + } + stallp.Unstall() + // Ensure that they all complete successfully. + wg.Wait() +} + func BenchmarkOnItem(b *testing.B) { fakeServer, err := net.Listen("tcp", "localhost:0") if err != nil {