diff --git a/context.go b/context.go index fa13bb8..756a3ea 100644 --- a/context.go +++ b/context.go @@ -8,10 +8,6 @@ import ( // NewContext creates new hime's context func NewContext(w http.ResponseWriter, r *http.Request) *Context { - return newInternalContext(w, r) -} - -func newInternalContext(w http.ResponseWriter, r *http.Request) *Context { app, ok := r.Context().Value(ctxKeyApp).(*App) if !ok { panic(ErrAppNotFound) diff --git a/handler.go b/handler.go index 0ec54bd..df3a63e 100644 --- a/handler.go +++ b/handler.go @@ -7,7 +7,7 @@ import ( // Wrap wraps hime handler with http.Handler func Wrap(h Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := newInternalContext(w, r) + ctx := NewContext(w, r) h(ctx).ServeHTTP(ctx.w, ctx.r) }) } diff --git a/request.go b/request.go index 57a3d8f..2d437ee 100644 --- a/request.go +++ b/request.go @@ -8,7 +8,12 @@ import ( "strings" ) -func trimComma(s string) string { +const ( + // defaultMaxMemory is http.defaultMaxMemory + defaultMaxMemory = 32 << 20 // 32 MB +) + +func removeComma(s string) string { return strings.Replace(s, ",", "", -1) } @@ -44,7 +49,7 @@ func (ctx *Context) FormValueTrimSpace(key string) string { // FormValueTrimSpaceComma trims space and remove comma from form value func (ctx *Context) FormValueTrimSpaceComma(key string) string { - return trimComma(strings.TrimSpace(ctx.FormValue(key))) + return removeComma(strings.TrimSpace(ctx.FormValue(key))) } // FormValueInt converts form value to int @@ -83,7 +88,7 @@ func (ctx *Context) PostFormValueTrimSpace(key string) string { // PostFormValueTrimSpaceComma trims space and remove comma from post form value func (ctx *Context) PostFormValueTrimSpaceComma(key string) string { - return trimComma(strings.TrimSpace(ctx.PostFormValue(key))) + return removeComma(strings.TrimSpace(ctx.PostFormValue(key))) } // PostFormValueInt converts post form value to int @@ -130,6 +135,36 @@ func (ctx *Context) FormFileNotEmpty(key string) (multipart.File, *multipart.Fil return file, header, err } +// FormFileHeader returns file header for given key without open file +func (ctx *Context) FormFileHeader(key string) (*multipart.FileHeader, error) { + // edit from http.Request.FormFile + if ctx.r.MultipartForm == nil { + err := ctx.r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, err + } + } + if ctx.r.MultipartForm != nil && ctx.r.MultipartForm.File != nil { + if fhs := ctx.r.MultipartForm.File[key]; len(fhs) > 0 { + return fhs[0], nil + } + } + return nil, http.ErrMissingFile +} + +// FormFileHeaderNotEmpty returns file header if not empty, +// or http.ErrMissingFile if file is empty +func (ctx *Context) FormFileHeaderNotEmpty(key string) (*multipart.FileHeader, error) { + fh, err := ctx.FormFileHeader(key) + if err != nil { + return nil, err + } + if fh.Size == 0 { + return nil, http.ErrMissingFile + } + return fh, nil +} + // MultipartForm returns r.MultipartForm func (ctx *Context) MultipartForm() *multipart.Form { return ctx.r.MultipartForm diff --git a/request_internal_test.go b/request_internal_test.go index e172905..8736cf9 100644 --- a/request_internal_test.go +++ b/request_internal_test.go @@ -6,7 +6,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestTrimComma(t *testing.T) { +func TestRemoveComma(t *testing.T) { t.Parallel() cases := []struct { @@ -21,6 +21,6 @@ func TestTrimComma(t *testing.T) { } for _, c := range cases { - assert.Equal(t, c.Output, trimComma(c.Input)) + assert.Equal(t, c.Output, removeComma(c.Input)) } } diff --git a/result_test.go b/result_test.go index c25fcf6..183c9e5 100644 --- a/result_test.go +++ b/result_test.go @@ -155,4 +155,59 @@ func TestResult(t *testing.T) { assert.Equal(t, http.StatusOK, w.Result().StatusCode) assert.Equal(t, "public, max-age=3600", w.Header().Get("Cache-Control")) }) + + t.Run("Error", func(t *testing.T) { + t.Parallel() + + app := hime.New(). + Handler(hime.H(func(ctx *hime.Context) hime.Result { + return ctx.Error("some error :P") + })) + + w := invokeHandler(app, "GET", "/", nil) + assert.Equal(t, http.StatusInternalServerError, w.Result().StatusCode) + assert.Equal(t, "some error :P\n", w.Body.String()) + }) + + t.Run("ErrorCustomStatusCode", func(t *testing.T) { + t.Parallel() + + app := hime.New(). + Handler(hime.H(func(ctx *hime.Context) hime.Result { + return ctx.Status(http.StatusNotFound).Error("some not found error :P") + })) + + w := invokeHandler(app, "GET", "/", nil) + assert.Equal(t, http.StatusNotFound, w.Result().StatusCode) + assert.Equal(t, "some not found error :P\n", w.Body.String()) + }) + + t.Run("RedirectTo", func(t *testing.T) { + t.Parallel() + + app := hime.New(). + Routes(hime.Routes{ + "route1": "/route/1", + }). + Handler(hime.H(func(ctx *hime.Context) hime.Result { + return ctx.RedirectTo("route1") + })) + + w := invokeHandler(app, "GET", "/", nil) + assert.Equal(t, http.StatusFound, w.Result().StatusCode) + l, err := w.Result().Location() + assert.NoError(t, err) + assert.Equal(t, "/route/1", l.String()) + }) + + t.Run("RedirectToUnknownRoute", func(t *testing.T) { + t.Parallel() + + app := hime.New(). + Handler(hime.H(func(ctx *hime.Context) hime.Result { + return ctx.RedirectTo("unknown") + })) + + assert.Panics(t, func() { invokeHandler(app, "GET", "/", nil) }) + }) }