Skip to content

Commit

Permalink
tonic: bind keys from gin.Context
Browse files Browse the repository at this point in the history
  • Loading branch information
wwwxu committed Dec 28, 2022
1 parent c4f8b2f commit 3a965f9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tonic/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ func Handler(h interface{}, status int, options ...func(*Route)) gin.HandlerFunc
handleError(c, err)
return
}
// Bind context-keys
if err := bind(c, input, ContextTag, extractContext); err != nil {
handleError(c, err)
return
}
// validating query and path inputs if they have a validate tag
initValidator()
args = append(args, input)
Expand Down
21 changes: 21 additions & 0 deletions tonic/tonic.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const (
QueryTag = "query"
PathTag = "path"
HeaderTag = "header"
ContextTag = "context"
EnumTag = "enum"
RequiredTag = "required"
DefaultTag = "default"
Expand Down Expand Up @@ -345,6 +346,26 @@ func extractHeader(c *gin.Context, tag string) (string, []string, error) {
return name, []string{header}, nil
}

// extractContext is an extractor that operates on the gin.Context
// of a request.
func extractContext(c *gin.Context, tag string) (string, []string, error) {
name, required, defaultVal, err := parseTagKey(tag)
if err != nil {
return "", nil, err
}
context := c.GetString(name)

// XXX: deprecated, use of "default" tag is preferred
if context == "" && defaultVal != "" {
return name, []string{defaultVal}, nil
}
// XXX: deprecated, use of "validate" tag is preferred
if required && context == "" {
return "", nil, fmt.Errorf("missing header parameter: %s", name)
}
return name, []string{context}, nil
}

// Public signature does not expose "required" and "default" because
// they are deprecated in favor of the "validate" and "default" tags
func parseTagKey(tag string) (string, bool, string, error) {
Expand Down
38 changes: 38 additions & 0 deletions tonic/tonic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ func TestMain(m *testing.M) {
g.GET("/query", tonic.Handler(queryHandler, 200))
g.GET("/query-old", tonic.Handler(queryHandlerOld, 200))
g.POST("/body", tonic.Handler(bodyHandler, 200))
g.GET("/context", tonic.Handler(contextHandler, 200))

// for context test
g.Use(func(c *gin.Context) {
if c.FullPath() == "/context" {
if val, ok := c.GetQuery("param"); ok {
c.Set("param", val)
}
if val, ok := c.GetQuery("param-optional"); ok {
c.Set("param-optional", val)
}
if val, ok := c.GetQuery("param-optional-validated"); ok {
c.Set("param-optional-validated", val)
}
}
c.Next()
})

r = g

Expand Down Expand Up @@ -130,6 +147,17 @@ func TestBody(t *testing.T) {
tester.Run()
}

func TestContext(t *testing.T) {
tester := iffy.NewTester(t, r)

tester.AddCall("context", "GET", "/context?param=foo", ``).Checkers(iffy.ExpectStatus(200), expectString("param", "foo"))
tester.AddCall("context", "GET", "/context?param=foo", ``).Checkers(iffy.ExpectStatus(400))
tester.AddCall("context", "GET", "/context?param=foo&param-optional=bar", ``).Checkers(iffy.ExpectStatus(200), expectString("param-optional", "bar"))
tester.AddCall("context", "GET", "/context?param=foo&param-optional-validated=foo", ``).Checkers(iffy.ExpectStatus(200), expectString("param-optional-validated", "foo"))

tester.Run()
}

func errorHandler(c *gin.Context) error {
return errors.New("error")
}
Expand Down Expand Up @@ -199,6 +227,16 @@ func bodyHandler(c *gin.Context, in *bodyIn) (*bodyIn, error) {
return in, nil
}

type ContextIn struct {
Param string `context:"param" json:"param" validate:"required"`
ParamOptional string `context:"param-optional" json:"param-optional"`
ValidatedParamOptional string `context:"param-optional-validated" json:"param-optional-validated" validate:"eq=|eq=foo|gt=10"`
}

func contextHandler(c *gin.Context, in *ContextIn) (*ContextIn, error) {
return in, nil
}

func expectEmptyBody(r *http.Response, body string, obj interface{}) error {
if len(body) != 0 {
return fmt.Errorf("Body '%s' should be empty", body)
Expand Down

0 comments on commit 3a965f9

Please sign in to comment.