Skip to content

Commit

Permalink
Fix proxy manipulation of single payload fields (#202)
Browse files Browse the repository at this point in the history
* Fix proxy manipulation of single payload fields

* Remove case for *common.Payload
  • Loading branch information
bergundy authored Jan 17, 2025
1 parent 28f4dfc commit a6b8ad8
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 85 deletions.
63 changes: 39 additions & 24 deletions cmd/proxygenerator/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,10 @@ func visitPayloads(
parent proto.Message,
objs ...interface{},
) error {
for i, obj := range objs {
for _, obj := range objs {
ctx.SinglePayloadRequired = false
switch o := obj.(type) {
case *common.Payload:
if o == nil { continue }
no, err := visitPayload(ctx, options, parent, o)
if err != nil { return err }
objs[i] = no
case map[string]*common.Payload:
for ix, x := range o {
if nx, err := visitPayload(ctx, options, parent, x); err != nil {
Expand Down Expand Up @@ -273,6 +268,14 @@ func visitPayloads(
if options.SkipSearchAttributes { continue }
{{end}}
if o == nil { continue }
{{range $record.Payloads -}}
if o.{{.}} != nil {
no, err := visitPayload(ctx, options, o, o.{{.}})
if err != nil { return err }
o.{{.}} = no
}
{{end}}
{{if $record.Methods}}
if err := visitPayloads(
ctx,
options,
Expand All @@ -281,6 +284,7 @@ func visitPayloads(
o.{{.}}(),
{{end}}
); err != nil { return err }
{{end}}
{{end}}
}
}
Expand Down Expand Up @@ -326,10 +330,11 @@ var interceptorTemplate = template.Must(template.New("interceptor").Parse(Interc

// TypeRecord holds the state for a type referred to by the workflow service
type TypeRecord struct {
Methods []string // List of methods on this type that can eventually lead to Payload(s)
Slice bool // The API refers to slices of this type
Map bool // The API refers to maps with this type as the value
Matches bool // We found methods on this type that can eventually lead to Payload(s)
Methods []string // List of methods on this type that can eventually lead to Payload(s)
Payloads []string // List of attributes on this type that are of type Payload
Slice bool // The API refers to slices of this type
Map bool // The API refers to maps with this type as the value
Matches bool // We found methods on this type that can eventually lead to Payload(s)
}

// isSlice returns true if a type is slice, false otherwise
Expand Down Expand Up @@ -411,12 +416,12 @@ func pruneRecords(input map[string]*TypeRecord) map[string]*TypeRecord {

// walk iterates the methods on a type and returns whether any of them can eventually lead to Payload(s)
// The return type for each method on this type is walked recursively to decide which methods can lead to Payload(s)
func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord) bool {
func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord, checkDirectPayload bool) bool {
typeName := typeName(typ)

// If this type is a slice then walk the underlying type and then make a note we need to encode slices of this type
if isSlice(typ) {
result := walk(desired, elemType(typ), records)
result := walk(desired, elemType(typ), records, checkDirectPayload)
if result {
record := (*records)[typeName]
record.Slice = true
Expand All @@ -426,7 +431,7 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord)

// If this type is a map then walk the underlying type and then make a note we need to encode maps with values of this type
if isMap(typ) {
result := walk(desired, elemType(typ), records)
result := walk(desired, elemType(typ), records, checkDirectPayload)
if result {
record := (*records)[typeName]
record.Map = true
Expand Down Expand Up @@ -459,8 +464,18 @@ func walk(desired []types.Type, typ types.Type, records *map[string]*TypeRecord)
// All the Get... methods return the relevant protobuf as the first result
resultType := sig.Results().At(0).Type()

if checkDirectPayload && resultType.String() == "*go.temporal.io/api/common/v1.Payload" {
record.Matches = true
prefix, ok := strings.CutPrefix(methodName, "Get")
if !ok {
panic(fmt.Errorf("expected method to have a Get prefix: %s", methodName))
}
record.Payloads = append(record.Payloads, prefix)
continue
}

// Check if this method returns a Payload(s) or if it leads (eventually) to a Type which refers to a Payload(s)
if typeMatches(resultType, desired...) || walk(desired, resultType, records) {
if typeMatches(resultType, desired...) || walk(desired, resultType, records, checkDirectPayload) {
record.Matches = true
record.Methods = append(record.Methods, methodName)
}
Expand Down Expand Up @@ -536,10 +551,10 @@ func generateInterceptor(cfg config) error {
}

sig := meth.Obj().Type().(*types.Signature)
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords)
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords)
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords)
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords)
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords, true)
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords, false)
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords, true)
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords, false)
}

for _, meth := range typeutil.IntuitiveMethodSet(operatorService, nil) {
Expand All @@ -548,14 +563,14 @@ func generateInterceptor(cfg config) error {
}

sig := meth.Obj().Type().(*types.Signature)
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords)
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords)
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords)
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords)
walk(payloadTypes, sig.Params().At(1).Type(), &payloadRecords, true)
walk(failureTypes, sig.Params().At(1).Type(), &failureRecords, false)
walk(payloadTypes, sig.Results().At(0).Type(), &payloadRecords, true)
walk(failureTypes, sig.Results().At(0).Type(), &failureRecords, false)
}

walk(payloadTypes, workflowExecutions, &payloadRecords)
walk(failureTypes, workflowExecutions, &failureRecords)
walk(payloadTypes, workflowExecutions, &payloadRecords, true)
walk(failureTypes, workflowExecutions, &failureRecords, false)

payloadRecords = pruneRecords(payloadRecords)
failureRecords = pruneRecords(failureRecords)
Expand Down
Loading

0 comments on commit a6b8ad8

Please sign in to comment.