Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Advanced middleware options #347

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ dependencies:
go get -u github.com/gogo/protobuf/protoc-gen-gogo@21df5aa0e680850681b8643f0024f92d3b09930c
go get -u github.com/gogo/protobuf/protoc-gen-gogofaster@21df5aa0e680850681b8643f0024f92d3b09930c
go get -u github.com/gogo/protobuf/proto@21df5aa0e680850681b8643f0024f92d3b09930c
go get -u github.com/kevinburke/go-bindata/go-bindata
GO111MODULE=off go get -u github.com/kevinburke/go-bindata/go-bindata

# Generate go files containing the all template files in []byte form
gobindata:
go generate github.com/metaverse/truss/gengokit/template
GO111MODULE=off go generate github.com/metaverse/truss/gengokit/template

# Install truss
truss: gobindata
Expand Down
2 changes: 1 addition & 1 deletion cmd/_integration-tests/middlewares/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
pb "github.com/metaverse/truss/cmd/_integration-tests/middlewares/proto"
)

// NewService returns a naïve, stateless implementation of Service.
// NewService returns a naive, stateless implementation of Service.
func NewService() pb.MiddlewaresTestServer {
return middlewarestestService{}
}
Expand Down
9 changes: 4 additions & 5 deletions cmd/_integration-tests/middlewares/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ func TestMain(m *testing.M) {
sometimesWrapped := svc.MakeSometimesWrappedEndpoint(service)
labeledTestHandler := svc.MakeLabeledTestHandlerEndpoint(service)

middlewareEndpoints = svc.Endpoints{
AlwaysWrappedEndpoint: alwaysWrapped,
SometimesWrappedEndpoint: sometimesWrapped,
LabeledTestHandlerEndpoint: labeledTestHandler,
}
middlewareEndpoints = svc.NewEndpoints()
middlewareEndpoints.AlwaysWrappedEndpoint = alwaysWrapped
middlewareEndpoints.SometimesWrappedEndpoint = sometimesWrapped
middlewareEndpoints.LabeledTestHandlerEndpoint = labeledTestHandler

middlewareEndpoints = handlers.WrapEndpoints(middlewareEndpoints)

Expand Down
2 changes: 1 addition & 1 deletion cmd/_integration-tests/transport/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
pb "github.com/metaverse/truss/cmd/_integration-tests/transport/proto"
)

// NewService returns a naïve, stateless implementation of Service.
// NewService returns a naive, stateless implementation of Service.
func NewService() pb.TransportPermutationsServer {
return transportpermutationsService{}
}
Expand Down
5 changes: 2 additions & 3 deletions cmd/_integration-tests/transport/http_benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,8 @@ func BenchmarkGetWithQueryClient_NoNetwork(b *testing.B) {
service = handlers.WrapService(service)
}
var getwithqueryEndpoint = svc.MakeGetWithQueryEndpoint(service)
endpoints := svc.Endpoints{
GetWithQueryEndpoint: getwithqueryEndpoint,
}
endpoints := svc.NewEndpoints()
endpoints.GetWithQueryEndpoint = getwithqueryEndpoint
ctx := context.WithValue(context.Background(), "request-url", "/getwithquery")
ctx = context.WithValue(ctx, "transport", "HTTPJSON")
server := httptransport.NewServer(
Expand Down
41 changes: 20 additions & 21 deletions cmd/_integration-tests/transport/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,26 @@ func TestMain(m *testing.M) {
StatusCodeAndHeadersE := svc.MakeStatusCodeAndHeadersEndpoint(service)
CustomVerbE := svc.MakeCustomVerbEndpoint(service)

endpoints := svc.Endpoints{
GetWithQueryEndpoint: getWithQueryE,
GetWithRepeatedQueryEndpoint: getWithRepeatedQueryE,
GetWithRepeatedStringQueryEndpoint: getWithRepeatedStringQueryE,
GetWithEnumQueryEndpoint: getWithEnumQueryE,
PostWithNestedMessageBodyEndpoint: postWithNestedMessageBodyE,
CtxToCtxEndpoint: ctxToCtxE,
GetWithCapsPathEndpoint: getWithCapsPathE,
GetWithPathParamsEndpoint: getWithPathParamsE,
GetWithEnumPathEndpoint: getWithEnumPathE,
GetWithOneofQueryEndpoint: getWithOneofQueryE,
EchoOddNamesEndpoint: echoOddNamesE,
ErrorRPCEndpoint: errorRPCE,
ErrorRPCNonJSONEndpoint: errorRPCNonJSONE,
ErrorRPCNonJSONLongEndpoint: errorRPCNonJSONLongE,
X2AOddRPCNameEndpoint: X2AOddRPCNameE,
ContentTypeTestEndpoint: contentTypeTestE,
StatusCodeAndNilHeadersEndpoint: StatusCodeAndNilHeadersE,
StatusCodeAndHeadersEndpoint: StatusCodeAndHeadersE,
CustomVerbEndpoint: CustomVerbE,
}
endpoints := svc.NewEndpoints()
endpoints.GetWithQueryEndpoint = getWithQueryE
endpoints.GetWithRepeatedQueryEndpoint = getWithRepeatedQueryE
endpoints.GetWithRepeatedStringQueryEndpoint = getWithRepeatedStringQueryE
endpoints.GetWithEnumQueryEndpoint = getWithEnumQueryE
endpoints.PostWithNestedMessageBodyEndpoint = postWithNestedMessageBodyE
endpoints.CtxToCtxEndpoint = ctxToCtxE
endpoints.GetWithCapsPathEndpoint = getWithCapsPathE
endpoints.GetWithPathParamsEndpoint = getWithPathParamsE
endpoints.GetWithEnumPathEndpoint = getWithEnumPathE
endpoints.GetWithOneofQueryEndpoint = getWithOneofQueryE
endpoints.EchoOddNamesEndpoint = echoOddNamesE
endpoints.ErrorRPCEndpoint = errorRPCE
endpoints.ErrorRPCNonJSONEndpoint = errorRPCNonJSONE
endpoints.ErrorRPCNonJSONLongEndpoint = errorRPCNonJSONLongE
endpoints.X2AOddRPCNameEndpoint = X2AOddRPCNameE
endpoints.ContentTypeTestEndpoint = contentTypeTestE
endpoints.StatusCodeAndNilHeadersEndpoint = StatusCodeAndNilHeadersE
endpoints.StatusCodeAndHeadersEndpoint = StatusCodeAndHeadersE
endpoints.CustomVerbEndpoint = CustomVerbE

// http test server
h := svc.MakeHTTPHandler(endpoints)
Expand Down
4 changes: 2 additions & 2 deletions gengokit/handlers/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ func TestApplyServerTempl(t *testing.T) {
pb "github.com/metaverse/truss/gengokit/general-service"
)

// NewService returns a naïve, stateless implementation of Service.
// NewService returns a naive, stateless implementation of Service.
func NewService() pb.ProtoServer {
return protoService{}
}
Expand Down Expand Up @@ -304,7 +304,7 @@ func TestPruneDecls(t *testing.T) {
pb "github.com/metaverse/truss/gengokit/general-service"
)

// NewService returns a naïve, stateless implementation of Service.
// NewService returns a naive, stateless implementation of Service.
func NewService() pb.ProtoServer {
return protoService{}
}
Expand Down
2 changes: 1 addition & 1 deletion gengokit/handlers/templates/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (
pb "{{.PBImportPath -}}"
)

// NewService returns a naïve, stateless implementation of Service.
// NewService returns a naive, stateless implementation of Service.
func NewService() pb.{{GoName .Service.Name}}Server {
return {{ToLower .Service.Name}}Service{}
}
Expand Down
1 change: 1 addition & 0 deletions gengokit/httptransport/httptransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ func (b *Binding) PathSections() []string {
for _, part := range parts {
if len(part) > 2 && part[0] == '{' && part[len(part)-1] == '}' {
name := RemoveBraces(part)
name = strings.Split(name, ":")[0]
if _, ok := isEnum[gogen.CamelCase(name)]; ok {
convert := fmt.Sprintf("fmt.Sprintf(\"%%d\", req.%v)", gogen.CamelCase(name))
rv = append(rv, convert)
Expand Down
6 changes: 3 additions & 3 deletions gengokit/httptransport/templates/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ func New(instance string, options ...httptransport.ClientOption) (pb.{{.Service.
{{- end}}
{{- end}}

return svc.Endpoints{
endpoints := svc.NewEndpoints()
{{range $method := .HTTPHelper.Methods -}}
{{ if $method.Bindings -}}
{{ with $binding := index $method.Bindings 0 -}}
{{$method.Name}}Endpoint: {{$binding.Label}}Endpoint,
endpoints.{{$method.Name}}Endpoint = {{$binding.Label}}Endpoint
{{end}}
{{- end}}
{{- end}}
}, nil
return endpoints, nil
}

func copyURL(base *url.URL, path string) *url.URL {
Expand Down
16 changes: 10 additions & 6 deletions gengokit/httptransport/templates/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,16 @@ func MakeHTTPHandler(endpoints Endpoints, options ...httptransport.ServerOption)

{{range $method := .HTTPHelper.Methods}}
{{range $binding := $method.Bindings}}
m.Methods("{{$binding.Verb | ToUpper}}").Path("{{$binding.PathTemplate}}").Handler(httptransport.NewServer(
endpoints.{{$method.Name}}Endpoint,
DecodeHTTP{{$binding.Label}}Request,
EncodeHTTPGenericResponse,
serverOptions...,
))
if endpoints.HasHttpHandlerFunc("{{$method.Name}}") {
m.Methods("{{$binding.Verb | ToUpper}}").Path("{{$binding.PathTemplate}}").HandlerFunc(endpoints.GetHttpHandlerFunc("{{$method.Name}}"))
} else {
m.Methods("{{$binding.Verb | ToUpper}}").Path("{{$binding.PathTemplate}}").Handler(httptransport.NewServer(
endpoints.{{$method.Name}}Endpoint,
endpoints.GetHttpRequestDecoder("{{$method.Name}}", DecodeHTTP{{$binding.Label}}Request),
endpoints.GetHttpResponseEncoder("{{$method.Name}}", EncodeHTTPGenericResponse),
append(serverOptions, endpoints.GetHttpServerOptions("{{$method.Name}}")...)...,
))
}
{{- end}}
{{- end}}
return m
Expand Down
2 changes: 1 addition & 1 deletion gengokit/httptransport/templates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/metaverse/truss/gengokit/gentesthelper"
)

// Test that rendering certain templates will ouput the code we expect. The
// Test that rendering certain templates will output the code we expect. The
// code we expect is either the source code literal defined in each test, or
// it's the source code of certain actual functions within this package (see
// embeddable-funcs.go for more info).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ func New(conn *grpc.ClientConn, options ...ClientOption) (pb.{{.Service.Name}}Se
{{end}}
{{end}}

return svc.Endpoints{
endpoints := svc.NewEndpoints()
{{range $i := .Service.Methods -}}
{{$i.Name}}Endpoint: {{ToLower $i.Name}}Endpoint,
endpoints.{{$i.Name}}Endpoint = {{ToLower $i.Name}}Endpoint
{{end}}
}, nil
return endpoints, nil
}

// GRPC Client Decode
Expand Down
114 changes: 114 additions & 0 deletions gengokit/template/NAME-service/svc/endpoints.gotemplate
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ package svc
import (
"fmt"
"context"
"net/http"
httptransport "github.com/go-kit/kit/transport/http"

"github.com/go-kit/kit/endpoint"

Expand All @@ -33,11 +35,24 @@ import (
// single type that implements the Service interface. For example, you might
// construct individual endpoints using transport/http.NewClient, combine them into an Endpoints, and return it to the caller as a Service.
type Endpoints struct {
httpServerOptions map[string][]httptransport.ServerOption
httpRequestDecoders map[string]httptransport.DecodeRequestFunc
httpResponseEncoders map[string]httptransport.EncodeResponseFunc
httpHandlerFuncs map[string]func(http.ResponseWriter, *http.Request)
{{range $i := .Service.Methods}}
{{$i.Name}}Endpoint endpoint.Endpoint
{{- end}}
}

func NewEndpoints() Endpoints {
return Endpoints{
httpServerOptions: make(map[string][]httptransport.ServerOption),
httpRequestDecoders: make(map[string]httptransport.DecodeRequestFunc),
httpResponseEncoders: make(map[string]httptransport.EncodeResponseFunc),
httpHandlerFuncs: make(map[string]func(http.ResponseWriter, *http.Request)),
}
}

// Endpoints
{{range $i := .Service.Methods}}
func (e Endpoints) {{$i.Name}}(ctx context.Context, in *pb.{{GoName $i.RequestType.Name}}) (*pb.{{GoName $i.ResponseType.Name}}, error) {
Expand Down Expand Up @@ -124,3 +139,102 @@ func (e *Endpoints) WrapAllLabeledExcept(middleware func(string, endpoint.Endpoi
{{- end}}
}
}

// WrapAllWithHttpOptionExcept wraps each Endpoint entry of filed HttpServerOptions of struct Endpoints with a
// httptransport.ServerOption.
// Use this for applying a set of server options to every endpoint in the service.
// Optionally, endpoints can be passed in by name to be excluded from being wrapped.
// WrapAllWithHttpOptionExcept(serverOption, "Status", "Ping")
func (e *Endpoints) WrapAllWithHttpOptionExcept(serverOption httptransport.ServerOption, excluded ...string) {
included := map[string]struct{}{
{{- range $i := .Service.Methods}}
"{{$i.Name}}": {},
{{- end}}
}

for _, ex := range excluded {
if _, ok := included[ex]; !ok {
panic(fmt.Sprintf("Excluded endpoint '%s' does not exist; see middlewares/endpoints.go", ex))
}
delete(included, ex)
}

for inc := range included {
var options []httptransport.ServerOption
if o, ok := e.httpServerOptions[inc]; ok {
options = append(o, serverOption)
} else {
options = make([]httptransport.ServerOption, 1)
options[0] = serverOption
}
e.httpServerOptions[inc] = options
}
}

// WrapWithHttpOption wraps one Endpoint entry of filed HttpServerOptions of struct Endpoints with a
// httptransport.ServerOption.
// WrapWithHttpOption(serverOption, "Status")
func (e *Endpoints) WrapWithHttpOption(endpoint string, serverOption httptransport.ServerOption) {
var options []httptransport.ServerOption
if o, ok := e.httpServerOptions[endpoint]; ok {
options = append(o, serverOption)
} else {
options = []httptransport.ServerOption{
serverOption,
}
}
e.httpServerOptions[endpoint] = options
}

// GetHttpServerOptions returns all httptransport.ServerOption associated with the given endpoint.
func (e Endpoints) GetHttpServerOptions(endpoint string) []httptransport.ServerOption {
if options, ok := e.httpServerOptions[endpoint]; ok {
return options
}
return make([]httptransport.ServerOption, 0)
}

// SetHttpRequestDecoder assigns a httptransport.DecodeRequestFunc to an endpoint.
func (e Endpoints) SetHttpRequestDecoder(endpoint string, decoder httptransport.DecodeRequestFunc) {
e.httpRequestDecoders[endpoint] = decoder
}

// GetHttpRequestDecoder returns the httptransport.DecodeRequestFunc associated with the given endpoint.
func (e Endpoints) GetHttpRequestDecoder(endpoint string, fallback httptransport.DecodeRequestFunc) httptransport.DecodeRequestFunc {
if decoder, ok := e.httpRequestDecoders[endpoint]; ok {
return decoder
}
return fallback
}

// SetHttpResponseEncoder assigns a httptransport.EncodeResponseFunc to an endpoint.
func (e Endpoints) SetHttpResponseEncoder(endpoint string, encoder httptransport.EncodeResponseFunc) {
e.httpResponseEncoders[endpoint] = encoder
}

// GetHttpResponseEncoder returns the httptransport.EncodeResponseFunc associated with the given endpoint.
func (e Endpoints) GetHttpResponseEncoder(endpoint string, fallback httptransport.EncodeResponseFunc) httptransport.EncodeResponseFunc {
if encoder, ok := e.httpResponseEncoders[endpoint]; ok {
return encoder
}
return fallback
}

// SetHttpHandlerFunc assigns a custom http HandlerFunc to an endpoint instead of using the default one.
func (e Endpoints) SetHttpHandlerFunc(endpoint string, handler func(http.ResponseWriter, *http.Request)) {
e.httpHandlerFuncs[endpoint] = handler
}

// GetHttpHandlerFunc returns the http HandlerFunc for the given endpoint.
func (e Endpoints) GetHttpHandlerFunc(endpoint string) func(http.ResponseWriter, *http.Request) {
if handler, ok := e.httpHandlerFuncs[endpoint]; ok {
return handler
}
return nil
}

// HasHttpHandlerFunc checks if a custom http HandlerFunc is associated with the given endpoint.
func (e Endpoints) HasHttpHandlerFunc(endpoint string) bool {
_, ok := e.httpHandlerFuncs[endpoint]
return ok
}
5 changes: 2 additions & 3 deletions gengokit/template/NAME-service/svc/server/run.gotemplate
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,10 @@ func NewEndpoints(service pb.{{.Service.Name}}Server) svc.Endpoints {
{{end}}
)

endpoints := svc.Endpoints{
endpoints := svc.NewEndpoints()
{{range $i := .Service.Methods -}}
{{$i.Name}}Endpoint: {{ToLower $i.Name}}Endpoint,
endpoints.{{$i.Name}}Endpoint = {{ToLower $i.Name}}Endpoint
{{end}}
}

// Wrap selected Endpoints with middlewares. See handlers/middlewares.go
endpoints = handlers.WrapEndpoints(endpoints)
Expand Down
Loading