Skip to content

Commit

Permalink
feat: thrift streaming (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
felix021 committed Jan 10, 2024
1 parent 6f0023e commit 9649481
Show file tree
Hide file tree
Showing 16 changed files with 224 additions and 58 deletions.
24 changes: 24 additions & 0 deletions generator/golang/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"strings"
"text/template"

"github.com/cloudwego/thriftgo/generator/golang/streaming"
"github.com/cloudwego/thriftgo/tool/trimmer/trim"

ref_tpl "github.com/cloudwego/thriftgo/generator/golang/templates/ref"
Expand Down Expand Up @@ -95,6 +96,9 @@ func (g *GoBackend) Generate(req *plugin.Request, log backend.LogFunc) *plugin.R
}
g.prepareTemplates()
g.fillRequisitions()
if !g.utils.Features().ThriftStreaming {
g.removeStreamingFunctions(req.GetAST())
}
g.executeTemplates()
return g.buildResponse()
}
Expand Down Expand Up @@ -265,3 +269,23 @@ func (g *GoBackend) PostProcess(path string, content []byte) ([]byte, error) {
}
return content, nil
}

func (g *GoBackend) removeStreamingFunctions(ast *parser.Thrift) {
for _, svc := range ast.Services {
functions := make([]*parser.Function, 0, len(svc.Functions))
for _, f := range svc.Functions {
st, err := streaming.ParseStreaming(f)
if err != nil {
g.log.Warn(fmt.Sprintf("%s.%s: failed to parse streaming, err = %v", svc.Name, f.Name, err))
continue
}
if st.IsStreaming {
g.log.Warn(fmt.Sprintf("skip streaming function %s.%s: not supported by your kitex, "+
"please update your kitex tool to the latest version", svc.Name, f.Name))
continue
}
functions = append(functions, f)
}
svc.Functions = functions
}
}
1 change: 1 addition & 0 deletions generator/golang/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (im *importManager) init(cu *CodeUtils, ast *parser.Thrift) {
"thrift_reflection": ThriftReflectionLib,
"json_utils": ThriftJSONUtilLib,
"fieldmask": ThriftFieldMaskLib,
"streaming": KitexStreamingLib,
}
for pkg, path := range std {
ns.Add(pkg, path)
Expand Down
2 changes: 2 additions & 0 deletions generator/golang/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ type Features struct {
WithFieldMask bool `with_field_mask:"Support field-mask for generated code."`
FieldMaskHalfway bool `field_mask_halfway:"Support set field-mask on not-root struct."`
FieldMaskZeroRequired bool `field_mask_zero_required:"Write zero value instead of current value for required fields filtered by fieldmask."`
ThriftStreaming bool `thrift_streaming:"Recognize thrift streaming annotation and generate streaming code."`
}

var defaultFeatures = Features{
Expand Down Expand Up @@ -88,6 +89,7 @@ var defaultFeatures = Features{
SnakeTyleJSONTag: false,
LowerCamelCaseJSONTag: false,
GenerateReflectionInfo: false,
ThriftStreaming: false,
EnumAsINT32: false,
TrimIDL: false,
JSONStringer: false,
Expand Down
12 changes: 12 additions & 0 deletions generator/golang/scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"path/filepath"
"strings"

"github.com/cloudwego/thriftgo/generator/golang/streaming"
"github.com/cloudwego/thriftgo/parser"
"github.com/cloudwego/thriftgo/pkg/namespace"
"github.com/cloudwego/thriftgo/reflection"
Expand Down Expand Up @@ -652,13 +653,20 @@ type Function struct {
throws []*Field
argType *StructLike
resType *StructLike
service *Service
streaming *streaming.Streaming
}

// GoName returns the go name of the function.
func (f *Function) GoName() Name {
return f.name
}

// Service returns the service that the function is defined in.
func (f *Function) Service() *Service {
return f.service
}

// ResponseGoTypeName returns the go type of the response type of the function.
func (f *Function) ResponseGoTypeName() TypeName {
return f.responseType
Expand All @@ -684,6 +692,10 @@ func (f *Function) ResType() *StructLike {
return f.resType
}

func (f *Function) Streaming() *streaming.Streaming {
return f.streaming
}

func buildSynthesized(v *parser.Function) (argType, resType *parser.StructLike) {
argType = &parser.StructLike{
Category: "struct",
Expand Down
33 changes: 24 additions & 9 deletions generator/golang/scope_internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ package golang

import (
"fmt"
"runtime/debug"
"strconv"
"strings"

"github.com/cloudwego/thriftgo/generator/golang/common"
"github.com/cloudwego/thriftgo/generator/golang/streaming"
"github.com/cloudwego/thriftgo/parser"
"github.com/cloudwego/thriftgo/pkg/namespace"
)
Expand Down Expand Up @@ -47,8 +49,8 @@ func newScope(ast *parser.Thrift) *Scope {

func (s *Scope) init(cu *CodeUtils) (err error) {
defer func() {
if x, ok := recover().(error); ok && x != nil {
err = x
if r := recover(); r != nil {
err = fmt.Errorf("err = %v, stack = %s", r, debug.Stack())
}
}()
if cu.Features().ReorderFields {
Expand All @@ -62,7 +64,9 @@ func (s *Scope) init(cu *CodeUtils) (err error) {
}
s.imports.init(cu, s.ast)
s.buildIncludes(cu)
s.installNames(cu)
if err = s.installNames(cu); err != nil {
return err
}
s.resolveTypesAndValues(cu)
return nil
}
Expand Down Expand Up @@ -108,9 +112,11 @@ func (s *Scope) includeIDL(cu *CodeUtils, t *parser.Thrift) (pkgName string) {
return inc.PackageName
}

func (s *Scope) installNames(cu *CodeUtils) {
func (s *Scope) installNames(cu *CodeUtils) error {
for _, v := range s.ast.Services {
s.buildService(cu, v)
if err := s.buildService(cu, v); err != nil {
return err
}
}
for _, v := range s.ast.GetStructLikes() {
s.buildStructLike(cu, v)
Expand All @@ -124,6 +130,7 @@ func (s *Scope) installNames(cu *CodeUtils) {
for _, v := range s.ast.Constants {
s.buildConstant(cu, v)
}
return nil
}

func (s *Scope) identify(cu *CodeUtils, raw string) string {
Expand All @@ -139,7 +146,7 @@ func (s *Scope) identify(cu *CodeUtils, raw string) string {
return name
}

func (s *Scope) buildService(cu *CodeUtils, v *parser.Service) {
func (s *Scope) buildService(cu *CodeUtils, v *parser.Service) error {
// service name
sn := s.identify(cu, v.Name)
sn = s.globals.Add(sn, v.Name)
Expand All @@ -156,10 +163,17 @@ func (s *Scope) buildService(cu *CodeUtils, v *parser.Service) {
for _, f := range v.Functions {
fn := s.identify(cu, f.Name)
fn = svc.scope.Add(fn, f.Name)
st, err := streaming.ParseStreaming(f)
if err != nil {
return fmt.Errorf("service %s: %s", v.Name, err.Error())
}

svc.functions = append(svc.functions, &Function{
Function: f,
scope: namespace.NewNamespace(namespace.UnderscoreSuffix),
name: Name(fn),
Function: f,
scope: namespace.NewNamespace(namespace.UnderscoreSuffix),
name: Name(fn),
service: svc,
streaming: st,
})
}

Expand All @@ -186,6 +200,7 @@ func (s *Scope) buildService(cu *CodeUtils, v *parser.Service) {
pn := sn + "Processor"
s.globals.MustReserve(cn, _p("client:"+v.Name))
s.globals.MustReserve(pn, _p("processor:"+v.Name))
return nil
}

// buildFunction builds a namespace for parameters of a Function.
Expand Down
89 changes: 89 additions & 0 deletions generator/golang/streaming/streaming.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2023 CloudWeGo Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package streaming

import (
"fmt"

"github.com/cloudwego/thriftgo/parser"
)

// TODO this package should finally be migrated to Kitex tool together with server/client/processor templates.

const (
// Annotation key used in Thrift IDL
StreamingModeKey = "streaming.mode"

// Streaming mode identifiers
StreamingBidirectional = "bidirectional" // Bidirectional streaming API over HTTP2
StreamingClientSide = "client" // Client-side streaming API over HTTP2
StreamingServerSide = "server" // Server-side streaming API over HTTP2
StreamingUnary = "unary" // Unary API over HTTP2, different from Kitex Thrift/Protobuf
)

// Streaming represents the streaming mode of a function
type Streaming struct {
Mode string `thrift:"Mode,1" json:"Mode"`
ClientStreaming bool `thrift:"ClientStreaming,2" json:"ClientStreaming"`
ServerStreaming bool `thrift:"ServerStreaming,3" json:"ServerStreaming"`
BidirectionalStreaming bool `thrift:"BidirectionalStreaming,4" json:"BidirectionalStreaming"`
Unary bool `thrift:"Unary,5" json:"Unary"`
IsStreaming bool `thrift:"IsStreaming,6" json:"IsStreaming"`
}

// ParseStreaming parses the streaming mode from a Thrift function parsed from IDL
func ParseStreaming(f *parser.Function) (s *Streaming, err error) {
s = &Streaming{}
for _, anno := range f.Annotations {
if anno.Key != StreamingModeKey {
continue
}
if len(anno.Values) != 1 {
return nil, fmt.Errorf("%s has multiple values for annotation %v (at most 1 allowed)",
f.Name, StreamingModeKey)
}
for _, value := range anno.Values {
s.IsStreaming = true
switch value {
case StreamingClientSide:
s.ClientStreaming = true
case StreamingServerSide:
s.ServerStreaming = true
case StreamingBidirectional:
s.ClientStreaming = true
s.ServerStreaming = true
s.BidirectionalStreaming = true
case StreamingUnary:
s.Unary = true
default: // other types are not recognized
return nil, fmt.Errorf("invalid value (%s) for annotation %v", value, StreamingModeKey)
}
}
if s.IsStreaming && len(f.Arguments) != 1 {
return nil, fmt.Errorf("streaming function %s should have exactly 1 argument", f.Name)
}

if s.BidirectionalStreaming {
s.Mode = StreamingBidirectional
} else if s.ServerStreaming {
s.Mode = StreamingServerSide
} else if s.ClientStreaming {
s.Mode = StreamingClientSide
} else if s.Unary {
s.Mode = StreamingUnary
}
}
return s, nil
}
25 changes: 23 additions & 2 deletions generator/golang/templates/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package templates

// Client .
var Client = `
{{define "Client"}}
{{define "ThriftClient"}}
{{- UseStdLibrary "thrift"}}
{{- $BasePrefix := ServicePrefix .Base}}
{{- $BaseService := ServiceName .Base}}
Expand Down Expand Up @@ -71,6 +71,9 @@ func (p *{{$ClientName}}) Client_() thrift.TClient {
{{- $ArgType := .ArgType}}
{{- $ResType := .ResType}}
func (p *{{$ClientName}}) {{- template "FunctionSignature" . -}} {
{{if .Streaming.IsStreaming -}}
panic("streaming method {{$ServiceName}}.{{.Name}}(mode = {{.Streaming.Mode}}) not available, please use Kitex Thrift Streaming Client.")
{{else -}}
var _args {{$ArgType.GoName}}
{{- range .Arguments}}
_args.{{($ArgType.Field .Name).GoName}} = {{.GoName}}
Expand Down Expand Up @@ -112,7 +115,25 @@ func (p *{{$ClientName}}) {{- template "FunctionSignature" . -}} {
{{- end}}
return _result.GetSuccess(), nil
{{- end}}{{/* If .Void */}}
{{- end}}{{/* If .Streaming.IsStreaming */ -}}
}
{{- if or .Streaming.ClientStreaming .Streaming.ServerStreaming}}
{{- $arg := index .Arguments 0}}
{{- $ResponseType := .FunctionType.Name}}
type {{.Service.GoName}}_{{.Name}}Server interface {
{{- UseStdLibrary "streaming" -}}
streaming.Stream
{{if .Streaming.ClientStreaming }}
Recv() (*{{$arg.Type.Name}}, error)
{{end}}
{{if .Streaming.ServerStreaming}}
Send(*{{$ResponseType}}) error
{{end}}
{{if and .Streaming.ClientStreaming (not .Streaming.ServerStreaming) }}
SendAndClose(*{{$ResponseType}}) error
{{end}}
}
{{- end}}{{/* Streaming */}}
{{- end}}{{/* range .Functions */}}
{{- end}}{{/* define "Client" */}}
{{- end}}{{/* define "ThriftClient" */}}
`
6 changes: 3 additions & 3 deletions generator/golang/templates/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ import (
{{- end}}
{{- range .Services}}
{{template "Service" .}}
{{template "Client" .}}
{{template "ThriftService" .}}
{{template "ThriftClient" .}}
{{- end}}
{{- range .Services}}
{{template "Processor" .}}
{{template "ThriftProcessor" .}}
{{- end}}
{{- if Features.GenerateReflectionInfo}}
Expand Down
8 changes: 6 additions & 2 deletions generator/golang/templates/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ package templates

// Processor .
var Processor = `
{{define "Processor"}}
{{define "ThriftProcessor"}}
{{- UseStdLibrary "thrift"}}
{{- $BasePrefix := ServicePrefix .Base}}
{{- $BaseService := ServiceName .Base}}
Expand Down Expand Up @@ -90,6 +90,9 @@ type {{$ProcessorName | Unexport}}{{$FuncName}} struct {
{{- UseStdLibrary "context"}}
func (p *{{$ProcessName}}) Process(ctx context.Context, seqId int32, iprot, oprot thrift.TProtocol) (success bool, err thrift.TException) {
{{if .Streaming.IsStreaming -}}
panic("streaming method {{$ServiceName}}.{{.Name}}(mode = {{.Streaming.Mode}}) not available, please use Kitex Thrift Streaming Client.")
{{else -}}
args := {{$ArgType.GoName}}{}
if err = args.Read(iprot); err != nil {
iprot.ReadMessageEnd()
Expand Down Expand Up @@ -165,6 +168,7 @@ func (p *{{$ProcessName}}) Process(ctx context.Context, seqId int32, iprot, opro
}
return true, err
{{- end}}{{/* if .Oneway */}}
{{- end -}}{{- /* end if not Has Streaming */ -}}
}
{{- end}}{{/* range .Functions */}}
Expand All @@ -180,5 +184,5 @@ func (p *{{$ProcessName}}) Process(ctx context.Context, seqId int32, iprot, opro
{{- $_ := (SetWithFieldMask $withFieldMask) }}
{{- end}}
{{- end}}{{/* range .Functions */}}
{{- end}}{{/* define "Processor" */}}
{{- end}}{{/* define "ThriftProcessor" */}}
`
Loading

0 comments on commit 9649481

Please sign in to comment.