Skip to content

Commit

Permalink
Merge pull request #30 from acoshift/graceful/remove-before
Browse files Browse the repository at this point in the history
update listen and serve
  • Loading branch information
acoshift authored May 18, 2018
2 parents 014a2e6 + 67d6b51 commit 0ea8cb5
Show file tree
Hide file tree
Showing 12 changed files with 288 additions and 46 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ func main() {
BeforeRender(addHeaderRender).
Handler(router(app)).
GracefulShutdown().
ListenAndServe(":8080")
Address(":8080").
ListenAndServe()
}

func router(app *hime.App) http.Handler {
Expand Down
37 changes: 23 additions & 14 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ import (

// App is the hime app
type App struct {
// Addr is server address
Addr string

// TLSConfig overrides http.Server TLSConfig
TLSConfig *tls.Config

Expand Down Expand Up @@ -67,6 +70,12 @@ func New() *App {
return &App{}
}

// Address sets server address
func (app *App) Address(addr string) *App {
app.Addr = addr
return app
}

// Handler sets the handler
func (app *App) Handler(h http.Handler) *App {
app.handler = h
Expand All @@ -87,7 +96,8 @@ func (app *App) ServeHTTP(w http.ResponseWriter, r *http.Request) {
app.handler.ServeHTTP(w, r)
}

func (app *App) configServer(addr string) {
func (app *App) configServer() {
app.srv.Addr = app.Addr
app.srv.TLSConfig = app.TLSConfig
app.srv.ReadTimeout = app.ReadTimeout
app.srv.ReadHeaderTimeout = app.ReadHeaderTimeout
Expand All @@ -98,46 +108,45 @@ func (app *App) configServer(addr string) {
app.srv.ConnState = app.ConnState
app.srv.ErrorLog = app.ErrorLog
app.srv.Handler = app
app.srv.Addr = addr
}

func (app *App) listenAndServe(addr string) error {
app.configServer(addr)
func (app *App) listenAndServe() error {
app.configServer()

return app.srv.ListenAndServe()
}

func (app *App) listenAndServeTLS(addr, certFile, keyFile string) error {
app.configServer(addr)
func (app *App) listenAndServeTLS(certFile, keyFile string) error {
app.configServer()

return app.srv.ListenAndServeTLS(certFile, keyFile)
}

// ListenAndServe starts web server
func (app *App) ListenAndServe(addr string) error {
func (app *App) ListenAndServe() error {
if app.gracefulShutdown != nil {
return app.GracefulShutdown().ListenAndServe(addr)
return app.GracefulShutdown().ListenAndServe()
}

return app.listenAndServe(addr)
return app.listenAndServe()
}

// ListenAndServeTLS starts web server in tls mode
func (app *App) ListenAndServeTLS(addr, certFile, keyFile string) error {
func (app *App) ListenAndServeTLS(certFile, keyFile string) error {
if app.gracefulShutdown != nil {
return app.GracefulShutdown().ListenAndServeTLS(addr, certFile, keyFile)
return app.GracefulShutdown().ListenAndServeTLS(certFile, keyFile)
}

return app.listenAndServeTLS(addr, certFile, keyFile)
return app.listenAndServeTLS(certFile, keyFile)
}

// GracefulShutdown returns graceful shutdown server
func (app *App) GracefulShutdown() *GracefulShutdown {
func (app *App) GracefulShutdown() *GracefulShutdownApp {
if app.gracefulShutdown == nil {
app.gracefulShutdown = &gracefulShutdown{}
}

return &GracefulShutdown{
return &GracefulShutdownApp{
App: app,
gracefulShutdown: app.gracefulShutdown,
}
Expand Down
12 changes: 7 additions & 5 deletions app_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,23 @@ func TestApp(t *testing.T) {
gs.Wait(5 * time.Second)
assert.Equal(t, 5*time.Second, gs.wait)

gs.Before(func() {})
gs.Before(func() {})
assert.Len(t, gs.beforeFns, 2)

gs.Notify(func() {})
gs.Notify(func() {})
gs.Notify(func() {})
assert.Len(t, gs.notiFns, 3)
})

t.Run("Address", func(t *testing.T) {
app.Address(":1234")
assert.Equal(t, ":1234", app.Addr)
})
}

func TestConfigServer(t *testing.T) {
t.Parallel()

app := &App{
Addr: ":8080",
TLSConfig: &tls.Config{},
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 6 * time.Second,
Expand All @@ -79,7 +81,7 @@ func TestConfigServer(t *testing.T) {
}

assert.Empty(t, &app.srv)
app.configServer(":8080")
app.configServer()
assert.NotEmpty(t, &app.srv)
assert.Equal(t, ":8080", app.srv.Addr)
}
31 changes: 31 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package hime_test

import (
"net/http"
"testing"

"github.com/acoshift/hime"
"github.com/stretchr/testify/assert"
)

func TestHandler(t *testing.T) {
t.Parallel()

t.Run("net/http", func(t *testing.T) {
app := hime.New().
Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}))

assert.HTTPBodyContains(t, app.ServeHTTP, "GET", "/", nil, "ok")
})

t.Run("hime", func(t *testing.T) {
app := hime.New().
Handler(hime.H(func(ctx hime.Context) hime.Result {
return ctx.String("ok")
}))

assert.HTTPBodyContains(t, app.ServeHTTP, "GET", "/", nil, "ok")
})
}
4 changes: 4 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ type AppConfig struct {
Routes map[string]string `yaml:"routes" json:"routes"`
Templates []TemplateConfig `yaml:"templates" json:"templates"`
Server struct {
Addr string `yaml:"addr" json:"addr"`
ReadTimeout string `yaml:"readTimeout" json:"readTimeout"`
ReadHeaderTimeout string `yaml:"readHeaderTimeout" json:"readHeaderTimeout"`
WriteTimeout string `yaml:"writeTimeout" json:"writeTimeout"`
Expand Down Expand Up @@ -74,6 +75,9 @@ func (app *App) Config(config AppConfig) *App {
}

// load server config
if config.Server.Addr != "" {
app.Addr = config.Server.Addr
}
parseDuration(config.Server.ReadTimeout, &app.ReadTimeout)
parseDuration(config.Server.ReadHeaderTimeout, &app.ReadHeaderTimeout)
parseDuration(config.Server.WriteTimeout, &app.WriteTimeout)
Expand Down
1 change: 1 addition & 0 deletions config_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestConfig(t *testing.T) {
assert.Contains(t, app.template, "main2")

// server
assert.Equal(t, ":8080", app.Addr)
assert.Equal(t, 10*time.Second, app.ReadTimeout)
assert.Equal(t, 5*time.Second, app.ReadHeaderTimeout)
assert.Equal(t, 6*time.Second, app.WriteTimeout)
Expand Down
3 changes: 2 additions & 1 deletion examples/middlewarevalue/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import (
func main() {
err := hime.New().
Handler(router()).
ListenAndServe(":8080")
Address(":8080").
ListenAndServe()
if err != nil {
log.Fatal(err)
}
Expand Down
3 changes: 2 additions & 1 deletion examples/net-http/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@ func main() {
}).
BeforeRender(addHeaderRender).
Handler(router(app)).
Address(":8080").
GracefulShutdown().
Notify(func() {
log.Println("Received terminate signal")
}).
Wait(5 * time.Second).
Timeout(5 * time.Second).
ListenAndServe(":8080")
ListenAndServe()
if err != nil {
log.Fatal(err)
}
Expand Down
80 changes: 56 additions & 24 deletions graceful.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,56 +5,60 @@ import (
"net/http"
"os"
"os/signal"
"sync"
"syscall"
"time"
)

// GracefulShutdown is the app in graceful shutdown mode
type GracefulShutdown struct {
// GracefulShutdownApp is the app in graceful shutdown mode
type GracefulShutdownApp struct {
*gracefulShutdown

App *App
}

type gracefulShutdown struct {
timeout time.Duration
wait time.Duration
notiFns []func()
beforeFns []func()
timeout time.Duration
wait time.Duration
notiFns []func()
}

// Address sets server address
func (gs *GracefulShutdownApp) Address(addr string) *GracefulShutdownApp {
gs.App.Addr = addr
return gs
}

// Timeout sets shutdown timeout for graceful shutdown,
// set to 0 to disable timeout
//
// default is 0
func (gs *GracefulShutdown) Timeout(d time.Duration) *GracefulShutdown {
func (gs *GracefulShutdownApp) Timeout(d time.Duration) *GracefulShutdownApp {
gs.timeout = d
return gs
}

// Wait sets wait time before shutdown
func (gs *GracefulShutdown) Wait(d time.Duration) *GracefulShutdown {
func (gs *GracefulShutdownApp) Wait(d time.Duration) *GracefulShutdownApp {
gs.wait = d
return gs
}

// Notify calls fn when receive terminate signal from os
func (gs *GracefulShutdown) Notify(fn func()) *GracefulShutdown {
func (gs *GracefulShutdownApp) Notify(fn func()) *GracefulShutdownApp {
if fn != nil {
gs.notiFns = append(gs.notiFns, fn)
}
return gs
}

// Before runs fn before start waiting to SIGTERM
func (gs *GracefulShutdown) Before(fn func()) *GracefulShutdown {
if fn != nil {
gs.beforeFns = append(gs.beforeFns, fn)
}
// OnShutdown calls server.RegisterOnShutdown(fn)
func (gs *GracefulShutdownApp) OnShutdown(fn func()) *GracefulShutdownApp {
gs.App.srv.RegisterOnShutdown(fn)
return gs
}

func (gs *GracefulShutdown) start(listenAndServe func() error) (err error) {
func (gs *GracefulShutdownApp) start(listenAndServe func() error) (err error) {
serverCtx, cancelServer := context.WithCancel(context.Background())
defer cancelServer()
go func() {
Expand All @@ -63,10 +67,6 @@ func (gs *GracefulShutdown) start(listenAndServe func() error) (err error) {
}
}()

for _, fn := range gs.beforeFns {
fn()
}

stop := make(chan os.Signal, 1)
signal.Notify(stop, syscall.SIGTERM)

Expand All @@ -75,7 +75,7 @@ func (gs *GracefulShutdown) start(listenAndServe func() error) (err error) {
return
case <-stop:
for _, fn := range gs.notiFns {
fn()
go fn()
}
if gs.wait > 0 {
time.Sleep(gs.wait)
Expand All @@ -93,11 +93,43 @@ func (gs *GracefulShutdown) start(listenAndServe func() error) (err error) {
}

// ListenAndServe starts web server in graceful shutdown mode
func (gs *GracefulShutdown) ListenAndServe(addr string) error {
return gs.start(func() error { return gs.App.listenAndServe(addr) })
func (gs *GracefulShutdownApp) ListenAndServe() error {
return gs.start(gs.App.listenAndServe)
}

// ListenAndServeTLS starts web server in graceful shutdown and tls mode
func (gs *GracefulShutdown) ListenAndServeTLS(addr, certFile, keyFile string) error {
return gs.start(func() error { return gs.App.listenAndServeTLS(addr, certFile, keyFile) })
func (gs *GracefulShutdownApp) ListenAndServeTLS(certFile, keyFile string) error {
return gs.start(func() error { return gs.App.listenAndServeTLS(certFile, keyFile) })
}

// GracefulShutdown runs multiple hime's app in graceful shutdown mode
func GracefulShutdown(apps []*App) error {
wg := &sync.WaitGroup{}
errChan := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

for _, app := range apps {
app := app
wg.Add(1)
go func() {
err := app.GracefulShutdown().ListenAndServe()
if err != http.ErrServerClosed {
errChan <- err
}
wg.Done()
}()
}

go func() {
wg.Wait()
cancel()
}()

select {
case err := <-errChan:
return err
case <-ctx.Done():
return nil
}
}
Loading

0 comments on commit 0ea8cb5

Please sign in to comment.