Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graceful shutdown #148

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions diam/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ package diam
import (
"bufio"
"crypto/tls"
"errors"
"fmt"
"io"
"log"
"math/rand"
"net"
"runtime"
"sync"
"sync/atomic"
"time"

"golang.org/x/net/context"
Expand Down Expand Up @@ -87,6 +90,8 @@ type conn struct {
tlsState *tls.ConnectionState // or nil when not using TLS
writer *response // the diam.Conn exposed to handlers

curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState))

mu sync.Mutex // guards the following
closeNotifyc chan struct{}
clientGone bool
Expand Down Expand Up @@ -137,6 +142,49 @@ func (c *conn) notifyClientGone() {
}
}

// A ConnState represents the state of a client connection to a server.
type ConnState int8

const (
// StateNew represents a new connection that is expected to
// send a request immediately. Connections begin at this
// state and then transition to either StateActive or
// StateClosed.
StateNew ConnState = iota

// StateActive represents a connection that has read 1 or more
// bytes of a request.
// After the request is handled, the state
// transitions to StateClosed, or StateIdle.
StateActive

// StateIdle represents a connection that has finished
// handling a request and is in the keep-alive state, waiting
// for a new request. Connections transition from StateIdle
// to either StateActive or StateClosed.
StateIdle

// StateClosed represents a closed connection.
// This is a terminal state. Hijacked connections do not
// transition to StateClosed.
StateClosed
)

var stateName = map[ConnState]string{
StateNew: "new",
StateActive: "active",
StateIdle: "idle",
StateClosed: "closed",
}

func (c ConnState) String() string {
name, ok := stateName[c]
if !ok {
return "UNDEFINED"
}
return name
}

// Create new connection from rwc.
func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
msc, isMulti := rwc.(MultistreamConn)
Expand All @@ -157,6 +205,26 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
return c, nil
}

func (c *conn) setState(state ConnState) {
srv := c.server
switch state {
case StateNew:
srv.trackConn(c, true)
case StateClosed:
srv.trackConn(c, false)
}
if state > 0xf || state < 0 {
panic("internal error")
}
packedState := uint64(time.Now().Unix()<<8) | uint64(state)
atomic.StoreUint64(&c.curState.atomic, packedState)
}

func (c *conn) getState() (state ConnState, unixSec int64) {
packedState := atomic.LoadUint64(&c.curState.atomic)
return ConnState(packedState & 0xff), int64(packedState >> 8)
}

// Read next message from connection.
func (c *conn) readMessage() (m *Message, err error) {
if c.server.ReadTimeout > 0 {
Expand Down Expand Up @@ -185,6 +253,7 @@ func (c *conn) serve() {
c.rwc.RemoteAddr().String(), err, buf)
}
c.rwc.Close()
c.setState(StateClosed)
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
if err := tlsConn.Handshake(); err != nil {
Expand All @@ -195,8 +264,10 @@ func (c *conn) serve() {
}
for {
m, err := c.readMessage()
c.setState(StateActive)
if err != nil {
c.rwc.Close()
c.setState(StateClosed)
// Report errors to the channel, except EOF.
if err != io.EOF && err != io.ErrUnexpectedEOF {
h := c.server.Handler
Expand All @@ -211,6 +282,7 @@ func (c *conn) serve() {
}
// Handle messages in this goroutine.
serverHandler{c.server}.ServeDIAM(c.writer, m)
c.setState(StateIdle)
}
}

Expand All @@ -223,6 +295,12 @@ func (c *conn) dictionary() *dict.Parser {
return c.server.Dict
}

type atomicBool int32

func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }

// A response represents the server side of a diameter response.
// It implements the Conn and CloseNotifier interfaces.
type response struct {
Expand Down Expand Up @@ -548,6 +626,10 @@ func ErrorReports() <-chan *ErrorReport {
return DefaultServeMux.ErrorReports()
}

// ErrServerClosed is returned by the Server's Serve, ListenAndServe,
// methods after a call to Shutdown or Close.
var ErrServerClosed = errors.New("diameter: Server closed")

// Serve accepts incoming diameter connections on the listener l,
// creating a new service goroutine for each. The service goroutines
// read messages and then call handler to reply to them.
Expand All @@ -567,6 +649,11 @@ type Server struct {
WriteTimeout time.Duration // maximum duration before timing out write of the response
TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS
LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to

inShutdown atomicBool // true when when server is in shutdown
mu sync.Mutex
listeners map[*net.Listener]struct{}
activeConn map[*conn]struct{}
}

// serverHandler delegates to either the server's Handler or DefaultServeMux.
Expand Down Expand Up @@ -607,11 +694,22 @@ func (srv *Server) ListenAndServe() error {
// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
func (srv *Server) Serve(l net.Listener) error {
l = &onceCloseListener{Listener: l}
defer l.Close()

if !srv.trackListener(&l, true) {
return ErrServerClosed
}
defer srv.trackListener(&l, false)

var tempDelay time.Duration // how long to sleep on accept failure
for {
rw, e := l.Accept()
if e != nil {
if srv.shuttingDown() {
return ErrServerClosed
}

if ne, ok := e.(net.Error); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
Expand Down Expand Up @@ -640,11 +738,150 @@ func (srv *Server) Serve(l net.Listener) error {
log.Printf("srv.newConn error: %v", err)
continue
} else {
c.setState(StateNew)
go c.serve()
}
}
}

// shutdownPollIntervalMax is the max polling interval when checking
// quiescence during Server.Shutdown. Polling starts with a small
// interval and backs off to the max.
// Ideally we could find a solution that doesn't involve polling,
// but which also doesn't have a high runtime cost (and doesn't
// involve any contentious mutexes), but that is left as an
// exercise for the reader.
const shutdownPollIntervalMax = 500 * time.Millisecond

// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
// listeners, then closing all idle connections, and then waiting
// indefinitely for connections to return to idle and then shut down.
//
// When Shutdown is called, Serve, ListenAndServe, and
// ListenAndServeTLS immediately return ErrServerClosed. Make sure the
// program doesn't exit and waits instead for Shutdown to return.
//
// Once Shutdown has been called on a server, it may not be reused;
// future calls to methods such as Serve will return ErrServerClosed.
func (srv *Server) Shutdown() error {
srv.inShutdown.setTrue()

srv.mu.Lock()
lnerr := srv.closeListenersLocked()
srv.mu.Unlock()

pollIntervalBase := time.Millisecond
nextPollInterval := func() time.Duration {
// Add 10% jitter.
interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
// Double and clamp for next time.
pollIntervalBase *= 2
if pollIntervalBase > shutdownPollIntervalMax {
pollIntervalBase = shutdownPollIntervalMax
}
return interval
}

timer := time.NewTimer(nextPollInterval())
defer timer.Stop()
for {
if srv.closeIdleConns() && srv.numListeners() == 0 {
return lnerr
}
select {
case <-timer.C:
timer.Reset(nextPollInterval())
}
}
}

func (s *Server) numListeners() int {
s.mu.Lock()
defer s.mu.Unlock()
return len(s.listeners)
}

// closeIdleConns closes all idle connections and reports whether the
// server is quiescent.
func (s *Server) closeIdleConns() bool {
s.mu.Lock()
defer s.mu.Unlock()
quiescent := true
for c := range s.activeConn {
st, unixSec := c.getState()
// treat StateNew connections as if
// they're idle if we haven't read the first request's
// header in over 5 seconds.
if st == StateNew && unixSec < time.Now().Unix()-5 {
st = StateIdle
}
if st != StateIdle || unixSec == 0 {
// Assume unixSec == 0 means it's a very new
// connection, without state set yet.
quiescent = false
continue
}
c.rwc.Close()
delete(s.activeConn, c)
}
return quiescent
}

// trackListener adds or removes a net.Listener to the set of tracked
// listeners.
//
// We store a pointer to interface in the map set, in case the
// net.Listener is not comparable. This is safe because we only call
// trackListener via Serve and can track+defer untrack the same
// pointer to local variable there. We never need to compare a
// Listener from another caller.
//
// It reports whether the server is still up (not Shutdown or Closed).
func (s *Server) trackListener(ln *net.Listener, add bool) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.listeners == nil {
s.listeners = make(map[*net.Listener]struct{})
}
if add {
if s.shuttingDown() {
return false
}
s.listeners[ln] = struct{}{}
} else {
delete(s.listeners, ln)
}
return true
}

func (s *Server) trackConn(c *conn, add bool) {
s.mu.Lock()
defer s.mu.Unlock()
if s.activeConn == nil {
s.activeConn = make(map[*conn]struct{})
}
if add {
s.activeConn[c] = struct{}{}
} else {
delete(s.activeConn, c)
}
}

func (s *Server) shuttingDown() bool {
return s.inShutdown.isSet()
}

func (s *Server) closeListenersLocked() error {
var err error
for ln := range s.listeners {
if cerr := (*ln).Close(); cerr != nil && err == nil {
err = cerr
}
}
return err
}

// ListenAndServeNetwork listens on the network & addr
// and then calls Serve with handler to handle requests
// on incoming connections.
Expand Down Expand Up @@ -729,3 +966,18 @@ func ListenAndServeNetworkTLS(network, addr string, certFile string, keyFile str
func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler, dp *dict.Parser) error {
return ListenAndServeNetworkTLS("tcp", addr, certFile, keyFile, handler, dp)
}

// onceCloseListener wraps a net.Listener, protecting it from
// multiple Close calls.
type onceCloseListener struct {
net.Listener
once sync.Once
closeErr error
}

func (oc *onceCloseListener) Close() error {
oc.once.Do(oc.close)
return oc.closeErr
}

func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() }