From b8df030533c143ab49750777c7c6930df6f0d54e Mon Sep 17 00:00:00 2001 From: jefftt Date: Tue, 24 Nov 2020 08:33:12 -0500 Subject: [PATCH 1/3] Support graceful shutdown --- diam/server.go | 253 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 253 insertions(+) diff --git a/diam/server.go b/diam/server.go index 44514bcc..3e233b10 100644 --- a/diam/server.go +++ b/diam/server.go @@ -9,12 +9,15 @@ package diam import ( "bufio" "crypto/tls" + "errors" "fmt" "io" "log" + "math/rand" "net" "runtime" "sync" + "sync/atomic" "time" "golang.org/x/net/context" @@ -87,6 +90,8 @@ type conn struct { tlsState *tls.ConnectionState // or nil when not using TLS writer *response // the diam.Conn exposed to handlers + curState struct{ atomic uint64 } // packed (unixtime<<8|uint8(ConnState)) + mu sync.Mutex // guards the following closeNotifyc chan struct{} clientGone bool @@ -137,6 +142,50 @@ func (c *conn) notifyClientGone() { } } +// A ConnState represents the state of a client connection to a server. +type ConnState int + +const ( + // StateNew represents a new connection that is expected to + // send a request immediately. Connections begin at this + // state and then transition to either StateActive or + // StateClosed. + StateNew ConnState = iota + + // StateActive represents a connection that has read 1 or more + // bytes of a request. + // After the request is handled, the state + // transitions to StateClosed, or StateIdle. + // For HTTP/2, StateActive fires on the transition from zero + // to one active request, and only transitions away once all + // active requests are complete. That means that ConnState + // cannot be used to do per-request work; ConnState only notes + // the overall state of the connection. + StateActive + + // StateIdle represents a connection that has finished + // handling a request and is in the keep-alive state, waiting + // for a new request. Connections transition from StateIdle + // to either StateActive or StateClosed. + StateIdle + + // StateClosed represents a closed connection. + // This is a terminal state. Hijacked connections do not + // transition to StateClosed. + StateClosed +) + +var stateName = map[ConnState]string{ + StateNew: "new", + StateActive: "active", + StateIdle: "idle", + StateClosed: "closed", +} + +func (c ConnState) String() string { + return stateName[c] +} + // Create new connection from rwc. func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { msc, isMulti := rwc.(MultistreamConn) @@ -157,6 +206,26 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { return c, nil } +func (c *conn) setState(state ConnState) { + srv := c.server + switch state { + case StateNew: + srv.trackConn(c, true) + case StateClosed: + srv.trackConn(c, false) + } + if state > 0xff || state < 0 { + panic("internal error") + } + packedState := uint64(time.Now().Unix()<<8) | uint64(state) + atomic.StoreUint64(&c.curState.atomic, packedState) +} + +func (c *conn) getState() (state ConnState, unixSec int64) { + packedState := atomic.LoadUint64(&c.curState.atomic) + return ConnState(packedState & 0xff), int64(packedState >> 8) +} + // Read next message from connection. func (c *conn) readMessage() (m *Message, err error) { if c.server.ReadTimeout > 0 { @@ -185,6 +254,7 @@ func (c *conn) serve() { c.rwc.RemoteAddr().String(), err, buf) } c.rwc.Close() + c.setState(StateClosed) }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { if err := tlsConn.Handshake(); err != nil { @@ -195,8 +265,10 @@ func (c *conn) serve() { } for { m, err := c.readMessage() + c.setState(StateActive) if err != nil { c.rwc.Close() + c.setState(StateClosed) // Report errors to the channel, except EOF. if err != io.EOF && err != io.ErrUnexpectedEOF { h := c.server.Handler @@ -211,6 +283,7 @@ func (c *conn) serve() { } // Handle messages in this goroutine. serverHandler{c.server}.ServeDIAM(c.writer, m) + c.setState(StateIdle) } } @@ -223,6 +296,12 @@ func (c *conn) dictionary() *dict.Parser { return c.server.Dict } +type atomicBool int32 + +func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } +func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } +func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } + // A response represents the server side of a diameter response. // It implements the Conn and CloseNotifier interfaces. type response struct { @@ -548,6 +627,10 @@ func ErrorReports() <-chan *ErrorReport { return DefaultServeMux.ErrorReports() } +// ErrServerClosed is returned by the Server's Serve, ListenAndServe, +// methods after a call to Shutdown or Close. +var ErrServerClosed = errors.New("diameter: Server closed") + // Serve accepts incoming diameter connections on the listener l, // creating a new service goroutine for each. The service goroutines // read messages and then call handler to reply to them. @@ -567,6 +650,11 @@ type Server struct { WriteTimeout time.Duration // maximum duration before timing out write of the response TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS LocalAddr net.Addr // optional Local Address to bind dailer's (Dail...) socket to + + inShutdown atomicBool // true when when server is in shutdown + mu sync.Mutex + listeners map[*net.Listener]struct{} + activeConn map[*conn]struct{} } // serverHandler delegates to either the server's Handler or DefaultServeMux. @@ -607,11 +695,22 @@ func (srv *Server) ListenAndServe() error { // new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. func (srv *Server) Serve(l net.Listener) error { + l = &onceCloseListener{Listener: l} defer l.Close() + + if !srv.trackListener(&l, true) { + return ErrServerClosed + } + defer srv.trackListener(&l, false) + var tempDelay time.Duration // how long to sleep on accept failure for { rw, e := l.Accept() if e != nil { + if srv.shuttingDown() { + return ErrServerClosed + } + if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond @@ -640,11 +739,150 @@ func (srv *Server) Serve(l net.Listener) error { log.Printf("srv.newConn error: %v", err) continue } else { + c.setState(StateNew) go c.serve() } } } +// shutdownPollIntervalMax is the max polling interval when checking +// quiescence during Server.Shutdown. Polling starts with a small +// interval and backs off to the max. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +const shutdownPollIntervalMax = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// +// When Shutdown is called, Serve, ListenAndServe, and +// ListenAndServeTLS immediately return ErrServerClosed. Make sure the +// program doesn't exit and waits instead for Shutdown to return. +// +// Once Shutdown has been called on a server, it may not be reused; +// future calls to methods such as Serve will return ErrServerClosed. +func (srv *Server) Shutdown() error { + srv.inShutdown.setTrue() + + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.mu.Unlock() + + pollIntervalBase := time.Millisecond + nextPollInterval := func() time.Duration { + // Add 10% jitter. + interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) + // Double and clamp for next time. + pollIntervalBase *= 2 + if pollIntervalBase > shutdownPollIntervalMax { + pollIntervalBase = shutdownPollIntervalMax + } + return interval + } + + timer := time.NewTimer(nextPollInterval()) + defer timer.Stop() + for { + if srv.closeIdleConns() && srv.numListeners() == 0 { + return lnerr + } + select { + case <-timer.C: + timer.Reset(nextPollInterval()) + } + } +} + +func (s *Server) numListeners() int { + s.mu.Lock() + defer s.mu.Unlock() + return len(s.listeners) +} + +// closeIdleConns closes all idle connections and reports whether the +// server is quiescent. +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.activeConn { + st, unixSec := c.getState() + // treat StateNew connections as if + // they're idle if we haven't read the first request's + // header in over 5 seconds. + if st == StateNew && unixSec < time.Now().Unix()-5 { + st = StateIdle + } + if st != StateIdle || unixSec == 0 { + // Assume unixSec == 0 means it's a very new + // connection, without state set yet. + quiescent = false + continue + } + c.rwc.Close() + delete(s.activeConn, c) + } + return quiescent +} + +// trackListener adds or removes a net.Listener to the set of tracked +// listeners. +// +// We store a pointer to interface in the map set, in case the +// net.Listener is not comparable. This is safe because we only call +// trackListener via Serve and can track+defer untrack the same +// pointer to local variable there. We never need to compare a +// Listener from another caller. +// +// It reports whether the server is still up (not Shutdown or Closed). +func (s *Server) trackListener(ln *net.Listener, add bool) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[*net.Listener]struct{}) + } + if add { + if s.shuttingDown() { + return false + } + s.listeners[ln] = struct{}{} + } else { + delete(s.listeners, ln) + } + return true +} + +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + +func (s *Server) shuttingDown() bool { + return s.inShutdown.isSet() +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } + } + return err +} + // ListenAndServeNetwork listens on the network & addr // and then calls Serve with handler to handle requests // on incoming connections. @@ -729,3 +967,18 @@ func ListenAndServeNetworkTLS(network, addr string, certFile string, keyFile str func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler, dp *dict.Parser) error { return ListenAndServeNetworkTLS("tcp", addr, certFile, keyFile, handler, dp) } + +// onceCloseListener wraps a net.Listener, protecting it from +// multiple Close calls. +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } From 38720e4ef2efaaab45a5e39b38e1b229af207ab2 Mon Sep 17 00:00:00 2001 From: jefftt Date: Wed, 25 Nov 2020 17:54:41 -0500 Subject: [PATCH 2/3] Fix comment --- diam/server.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/diam/server.go b/diam/server.go index 3e233b10..2495375f 100644 --- a/diam/server.go +++ b/diam/server.go @@ -156,11 +156,6 @@ const ( // bytes of a request. // After the request is handled, the state // transitions to StateClosed, or StateIdle. - // For HTTP/2, StateActive fires on the transition from zero - // to one active request, and only transitions away once all - // active requests are complete. That means that ConnState - // cannot be used to do per-request work; ConnState only notes - // the overall state of the connection. StateActive // StateIdle represents a connection that has finished From 60f8f43b76e6880930ee98080ac210eec870b168 Mon Sep 17 00:00:00 2001 From: jeff Date: Fri, 26 Feb 2021 09:09:26 -0500 Subject: [PATCH 3/3] change ConnState to int8, fix ConnState String() --- diam/server.go | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/diam/server.go b/diam/server.go index 2495375f..6f46b2e9 100644 --- a/diam/server.go +++ b/diam/server.go @@ -143,7 +143,7 @@ func (c *conn) notifyClientGone() { } // A ConnState represents the state of a client connection to a server. -type ConnState int +type ConnState int8 const ( // StateNew represents a new connection that is expected to @@ -178,7 +178,11 @@ var stateName = map[ConnState]string{ } func (c ConnState) String() string { - return stateName[c] + name, ok := stateName[c] + if !ok { + return "UNDEFINED" + } + return name } // Create new connection from rwc. @@ -209,7 +213,7 @@ func (c *conn) setState(state ConnState) { case StateClosed: srv.trackConn(c, false) } - if state > 0xff || state < 0 { + if state > 0xf || state < 0 { panic("internal error") } packedState := uint64(time.Now().Unix()<<8) | uint64(state)