Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support using protobuf types in query parameters #306

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions gengokit/httptransport/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ import (
"text/template"
"unicode"

log "github.com/sirupsen/logrus"
gogen "github.com/gogo/protobuf/protoc-gen-gogo/generator"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"

"github.com/metaverse/truss/gengokit/httptransport/templates"
"github.com/metaverse/truss/svcdef"
Expand Down Expand Up @@ -139,9 +139,9 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding {
LocalName: fmt.Sprintf("%s%s", gogen.CamelCase(field.Name), gogen.CamelCase(meth.Name)),
}

if field.Type.Message == nil && field.Type.Enum == nil && field.Type.Map == nil {
if field.Type.Message == nil && field.Type.Enum == nil && field.Type.Map == nil && !isProtobufType(field.Type.Name) && field.Type.Name != "time.Time" {
newField.IsBaseType = true
} else {
} else if !isProtobufType(field.Type.Name) && field.Type.Name != "time.Time" {
newField.GoType = "pb." + newField.GoType
}

Expand All @@ -165,15 +165,15 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding {
}

// Emit warnings for certain cases
if !newField.IsBaseType && newField.Location != "body" {
if !newField.IsBaseType && newField.Location != "body" && !isSafeNonBaseType(newField.GoType) {
log.Warnf(
"%s.%s is a non-base type specified to be located outside of "+
"the body. Non-base types outside the body may result in "+
"generated code which fails to compile.",
meth.Name,
newField.Name)
}
if newField.Repeated && newField.Location == "path" {
if newField.Repeated && newField.Location == "path" && !isSafeNonBaseType(newField.GoType) {
log.Warnf(
"%s.%s is a repeated field specified to be in the path. "+
"Repeated fields are not supported in the path and may"+
Expand All @@ -185,6 +185,24 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding {
return &nBinding
}

func isProtobufType(t string) bool {
if strings.HasPrefix(t, "types.") {
return true
}
return false
}

func isSafeNonBaseType(t string) bool {
switch t {
case "time.Time": // gogo stdtime case
return true
case "types.Timestamp", "types.Duration":
return true
default:
return false
}
}

func GenServerTemplate(exec interface{}) (string, error) {
code, err := ApplyTemplate("ServerTemplate", templates.ServerTemplate, exec, TemplateFuncs)
if err != nil {
Expand Down Expand Up @@ -381,6 +399,10 @@ func createDecodeConvertFunc(f Field) (string, bool) {
var {{.LocalName}} *{{.GoType}}
{{.LocalName}} = &{{.GoType}}{}
err = json.Unmarshal([]byte({{.LocalName}}Str), {{.LocalName}})`
singlePBTypeUnmarshalTmpl := `
var {{.LocalName}} *{{.GoType}}
{{.LocalName}} = &{{.GoType}}{}
err = jsonpb.UnmarshalString({{.LocalName}}Str, {{.LocalName}})`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating thses tailored decode functions for each individual field, how about we we first merge all the variables from path, query, and body into a single JSON string, then call jsonpb.Unmarshal(s, req) to decode into the request struct in one go? Seems like it would simplify the decode logic significantly?

// All repeated args of any type are represented as slices, and bare
// assignments to a slice accept a slice as the rvalue. As a result,
// LocalName will be declared as a slice, and json.Unmarshal handles
Expand All @@ -399,16 +421,30 @@ if err != nil {
{{- end}}
err = json.Unmarshal([]byte({{.LocalName}}Str), &{{.LocalName}})`

repeatedPBUnmarshalTmpl := `
var {{.LocalName}} {{.GoType}}
{{- if and (and .IsBaseType .Repeated) (not (Contains .GoType "[]byte"))}}
err = jsonpb.UnmarshalString({{.LocalName}}Str, &{{.LocalName}})
if err != nil {
{{.LocalName}}Str = "[" + {{.LocalName}}Str + "]"
}
{{- end}}
err = jsonpb.UnmarshalString({{.LocalName}}Str, &{{.LocalName}})`

errorCheckingTmpl := `
if err != nil {
return nil, errors.Wrapf(err, "couldn't decode {{.LocalName}} from %v", {{.LocalName}}Str)
}`

var preamble string
if !f.Repeated {
if !f.Repeated && !isProtobufType(f.GoType) {
preamble = singleCustomTypeUnmarshalTmpl
} else {
} else if !f.Repeated && isProtobufType(f.GoType) {
preamble = singlePBTypeUnmarshalTmpl
} else if !isProtobufType(f.GoType) {
preamble = repeatedUnmarshalTmpl
} else {
preamble = repeatedPBUnmarshalTmpl
}
jsonConvTmpl := preamble + errorCheckingTmpl
code, err := ApplyTemplate("UnmarshalNonBaseType", jsonConvTmpl, f, TemplateFuncs)
Expand Down
4 changes: 4 additions & 0 deletions gengokit/httptransport/templates/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,11 @@ import (
"strconv"
"strings"
"io"
"time"

"github.com/gogo/protobuf/jsonpb"
"github.com/gogo/protobuf/proto"
"github.com/gogo/protobuf/types"

"context"

Expand All @@ -102,6 +104,8 @@ var (
_ = pb.New{{.Service.Name}}Client
_ = io.Copy
_ = errors.Wrap
_ = types.EmptyAny
_ = time.NewTimer
)

// MakeHTTPHandler returns a handler that makes a set of endpoints available
Expand Down
3 changes: 3 additions & 0 deletions svcdef/svcdef.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,9 @@ func NewField(f *ast.Field) (*Field, error) {
if oneof, ok := oneofs[ex.Name]; ok {
rv.Type.Oneof = oneof
}
case *ast.SelectorExpr:
packageIdent := ex.X.(*ast.Ident)
rv.Type.Name += fmt.Sprintf("%s.%s", packageIdent.Name, ex.Sel.Name)
case *ast.StarExpr:
rv.Type.StarExpr = true
typeFollower(ex.X)
Expand Down